1//===----------------------------------------------------------------------===//
2// DuckDB
3//
4// duckdb/function/udf_function.hpp
5//
6//
7//===----------------------------------------------------------------------===//
8
9#pragma once
10
11#include "duckdb/function/scalar_function.hpp"
12#include "duckdb/function/aggregate_function.hpp"
13
14namespace duckdb {
15
16struct UDFWrapper {
17public:
18 template <typename TR, typename... Args>
19 inline static scalar_function_t CreateScalarFunction(const string &name, TR (*udf_func)(Args...)) {
20 const std::size_t num_template_argc = sizeof...(Args);
21 switch (num_template_argc) {
22 case 1:
23 return CreateUnaryFunction<TR, Args...>(name, udf_func);
24 case 2:
25 return CreateBinaryFunction<TR, Args...>(name, udf_func);
26 case 3:
27 return CreateTernaryFunction<TR, Args...>(name, udf_func);
28 default: // LCOV_EXCL_START
29 throw std::runtime_error("UDF function only supported until ternary!");
30 } // LCOV_EXCL_STOP
31 }
32
33 template <typename TR, typename... Args>
34 inline static scalar_function_t CreateScalarFunction(const string &name, vector<LogicalType> args,
35 LogicalType ret_type, TR (*udf_func)(Args...)) {
36 if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START
37 throw std::runtime_error("Return type doesn't match with the first template type.");
38 } // LCOV_EXCL_STOP
39
40 const std::size_t num_template_types = sizeof...(Args);
41 if (num_template_types != args.size()) { // LCOV_EXCL_START
42 throw std::runtime_error(
43 "The number of templated types should be the same quantity of the LogicalType arguments.");
44 } // LCOV_EXCL_STOP
45
46 switch (num_template_types) {
47 case 1:
48 return CreateUnaryFunction<TR, Args...>(name, args, ret_type, udf_func);
49 case 2:
50 return CreateBinaryFunction<TR, Args...>(name, args, ret_type, udf_func);
51 case 3:
52 return CreateTernaryFunction<TR, Args...>(name, args, ret_type, udf_func);
53 default: // LCOV_EXCL_START
54 throw std::runtime_error("UDF function only supported until ternary!");
55 } // LCOV_EXCL_STOP
56 }
57
58 template <typename TR, typename... Args>
59 inline static void RegisterFunction(const string &name, scalar_function_t udf_function, ClientContext &context,
60 LogicalType varargs = LogicalType(LogicalTypeId::INVALID)) {
61 vector<LogicalType> arguments;
62 GetArgumentTypesRecursive<Args...>(arguments);
63
64 LogicalType ret_type = GetArgumentType<TR>();
65
66 RegisterFunction(name, args: arguments, ret_type, udf_function, context, varargs);
67 }
68
69 static void RegisterFunction(string name, vector<LogicalType> args, LogicalType ret_type,
70 scalar_function_t udf_function, ClientContext &context,
71 LogicalType varargs = LogicalType(LogicalTypeId::INVALID));
72
73 //--------------------------------- Aggregate UDFs ------------------------------------//
74 template <typename UDF_OP, typename STATE, typename TR, typename TA>
75 inline static AggregateFunction CreateAggregateFunction(const string &name) {
76 return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name);
77 }
78
79 template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB>
80 inline static AggregateFunction CreateAggregateFunction(const string &name) {
81 return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name);
82 }
83
84 template <typename UDF_OP, typename STATE, typename TR, typename TA>
85 inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type,
86 LogicalType input_type) {
87 if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START
88 throw std::runtime_error("The return argument don't match!");
89 } // LCOV_EXCL_STOP
90
91 if (!TypesMatch<TA>(input_type)) { // LCOV_EXCL_START
92 throw std::runtime_error("The input argument don't match!");
93 } // LCOV_EXCL_STOP
94
95 return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name, ret_type, input_type);
96 }
97
98 template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB>
99 inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type,
100 LogicalType input_typeA, LogicalType input_typeB) {
101 if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START
102 throw std::runtime_error("The return argument don't match!");
103 }
104
105 if (!TypesMatch<TA>(input_typeA)) {
106 throw std::runtime_error("The first input argument don't match!");
107 }
108
109 if (!TypesMatch<TB>(input_typeB)) {
110 throw std::runtime_error("The second input argument don't match!");
111 } // LCOV_EXCL_STOP
112
113 return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name, ret_type, input_typeA, input_typeB);
114 }
115
116 //! A generic CreateAggregateFunction ---------------------------------------------------------------------------//
117 inline static AggregateFunction
118 CreateAggregateFunction(string name, vector<LogicalType> arguments, LogicalType return_type,
119 aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update,
120 aggregate_combine_t combine, aggregate_finalize_t finalize,
121 aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr,
122 aggregate_destructor_t destructor = nullptr) {
123
124 AggregateFunction aggr_function(std::move(name), std::move(arguments), std::move(return_type), state_size,
125 initialize, update, combine, finalize, simple_update, bind, destructor);
126 aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
127 return aggr_function;
128 }
129
130 static void RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context,
131 LogicalType varargs = LogicalType(LogicalTypeId::INVALID));
132
133private:
134 //-------------------------------- Templated functions --------------------------------//
135 struct UnaryUDFExecutor {
136 template <class INPUT_TYPE, class RESULT_TYPE>
137 static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) {
138 typedef RESULT_TYPE (*unary_function_t)(INPUT_TYPE);
139 auto udf = (unary_function_t)dataptr;
140 return udf(input);
141 }
142 };
143
144 template <typename TR, typename TA>
145 inline static scalar_function_t CreateUnaryFunction(const string &name, TR (*udf_func)(TA)) {
146 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void {
147 UnaryExecutor::GenericExecute<TA, TR, UnaryUDFExecutor>(input.data[0], result, input.size(),
148 (void *)udf_func);
149 };
150 return udf_function;
151 }
152
153 template <typename TR, typename TA, typename TB>
154 inline static scalar_function_t CreateBinaryFunction(const string &name, TR (*udf_func)(TA, TB)) {
155 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void {
156 BinaryExecutor::Execute<TA, TB, TR>(input.data[0], input.data[1], result, input.size(), udf_func);
157 };
158 return udf_function;
159 }
160
161 template <typename TR, typename TA, typename TB, typename TC>
162 inline static scalar_function_t CreateTernaryFunction(const string &name, TR (*udf_func)(TA, TB, TC)) {
163 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void {
164 TernaryExecutor::Execute<TA, TB, TC, TR>(input.data[0], input.data[1], input.data[2], result, input.size(),
165 udf_func);
166 };
167 return udf_function;
168 }
169
170 template <typename TR, typename... Args>
171 inline static scalar_function_t CreateUnaryFunction(const string &name,
172 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
173 throw std::runtime_error("Incorrect number of arguments for unary function");
174 } // LCOV_EXCL_STOP
175
176 template <typename TR, typename... Args>
177 inline static scalar_function_t CreateBinaryFunction(const string &name,
178 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
179 throw std::runtime_error("Incorrect number of arguments for binary function");
180 } // LCOV_EXCL_STOP
181
182 template <typename TR, typename... Args>
183 inline static scalar_function_t CreateTernaryFunction(const string &name,
184 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
185 throw std::runtime_error("Incorrect number of arguments for ternary function");
186 } // LCOV_EXCL_STOP
187
188 template <typename T>
189 inline static LogicalType GetArgumentType() {
190 if (std::is_same<T, bool>()) {
191 return LogicalType(LogicalTypeId::BOOLEAN);
192 } else if (std::is_same<T, int8_t>()) {
193 return LogicalType(LogicalTypeId::TINYINT);
194 } else if (std::is_same<T, int16_t>()) {
195 return LogicalType(LogicalTypeId::SMALLINT);
196 } else if (std::is_same<T, int32_t>()) {
197 return LogicalType(LogicalTypeId::INTEGER);
198 } else if (std::is_same<T, int64_t>()) {
199 return LogicalType(LogicalTypeId::BIGINT);
200 } else if (std::is_same<T, float>()) {
201 return LogicalType(LogicalTypeId::FLOAT);
202 } else if (std::is_same<T, double>()) {
203 return LogicalType(LogicalTypeId::DOUBLE);
204 } else if (std::is_same<T, string_t>()) {
205 return LogicalType(LogicalTypeId::VARCHAR);
206 } else { // LCOV_EXCL_START
207 throw std::runtime_error("Unrecognized type!");
208 } // LCOV_EXCL_STOP
209 }
210
211 template <typename TA, typename TB, typename... Args>
212 inline static void GetArgumentTypesRecursive(vector<LogicalType> &arguments) {
213 arguments.push_back(GetArgumentType<TA>());
214 GetArgumentTypesRecursive<TB, Args...>(arguments);
215 }
216
217 template <typename TA>
218 inline static void GetArgumentTypesRecursive(vector<LogicalType> &arguments) {
219 arguments.push_back(GetArgumentType<TA>());
220 }
221
222private:
223 //-------------------------------- Argumented functions --------------------------------//
224
225 template <typename TR, typename... Args>
226 inline static scalar_function_t CreateUnaryFunction(const string &name, vector<LogicalType> args,
227 LogicalType ret_type,
228 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
229 throw std::runtime_error("Incorrect number of arguments for unary function");
230 } // LCOV_EXCL_STOP
231
232 template <typename TR, typename TA>
233 inline static scalar_function_t CreateUnaryFunction(const string &name, vector<LogicalType> args,
234 LogicalType ret_type, TR (*udf_func)(TA)) {
235 if (args.size() != 1) { // LCOV_EXCL_START
236 throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 1!");
237 }
238 if (!TypesMatch<TA>(args[0])) {
239 throw std::runtime_error("The first arguments don't match!");
240 } // LCOV_EXCL_STOP
241
242 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void {
243 UnaryExecutor::GenericExecute<TA, TR, UnaryUDFExecutor>(input.data[0], result, input.size(),
244 (void *)udf_func);
245 };
246 return udf_function;
247 }
248
249 template <typename TR, typename... Args>
250 inline static scalar_function_t CreateBinaryFunction(const string &name, vector<LogicalType> args,
251 LogicalType ret_type,
252 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
253 throw std::runtime_error("Incorrect number of arguments for binary function");
254 } // LCOV_EXCL_STOP
255
256 template <typename TR, typename TA, typename TB>
257 inline static scalar_function_t CreateBinaryFunction(const string &name, vector<LogicalType> args,
258 LogicalType ret_type, TR (*udf_func)(TA, TB)) {
259 if (args.size() != 2) { // LCOV_EXCL_START
260 throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 2!");
261 }
262 if (!TypesMatch<TA>(args[0])) {
263 throw std::runtime_error("The first arguments don't match!");
264 }
265 if (!TypesMatch<TB>(args[1])) {
266 throw std::runtime_error("The second arguments don't match!");
267 } // LCOV_EXCL_STOP
268
269 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) {
270 BinaryExecutor::Execute<TA, TB, TR>(input.data[0], input.data[1], result, input.size(), udf_func);
271 };
272 return udf_function;
273 }
274
275 template <typename TR, typename... Args>
276 inline static scalar_function_t CreateTernaryFunction(const string &name, vector<LogicalType> args,
277 LogicalType ret_type,
278 TR (*udf_func)(Args...)) { // LCOV_EXCL_START
279 throw std::runtime_error("Incorrect number of arguments for ternary function");
280 } // LCOV_EXCL_STOP
281
282 template <typename TR, typename TA, typename TB, typename TC>
283 inline static scalar_function_t CreateTernaryFunction(const string &name, vector<LogicalType> args,
284 LogicalType ret_type, TR (*udf_func)(TA, TB, TC)) {
285 if (args.size() != 3) { // LCOV_EXCL_START
286 throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 3!");
287 }
288 if (!TypesMatch<TA>(args[0])) {
289 throw std::runtime_error("The first arguments don't match!");
290 }
291 if (!TypesMatch<TB>(args[1])) {
292 throw std::runtime_error("The second arguments don't match!");
293 }
294 if (!TypesMatch<TC>(args[2])) {
295 throw std::runtime_error("The second arguments don't match!");
296 } // LCOV_EXCL_STOP
297
298 scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void {
299 TernaryExecutor::Execute<TA, TB, TC, TR>(input.data[0], input.data[1], input.data[2], result, input.size(),
300 udf_func);
301 };
302 return udf_function;
303 }
304
305 template <typename T>
306 inline static bool TypesMatch(const LogicalType &sql_type) {
307 switch (sql_type.id()) {
308 case LogicalTypeId::BOOLEAN:
309 return std::is_same<T, bool>();
310 case LogicalTypeId::TINYINT:
311 return std::is_same<T, int8_t>();
312 case LogicalTypeId::SMALLINT:
313 return std::is_same<T, int16_t>();
314 case LogicalTypeId::INTEGER:
315 return std::is_same<T, int32_t>();
316 case LogicalTypeId::BIGINT:
317 return std::is_same<T, int64_t>();
318 case LogicalTypeId::DATE:
319 return std::is_same<T, date_t>();
320 case LogicalTypeId::TIME:
321 case LogicalTypeId::TIME_TZ:
322 return std::is_same<T, dtime_t>();
323 case LogicalTypeId::TIMESTAMP:
324 case LogicalTypeId::TIMESTAMP_MS:
325 case LogicalTypeId::TIMESTAMP_NS:
326 case LogicalTypeId::TIMESTAMP_SEC:
327 case LogicalTypeId::TIMESTAMP_TZ:
328 return std::is_same<T, timestamp_t>();
329 case LogicalTypeId::FLOAT:
330 return std::is_same<T, float>();
331 case LogicalTypeId::DOUBLE:
332 return std::is_same<T, double>();
333 case LogicalTypeId::VARCHAR:
334 case LogicalTypeId::CHAR:
335 case LogicalTypeId::BLOB:
336 return std::is_same<T, string_t>();
337 default: // LCOV_EXCL_START
338 throw std::runtime_error("Type is not supported!");
339 } // LCOV_EXCL_STOP
340 }
341
342private:
343 //-------------------------------- Aggregate functions --------------------------------//
344 template <typename UDF_OP, typename STATE, typename TR, typename TA>
345 inline static AggregateFunction CreateUnaryAggregateFunction(const string &name) {
346 LogicalType return_type = GetArgumentType<TR>();
347 LogicalType input_type = GetArgumentType<TA>();
348 return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name, return_type, input_type);
349 }
350
351 template <typename UDF_OP, typename STATE, typename TR, typename TA>
352 inline static AggregateFunction CreateUnaryAggregateFunction(const string &name, LogicalType ret_type,
353 LogicalType input_type) {
354 AggregateFunction aggr_function =
355 AggregateFunction::UnaryAggregate<STATE, TR, TA, UDF_OP>(input_type, ret_type);
356 aggr_function.name = name;
357 return aggr_function;
358 }
359
360 template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB>
361 inline static AggregateFunction CreateBinaryAggregateFunction(const string &name) {
362 LogicalType return_type = GetArgumentType<TR>();
363 LogicalType input_typeA = GetArgumentType<TA>();
364 LogicalType input_typeB = GetArgumentType<TB>();
365 return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name, return_type, input_typeA, input_typeB);
366 }
367
368 template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB>
369 inline static AggregateFunction CreateBinaryAggregateFunction(const string &name, LogicalType ret_type,
370 LogicalType input_typeA, LogicalType input_typeB) {
371 AggregateFunction aggr_function =
372 AggregateFunction::BinaryAggregate<STATE, TR, TA, TB, UDF_OP>(input_typeA, input_typeB, ret_type);
373 aggr_function.name = name;
374 return aggr_function;
375 }
376}; // end UDFWrapper
377
378} // namespace duckdb
379