1 | //===----------------------------------------------------------------------===// |
2 | // DuckDB |
3 | // |
4 | // duckdb/function/scalar_function.hpp |
5 | // |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #pragma once |
10 | |
11 | #include "duckdb/common/vector_operations/binary_executor.hpp" |
12 | #include "duckdb/common/vector_operations/ternary_executor.hpp" |
13 | #include "duckdb/common/vector_operations/unary_executor.hpp" |
14 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
15 | #include "duckdb/execution/expression_executor_state.hpp" |
16 | #include "duckdb/function/function.hpp" |
17 | #include "duckdb/planner/plan_serialization.hpp" |
18 | #include "duckdb/storage/statistics/base_statistics.hpp" |
19 | #include "duckdb/common/optional_ptr.hpp" |
20 | |
21 | namespace duckdb { |
22 | |
23 | struct FunctionLocalState { |
24 | DUCKDB_API virtual ~FunctionLocalState(); |
25 | |
26 | template <class TARGET> |
27 | TARGET &Cast() { |
28 | D_ASSERT(dynamic_cast<TARGET *>(this)); |
29 | return reinterpret_cast<TARGET &>(*this); |
30 | } |
31 | template <class TARGET> |
32 | const TARGET &Cast() const { |
33 | D_ASSERT(dynamic_cast<const TARGET *>(this)); |
34 | return reinterpret_cast<const TARGET &>(*this); |
35 | } |
36 | }; |
37 | |
38 | class Binder; |
39 | class BoundFunctionExpression; |
40 | class DependencyList; |
41 | class ScalarFunctionCatalogEntry; |
42 | |
43 | struct FunctionStatisticsInput { |
44 | FunctionStatisticsInput(BoundFunctionExpression &expr_p, optional_ptr<FunctionData> bind_data_p, |
45 | vector<BaseStatistics> &child_stats_p, unique_ptr<Expression> *expr_ptr_p) |
46 | : expr(expr_p), bind_data(bind_data_p), child_stats(child_stats_p), expr_ptr(expr_ptr_p) { |
47 | } |
48 | |
49 | BoundFunctionExpression &expr; |
50 | optional_ptr<FunctionData> bind_data; |
51 | vector<BaseStatistics> &child_stats; |
52 | unique_ptr<Expression> *expr_ptr; |
53 | }; |
54 | |
55 | //! The type used for scalar functions |
56 | typedef std::function<void(DataChunk &, ExpressionState &, Vector &)> scalar_function_t; |
57 | //! Binds the scalar function and creates the function data |
58 | typedef unique_ptr<FunctionData> (*bind_scalar_function_t)(ClientContext &context, ScalarFunction &bound_function, |
59 | vector<unique_ptr<Expression>> &arguments); |
60 | typedef unique_ptr<FunctionLocalState> (*init_local_state_t)(ExpressionState &state, |
61 | const BoundFunctionExpression &expr, |
62 | FunctionData *bind_data); |
63 | typedef unique_ptr<BaseStatistics> (*function_statistics_t)(ClientContext &context, FunctionStatisticsInput &input); |
64 | //! Adds the dependencies of this BoundFunctionExpression to the set of dependencies |
65 | typedef void (*dependency_function_t)(BoundFunctionExpression &expr, DependencyList &dependencies); |
66 | |
67 | typedef void (*function_serialize_t)(FieldWriter &writer, const FunctionData *bind_data, |
68 | const ScalarFunction &function); |
69 | typedef unique_ptr<FunctionData> (*function_deserialize_t)(PlanDeserializationState &state, FieldReader &reader, |
70 | ScalarFunction &function); |
71 | |
72 | class ScalarFunction : public BaseScalarFunction { |
73 | public: |
74 | DUCKDB_API ScalarFunction(string name, vector<LogicalType> arguments, LogicalType return_type, |
75 | scalar_function_t function, bind_scalar_function_t bind = nullptr, |
76 | dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr, |
77 | init_local_state_t init_local_state = nullptr, |
78 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID), |
79 | FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, |
80 | FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); |
81 | |
82 | DUCKDB_API ScalarFunction(vector<LogicalType> arguments, LogicalType return_type, scalar_function_t function, |
83 | bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr, |
84 | function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, |
85 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID), |
86 | FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, |
87 | FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); |
88 | |
89 | //! The main scalar function to execute |
90 | scalar_function_t function; |
91 | //! The bind function (if any) |
92 | bind_scalar_function_t bind; |
93 | //! Init thread local state for the function (if any) |
94 | init_local_state_t init_local_state; |
95 | //! The dependency function (if any) |
96 | dependency_function_t dependency; |
97 | //! The statistics propagation function (if any) |
98 | function_statistics_t statistics; |
99 | |
100 | function_serialize_t serialize; |
101 | function_deserialize_t deserialize; |
102 | |
103 | DUCKDB_API bool operator==(const ScalarFunction &rhs) const; |
104 | DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; |
105 | |
106 | DUCKDB_API bool Equal(const ScalarFunction &rhs) const; |
107 | |
108 | private: |
109 | bool CompareScalarFunctionT(const scalar_function_t &other) const; |
110 | |
111 | public: |
112 | DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result); |
113 | |
114 | template <class TA, class TR, class OP> |
115 | static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
116 | D_ASSERT(input.ColumnCount() >= 1); |
117 | UnaryExecutor::Execute<TA, TR, OP>(input.data[0], result, input.size()); |
118 | } |
119 | |
120 | template <class TA, class TB, class TR, class OP> |
121 | static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
122 | D_ASSERT(input.ColumnCount() == 2); |
123 | BinaryExecutor::ExecuteStandard<TA, TB, TR, OP>(input.data[0], input.data[1], result, input.size()); |
124 | } |
125 | |
126 | template <class TA, class TB, class TC, class TR, class OP> |
127 | static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
128 | D_ASSERT(input.ColumnCount() == 3); |
129 | TernaryExecutor::ExecuteStandard<TA, TB, TC, TR, OP>(input.data[0], input.data[1], input.data[2], result, |
130 | input.size()); |
131 | } |
132 | |
133 | public: |
134 | template <class OP> |
135 | static scalar_function_t GetScalarUnaryFunction(LogicalType type) { |
136 | scalar_function_t function; |
137 | switch (type.id()) { |
138 | case LogicalTypeId::TINYINT: |
139 | function = &ScalarFunction::UnaryFunction<int8_t, int8_t, OP>; |
140 | break; |
141 | case LogicalTypeId::SMALLINT: |
142 | function = &ScalarFunction::UnaryFunction<int16_t, int16_t, OP>; |
143 | break; |
144 | case LogicalTypeId::INTEGER: |
145 | function = &ScalarFunction::UnaryFunction<int32_t, int32_t, OP>; |
146 | break; |
147 | case LogicalTypeId::BIGINT: |
148 | function = &ScalarFunction::UnaryFunction<int64_t, int64_t, OP>; |
149 | break; |
150 | case LogicalTypeId::UTINYINT: |
151 | function = &ScalarFunction::UnaryFunction<uint8_t, uint8_t, OP>; |
152 | break; |
153 | case LogicalTypeId::USMALLINT: |
154 | function = &ScalarFunction::UnaryFunction<uint16_t, uint16_t, OP>; |
155 | break; |
156 | case LogicalTypeId::UINTEGER: |
157 | function = &ScalarFunction::UnaryFunction<uint32_t, uint32_t, OP>; |
158 | break; |
159 | case LogicalTypeId::UBIGINT: |
160 | function = &ScalarFunction::UnaryFunction<uint64_t, uint64_t, OP>; |
161 | break; |
162 | case LogicalTypeId::HUGEINT: |
163 | function = &ScalarFunction::UnaryFunction<hugeint_t, hugeint_t, OP>; |
164 | break; |
165 | case LogicalTypeId::FLOAT: |
166 | function = &ScalarFunction::UnaryFunction<float, float, OP>; |
167 | break; |
168 | case LogicalTypeId::DOUBLE: |
169 | function = &ScalarFunction::UnaryFunction<double, double, OP>; |
170 | break; |
171 | default: |
172 | throw InternalException("Unimplemented type for GetScalarUnaryFunction" ); |
173 | } |
174 | return function; |
175 | } |
176 | |
177 | template <class TR, class OP> |
178 | static scalar_function_t GetScalarUnaryFunctionFixedReturn(LogicalType type) { |
179 | scalar_function_t function; |
180 | switch (type.id()) { |
181 | case LogicalTypeId::TINYINT: |
182 | function = &ScalarFunction::UnaryFunction<int8_t, TR, OP>; |
183 | break; |
184 | case LogicalTypeId::SMALLINT: |
185 | function = &ScalarFunction::UnaryFunction<int16_t, TR, OP>; |
186 | break; |
187 | case LogicalTypeId::INTEGER: |
188 | function = &ScalarFunction::UnaryFunction<int32_t, TR, OP>; |
189 | break; |
190 | case LogicalTypeId::BIGINT: |
191 | function = &ScalarFunction::UnaryFunction<int64_t, TR, OP>; |
192 | break; |
193 | case LogicalTypeId::UTINYINT: |
194 | function = &ScalarFunction::UnaryFunction<uint8_t, TR, OP>; |
195 | break; |
196 | case LogicalTypeId::USMALLINT: |
197 | function = &ScalarFunction::UnaryFunction<uint16_t, TR, OP>; |
198 | break; |
199 | case LogicalTypeId::UINTEGER: |
200 | function = &ScalarFunction::UnaryFunction<uint32_t, TR, OP>; |
201 | break; |
202 | case LogicalTypeId::UBIGINT: |
203 | function = &ScalarFunction::UnaryFunction<uint64_t, TR, OP>; |
204 | break; |
205 | case LogicalTypeId::HUGEINT: |
206 | function = &ScalarFunction::UnaryFunction<hugeint_t, TR, OP>; |
207 | break; |
208 | case LogicalTypeId::FLOAT: |
209 | function = &ScalarFunction::UnaryFunction<float, TR, OP>; |
210 | break; |
211 | case LogicalTypeId::DOUBLE: |
212 | function = &ScalarFunction::UnaryFunction<double, TR, OP>; |
213 | break; |
214 | default: |
215 | throw InternalException("Unimplemented type for GetScalarUnaryFunctionFixedReturn" ); |
216 | } |
217 | return function; |
218 | } |
219 | }; |
220 | |
221 | } // namespace duckdb |
222 | |