1 | //===----------------------------------------------------------------------===// |
2 | // DuckDB |
3 | // |
4 | // duckdb/function/udf_function.hpp |
5 | // |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #pragma once |
10 | |
11 | #include "duckdb/function/scalar_function.hpp" |
12 | #include "duckdb/function/aggregate_function.hpp" |
13 | |
14 | namespace duckdb { |
15 | |
16 | struct UDFWrapper { |
17 | public: |
18 | template <typename TR, typename... Args> |
19 | inline static scalar_function_t CreateScalarFunction(const string &name, TR (*udf_func)(Args...)) { |
20 | const std::size_t num_template_argc = sizeof...(Args); |
21 | switch (num_template_argc) { |
22 | case 1: |
23 | return CreateUnaryFunction<TR, Args...>(name, udf_func); |
24 | case 2: |
25 | return CreateBinaryFunction<TR, Args...>(name, udf_func); |
26 | case 3: |
27 | return CreateTernaryFunction<TR, Args...>(name, udf_func); |
28 | default: // LCOV_EXCL_START |
29 | throw std::runtime_error("UDF function only supported until ternary!" ); |
30 | } // LCOV_EXCL_STOP |
31 | } |
32 | |
33 | template <typename TR, typename... Args> |
34 | inline static scalar_function_t CreateScalarFunction(const string &name, vector<LogicalType> args, |
35 | LogicalType ret_type, TR (*udf_func)(Args...)) { |
36 | if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START |
37 | throw std::runtime_error("Return type doesn't match with the first template type." ); |
38 | } // LCOV_EXCL_STOP |
39 | |
40 | const std::size_t num_template_types = sizeof...(Args); |
41 | if (num_template_types != args.size()) { // LCOV_EXCL_START |
42 | throw std::runtime_error( |
43 | "The number of templated types should be the same quantity of the LogicalType arguments." ); |
44 | } // LCOV_EXCL_STOP |
45 | |
46 | switch (num_template_types) { |
47 | case 1: |
48 | return CreateUnaryFunction<TR, Args...>(name, args, ret_type, udf_func); |
49 | case 2: |
50 | return CreateBinaryFunction<TR, Args...>(name, args, ret_type, udf_func); |
51 | case 3: |
52 | return CreateTernaryFunction<TR, Args...>(name, args, ret_type, udf_func); |
53 | default: // LCOV_EXCL_START |
54 | throw std::runtime_error("UDF function only supported until ternary!" ); |
55 | } // LCOV_EXCL_STOP |
56 | } |
57 | |
58 | template <typename TR, typename... Args> |
59 | inline static void RegisterFunction(const string &name, scalar_function_t udf_function, ClientContext &context, |
60 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID)) { |
61 | vector<LogicalType> arguments; |
62 | GetArgumentTypesRecursive<Args...>(arguments); |
63 | |
64 | LogicalType ret_type = GetArgumentType<TR>(); |
65 | |
66 | RegisterFunction(name, args: arguments, ret_type, udf_function, context, varargs); |
67 | } |
68 | |
69 | static void RegisterFunction(string name, vector<LogicalType> args, LogicalType ret_type, |
70 | scalar_function_t udf_function, ClientContext &context, |
71 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); |
72 | |
73 | //--------------------------------- Aggregate UDFs ------------------------------------// |
74 | template <typename UDF_OP, typename STATE, typename TR, typename TA> |
75 | inline static AggregateFunction CreateAggregateFunction(const string &name) { |
76 | return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name); |
77 | } |
78 | |
79 | template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB> |
80 | inline static AggregateFunction CreateAggregateFunction(const string &name) { |
81 | return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name); |
82 | } |
83 | |
84 | template <typename UDF_OP, typename STATE, typename TR, typename TA> |
85 | inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, |
86 | LogicalType input_type) { |
87 | if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START |
88 | throw std::runtime_error("The return argument don't match!" ); |
89 | } // LCOV_EXCL_STOP |
90 | |
91 | if (!TypesMatch<TA>(input_type)) { // LCOV_EXCL_START |
92 | throw std::runtime_error("The input argument don't match!" ); |
93 | } // LCOV_EXCL_STOP |
94 | |
95 | return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name, ret_type, input_type); |
96 | } |
97 | |
98 | template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB> |
99 | inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, |
100 | LogicalType input_typeA, LogicalType input_typeB) { |
101 | if (!TypesMatch<TR>(ret_type)) { // LCOV_EXCL_START |
102 | throw std::runtime_error("The return argument don't match!" ); |
103 | } |
104 | |
105 | if (!TypesMatch<TA>(input_typeA)) { |
106 | throw std::runtime_error("The first input argument don't match!" ); |
107 | } |
108 | |
109 | if (!TypesMatch<TB>(input_typeB)) { |
110 | throw std::runtime_error("The second input argument don't match!" ); |
111 | } // LCOV_EXCL_STOP |
112 | |
113 | return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name, ret_type, input_typeA, input_typeB); |
114 | } |
115 | |
116 | //! A generic CreateAggregateFunction ---------------------------------------------------------------------------// |
117 | inline static AggregateFunction |
118 | CreateAggregateFunction(string name, vector<LogicalType> arguments, LogicalType return_type, |
119 | aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, |
120 | aggregate_combine_t combine, aggregate_finalize_t finalize, |
121 | aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, |
122 | aggregate_destructor_t destructor = nullptr) { |
123 | |
124 | AggregateFunction aggr_function(std::move(name), std::move(arguments), std::move(return_type), state_size, |
125 | initialize, update, combine, finalize, simple_update, bind, destructor); |
126 | aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
127 | return aggr_function; |
128 | } |
129 | |
130 | static void RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, |
131 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); |
132 | |
133 | private: |
134 | //-------------------------------- Templated functions --------------------------------// |
135 | struct UnaryUDFExecutor { |
136 | template <class INPUT_TYPE, class RESULT_TYPE> |
137 | static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { |
138 | typedef RESULT_TYPE (*unary_function_t)(INPUT_TYPE); |
139 | auto udf = (unary_function_t)dataptr; |
140 | return udf(input); |
141 | } |
142 | }; |
143 | |
144 | template <typename TR, typename TA> |
145 | inline static scalar_function_t CreateUnaryFunction(const string &name, TR (*udf_func)(TA)) { |
146 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { |
147 | UnaryExecutor::GenericExecute<TA, TR, UnaryUDFExecutor>(input.data[0], result, input.size(), |
148 | (void *)udf_func); |
149 | }; |
150 | return udf_function; |
151 | } |
152 | |
153 | template <typename TR, typename TA, typename TB> |
154 | inline static scalar_function_t CreateBinaryFunction(const string &name, TR (*udf_func)(TA, TB)) { |
155 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { |
156 | BinaryExecutor::Execute<TA, TB, TR>(input.data[0], input.data[1], result, input.size(), udf_func); |
157 | }; |
158 | return udf_function; |
159 | } |
160 | |
161 | template <typename TR, typename TA, typename TB, typename TC> |
162 | inline static scalar_function_t CreateTernaryFunction(const string &name, TR (*udf_func)(TA, TB, TC)) { |
163 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { |
164 | TernaryExecutor::Execute<TA, TB, TC, TR>(input.data[0], input.data[1], input.data[2], result, input.size(), |
165 | udf_func); |
166 | }; |
167 | return udf_function; |
168 | } |
169 | |
170 | template <typename TR, typename... Args> |
171 | inline static scalar_function_t CreateUnaryFunction(const string &name, |
172 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
173 | throw std::runtime_error("Incorrect number of arguments for unary function" ); |
174 | } // LCOV_EXCL_STOP |
175 | |
176 | template <typename TR, typename... Args> |
177 | inline static scalar_function_t CreateBinaryFunction(const string &name, |
178 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
179 | throw std::runtime_error("Incorrect number of arguments for binary function" ); |
180 | } // LCOV_EXCL_STOP |
181 | |
182 | template <typename TR, typename... Args> |
183 | inline static scalar_function_t CreateTernaryFunction(const string &name, |
184 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
185 | throw std::runtime_error("Incorrect number of arguments for ternary function" ); |
186 | } // LCOV_EXCL_STOP |
187 | |
188 | template <typename T> |
189 | inline static LogicalType GetArgumentType() { |
190 | if (std::is_same<T, bool>()) { |
191 | return LogicalType(LogicalTypeId::BOOLEAN); |
192 | } else if (std::is_same<T, int8_t>()) { |
193 | return LogicalType(LogicalTypeId::TINYINT); |
194 | } else if (std::is_same<T, int16_t>()) { |
195 | return LogicalType(LogicalTypeId::SMALLINT); |
196 | } else if (std::is_same<T, int32_t>()) { |
197 | return LogicalType(LogicalTypeId::INTEGER); |
198 | } else if (std::is_same<T, int64_t>()) { |
199 | return LogicalType(LogicalTypeId::BIGINT); |
200 | } else if (std::is_same<T, float>()) { |
201 | return LogicalType(LogicalTypeId::FLOAT); |
202 | } else if (std::is_same<T, double>()) { |
203 | return LogicalType(LogicalTypeId::DOUBLE); |
204 | } else if (std::is_same<T, string_t>()) { |
205 | return LogicalType(LogicalTypeId::VARCHAR); |
206 | } else { // LCOV_EXCL_START |
207 | throw std::runtime_error("Unrecognized type!" ); |
208 | } // LCOV_EXCL_STOP |
209 | } |
210 | |
211 | template <typename TA, typename TB, typename... Args> |
212 | inline static void GetArgumentTypesRecursive(vector<LogicalType> &arguments) { |
213 | arguments.push_back(GetArgumentType<TA>()); |
214 | GetArgumentTypesRecursive<TB, Args...>(arguments); |
215 | } |
216 | |
217 | template <typename TA> |
218 | inline static void GetArgumentTypesRecursive(vector<LogicalType> &arguments) { |
219 | arguments.push_back(GetArgumentType<TA>()); |
220 | } |
221 | |
222 | private: |
223 | //-------------------------------- Argumented functions --------------------------------// |
224 | |
225 | template <typename TR, typename... Args> |
226 | inline static scalar_function_t CreateUnaryFunction(const string &name, vector<LogicalType> args, |
227 | LogicalType ret_type, |
228 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
229 | throw std::runtime_error("Incorrect number of arguments for unary function" ); |
230 | } // LCOV_EXCL_STOP |
231 | |
232 | template <typename TR, typename TA> |
233 | inline static scalar_function_t CreateUnaryFunction(const string &name, vector<LogicalType> args, |
234 | LogicalType ret_type, TR (*udf_func)(TA)) { |
235 | if (args.size() != 1) { // LCOV_EXCL_START |
236 | throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 1!" ); |
237 | } |
238 | if (!TypesMatch<TA>(args[0])) { |
239 | throw std::runtime_error("The first arguments don't match!" ); |
240 | } // LCOV_EXCL_STOP |
241 | |
242 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { |
243 | UnaryExecutor::GenericExecute<TA, TR, UnaryUDFExecutor>(input.data[0], result, input.size(), |
244 | (void *)udf_func); |
245 | }; |
246 | return udf_function; |
247 | } |
248 | |
249 | template <typename TR, typename... Args> |
250 | inline static scalar_function_t CreateBinaryFunction(const string &name, vector<LogicalType> args, |
251 | LogicalType ret_type, |
252 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
253 | throw std::runtime_error("Incorrect number of arguments for binary function" ); |
254 | } // LCOV_EXCL_STOP |
255 | |
256 | template <typename TR, typename TA, typename TB> |
257 | inline static scalar_function_t CreateBinaryFunction(const string &name, vector<LogicalType> args, |
258 | LogicalType ret_type, TR (*udf_func)(TA, TB)) { |
259 | if (args.size() != 2) { // LCOV_EXCL_START |
260 | throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 2!" ); |
261 | } |
262 | if (!TypesMatch<TA>(args[0])) { |
263 | throw std::runtime_error("The first arguments don't match!" ); |
264 | } |
265 | if (!TypesMatch<TB>(args[1])) { |
266 | throw std::runtime_error("The second arguments don't match!" ); |
267 | } // LCOV_EXCL_STOP |
268 | |
269 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) { |
270 | BinaryExecutor::Execute<TA, TB, TR>(input.data[0], input.data[1], result, input.size(), udf_func); |
271 | }; |
272 | return udf_function; |
273 | } |
274 | |
275 | template <typename TR, typename... Args> |
276 | inline static scalar_function_t CreateTernaryFunction(const string &name, vector<LogicalType> args, |
277 | LogicalType ret_type, |
278 | TR (*udf_func)(Args...)) { // LCOV_EXCL_START |
279 | throw std::runtime_error("Incorrect number of arguments for ternary function" ); |
280 | } // LCOV_EXCL_STOP |
281 | |
282 | template <typename TR, typename TA, typename TB, typename TC> |
283 | inline static scalar_function_t CreateTernaryFunction(const string &name, vector<LogicalType> args, |
284 | LogicalType ret_type, TR (*udf_func)(TA, TB, TC)) { |
285 | if (args.size() != 3) { // LCOV_EXCL_START |
286 | throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 3!" ); |
287 | } |
288 | if (!TypesMatch<TA>(args[0])) { |
289 | throw std::runtime_error("The first arguments don't match!" ); |
290 | } |
291 | if (!TypesMatch<TB>(args[1])) { |
292 | throw std::runtime_error("The second arguments don't match!" ); |
293 | } |
294 | if (!TypesMatch<TC>(args[2])) { |
295 | throw std::runtime_error("The second arguments don't match!" ); |
296 | } // LCOV_EXCL_STOP |
297 | |
298 | scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { |
299 | TernaryExecutor::Execute<TA, TB, TC, TR>(input.data[0], input.data[1], input.data[2], result, input.size(), |
300 | udf_func); |
301 | }; |
302 | return udf_function; |
303 | } |
304 | |
305 | template <typename T> |
306 | inline static bool TypesMatch(const LogicalType &sql_type) { |
307 | switch (sql_type.id()) { |
308 | case LogicalTypeId::BOOLEAN: |
309 | return std::is_same<T, bool>(); |
310 | case LogicalTypeId::TINYINT: |
311 | return std::is_same<T, int8_t>(); |
312 | case LogicalTypeId::SMALLINT: |
313 | return std::is_same<T, int16_t>(); |
314 | case LogicalTypeId::INTEGER: |
315 | return std::is_same<T, int32_t>(); |
316 | case LogicalTypeId::BIGINT: |
317 | return std::is_same<T, int64_t>(); |
318 | case LogicalTypeId::DATE: |
319 | return std::is_same<T, date_t>(); |
320 | case LogicalTypeId::TIME: |
321 | case LogicalTypeId::TIME_TZ: |
322 | return std::is_same<T, dtime_t>(); |
323 | case LogicalTypeId::TIMESTAMP: |
324 | case LogicalTypeId::TIMESTAMP_MS: |
325 | case LogicalTypeId::TIMESTAMP_NS: |
326 | case LogicalTypeId::TIMESTAMP_SEC: |
327 | case LogicalTypeId::TIMESTAMP_TZ: |
328 | return std::is_same<T, timestamp_t>(); |
329 | case LogicalTypeId::FLOAT: |
330 | return std::is_same<T, float>(); |
331 | case LogicalTypeId::DOUBLE: |
332 | return std::is_same<T, double>(); |
333 | case LogicalTypeId::VARCHAR: |
334 | case LogicalTypeId::CHAR: |
335 | case LogicalTypeId::BLOB: |
336 | return std::is_same<T, string_t>(); |
337 | default: // LCOV_EXCL_START |
338 | throw std::runtime_error("Type is not supported!" ); |
339 | } // LCOV_EXCL_STOP |
340 | } |
341 | |
342 | private: |
343 | //-------------------------------- Aggregate functions --------------------------------// |
344 | template <typename UDF_OP, typename STATE, typename TR, typename TA> |
345 | inline static AggregateFunction CreateUnaryAggregateFunction(const string &name) { |
346 | LogicalType return_type = GetArgumentType<TR>(); |
347 | LogicalType input_type = GetArgumentType<TA>(); |
348 | return CreateUnaryAggregateFunction<UDF_OP, STATE, TR, TA>(name, return_type, input_type); |
349 | } |
350 | |
351 | template <typename UDF_OP, typename STATE, typename TR, typename TA> |
352 | inline static AggregateFunction CreateUnaryAggregateFunction(const string &name, LogicalType ret_type, |
353 | LogicalType input_type) { |
354 | AggregateFunction aggr_function = |
355 | AggregateFunction::UnaryAggregate<STATE, TR, TA, UDF_OP>(input_type, ret_type); |
356 | aggr_function.name = name; |
357 | return aggr_function; |
358 | } |
359 | |
360 | template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB> |
361 | inline static AggregateFunction CreateBinaryAggregateFunction(const string &name) { |
362 | LogicalType return_type = GetArgumentType<TR>(); |
363 | LogicalType input_typeA = GetArgumentType<TA>(); |
364 | LogicalType input_typeB = GetArgumentType<TB>(); |
365 | return CreateBinaryAggregateFunction<UDF_OP, STATE, TR, TA, TB>(name, return_type, input_typeA, input_typeB); |
366 | } |
367 | |
368 | template <typename UDF_OP, typename STATE, typename TR, typename TA, typename TB> |
369 | inline static AggregateFunction CreateBinaryAggregateFunction(const string &name, LogicalType ret_type, |
370 | LogicalType input_typeA, LogicalType input_typeB) { |
371 | AggregateFunction aggr_function = |
372 | AggregateFunction::BinaryAggregate<STATE, TR, TA, TB, UDF_OP>(input_typeA, input_typeB, ret_type); |
373 | aggr_function.name = name; |
374 | return aggr_function; |
375 | } |
376 | }; // end UDFWrapper |
377 | |
378 | } // namespace duckdb |
379 | |