1//===----------------------------------------------------------------------===//
2// DuckDB
3//
4// duckdb/function/scalar_function.hpp
5//
6//
7//===----------------------------------------------------------------------===//
8
9#pragma once
10
11#include "duckdb/common/vector_operations/binary_executor.hpp"
12#include "duckdb/common/vector_operations/ternary_executor.hpp"
13#include "duckdb/common/vector_operations/unary_executor.hpp"
14#include "duckdb/common/vector_operations/vector_operations.hpp"
15#include "duckdb/execution/expression_executor_state.hpp"
16#include "duckdb/function/function.hpp"
17#include "duckdb/planner/plan_serialization.hpp"
18#include "duckdb/storage/statistics/base_statistics.hpp"
19#include "duckdb/common/optional_ptr.hpp"
20
21namespace duckdb {
22
23struct FunctionLocalState {
24 DUCKDB_API virtual ~FunctionLocalState();
25
26 template <class TARGET>
27 TARGET &Cast() {
28 D_ASSERT(dynamic_cast<TARGET *>(this));
29 return reinterpret_cast<TARGET &>(*this);
30 }
31 template <class TARGET>
32 const TARGET &Cast() const {
33 D_ASSERT(dynamic_cast<const TARGET *>(this));
34 return reinterpret_cast<const TARGET &>(*this);
35 }
36};
37
38class Binder;
39class BoundFunctionExpression;
40class DependencyList;
41class ScalarFunctionCatalogEntry;
42
43struct FunctionStatisticsInput {
44 FunctionStatisticsInput(BoundFunctionExpression &expr_p, optional_ptr<FunctionData> bind_data_p,
45 vector<BaseStatistics> &child_stats_p, unique_ptr<Expression> *expr_ptr_p)
46 : expr(expr_p), bind_data(bind_data_p), child_stats(child_stats_p), expr_ptr(expr_ptr_p) {
47 }
48
49 BoundFunctionExpression &expr;
50 optional_ptr<FunctionData> bind_data;
51 vector<BaseStatistics> &child_stats;
52 unique_ptr<Expression> *expr_ptr;
53};
54
55//! The type used for scalar functions
56typedef std::function<void(DataChunk &, ExpressionState &, Vector &)> scalar_function_t;
57//! Binds the scalar function and creates the function data
58typedef unique_ptr<FunctionData> (*bind_scalar_function_t)(ClientContext &context, ScalarFunction &bound_function,
59 vector<unique_ptr<Expression>> &arguments);
60typedef unique_ptr<FunctionLocalState> (*init_local_state_t)(ExpressionState &state,
61 const BoundFunctionExpression &expr,
62 FunctionData *bind_data);
63typedef unique_ptr<BaseStatistics> (*function_statistics_t)(ClientContext &context, FunctionStatisticsInput &input);
64//! Adds the dependencies of this BoundFunctionExpression to the set of dependencies
65typedef void (*dependency_function_t)(BoundFunctionExpression &expr, DependencyList &dependencies);
66
67typedef void (*function_serialize_t)(FieldWriter &writer, const FunctionData *bind_data,
68 const ScalarFunction &function);
69typedef unique_ptr<FunctionData> (*function_deserialize_t)(PlanDeserializationState &state, FieldReader &reader,
70 ScalarFunction &function);
71
72class ScalarFunction : public BaseScalarFunction {
73public:
74 DUCKDB_API ScalarFunction(string name, vector<LogicalType> arguments, LogicalType return_type,
75 scalar_function_t function, bind_scalar_function_t bind = nullptr,
76 dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr,
77 init_local_state_t init_local_state = nullptr,
78 LogicalType varargs = LogicalType(LogicalTypeId::INVALID),
79 FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS,
80 FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING);
81
82 DUCKDB_API ScalarFunction(vector<LogicalType> arguments, LogicalType return_type, scalar_function_t function,
83 bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr,
84 function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr,
85 LogicalType varargs = LogicalType(LogicalTypeId::INVALID),
86 FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS,
87 FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING);
88
89 //! The main scalar function to execute
90 scalar_function_t function;
91 //! The bind function (if any)
92 bind_scalar_function_t bind;
93 //! Init thread local state for the function (if any)
94 init_local_state_t init_local_state;
95 //! The dependency function (if any)
96 dependency_function_t dependency;
97 //! The statistics propagation function (if any)
98 function_statistics_t statistics;
99
100 function_serialize_t serialize;
101 function_deserialize_t deserialize;
102
103 DUCKDB_API bool operator==(const ScalarFunction &rhs) const;
104 DUCKDB_API bool operator!=(const ScalarFunction &rhs) const;
105
106 DUCKDB_API bool Equal(const ScalarFunction &rhs) const;
107
108private:
109 bool CompareScalarFunctionT(const scalar_function_t &other) const;
110
111public:
112 DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result);
113
114 template <class TA, class TR, class OP>
115 static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) {
116 D_ASSERT(input.ColumnCount() >= 1);
117 UnaryExecutor::Execute<TA, TR, OP>(input.data[0], result, input.size());
118 }
119
120 template <class TA, class TB, class TR, class OP>
121 static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) {
122 D_ASSERT(input.ColumnCount() == 2);
123 BinaryExecutor::ExecuteStandard<TA, TB, TR, OP>(input.data[0], input.data[1], result, input.size());
124 }
125
126 template <class TA, class TB, class TC, class TR, class OP>
127 static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) {
128 D_ASSERT(input.ColumnCount() == 3);
129 TernaryExecutor::ExecuteStandard<TA, TB, TC, TR, OP>(input.data[0], input.data[1], input.data[2], result,
130 input.size());
131 }
132
133public:
134 template <class OP>
135 static scalar_function_t GetScalarUnaryFunction(LogicalType type) {
136 scalar_function_t function;
137 switch (type.id()) {
138 case LogicalTypeId::TINYINT:
139 function = &ScalarFunction::UnaryFunction<int8_t, int8_t, OP>;
140 break;
141 case LogicalTypeId::SMALLINT:
142 function = &ScalarFunction::UnaryFunction<int16_t, int16_t, OP>;
143 break;
144 case LogicalTypeId::INTEGER:
145 function = &ScalarFunction::UnaryFunction<int32_t, int32_t, OP>;
146 break;
147 case LogicalTypeId::BIGINT:
148 function = &ScalarFunction::UnaryFunction<int64_t, int64_t, OP>;
149 break;
150 case LogicalTypeId::UTINYINT:
151 function = &ScalarFunction::UnaryFunction<uint8_t, uint8_t, OP>;
152 break;
153 case LogicalTypeId::USMALLINT:
154 function = &ScalarFunction::UnaryFunction<uint16_t, uint16_t, OP>;
155 break;
156 case LogicalTypeId::UINTEGER:
157 function = &ScalarFunction::UnaryFunction<uint32_t, uint32_t, OP>;
158 break;
159 case LogicalTypeId::UBIGINT:
160 function = &ScalarFunction::UnaryFunction<uint64_t, uint64_t, OP>;
161 break;
162 case LogicalTypeId::HUGEINT:
163 function = &ScalarFunction::UnaryFunction<hugeint_t, hugeint_t, OP>;
164 break;
165 case LogicalTypeId::FLOAT:
166 function = &ScalarFunction::UnaryFunction<float, float, OP>;
167 break;
168 case LogicalTypeId::DOUBLE:
169 function = &ScalarFunction::UnaryFunction<double, double, OP>;
170 break;
171 default:
172 throw InternalException("Unimplemented type for GetScalarUnaryFunction");
173 }
174 return function;
175 }
176
177 template <class TR, class OP>
178 static scalar_function_t GetScalarUnaryFunctionFixedReturn(LogicalType type) {
179 scalar_function_t function;
180 switch (type.id()) {
181 case LogicalTypeId::TINYINT:
182 function = &ScalarFunction::UnaryFunction<int8_t, TR, OP>;
183 break;
184 case LogicalTypeId::SMALLINT:
185 function = &ScalarFunction::UnaryFunction<int16_t, TR, OP>;
186 break;
187 case LogicalTypeId::INTEGER:
188 function = &ScalarFunction::UnaryFunction<int32_t, TR, OP>;
189 break;
190 case LogicalTypeId::BIGINT:
191 function = &ScalarFunction::UnaryFunction<int64_t, TR, OP>;
192 break;
193 case LogicalTypeId::UTINYINT:
194 function = &ScalarFunction::UnaryFunction<uint8_t, TR, OP>;
195 break;
196 case LogicalTypeId::USMALLINT:
197 function = &ScalarFunction::UnaryFunction<uint16_t, TR, OP>;
198 break;
199 case LogicalTypeId::UINTEGER:
200 function = &ScalarFunction::UnaryFunction<uint32_t, TR, OP>;
201 break;
202 case LogicalTypeId::UBIGINT:
203 function = &ScalarFunction::UnaryFunction<uint64_t, TR, OP>;
204 break;
205 case LogicalTypeId::HUGEINT:
206 function = &ScalarFunction::UnaryFunction<hugeint_t, TR, OP>;
207 break;
208 case LogicalTypeId::FLOAT:
209 function = &ScalarFunction::UnaryFunction<float, TR, OP>;
210 break;
211 case LogicalTypeId::DOUBLE:
212 function = &ScalarFunction::UnaryFunction<double, TR, OP>;
213 break;
214 default:
215 throw InternalException("Unimplemented type for GetScalarUnaryFunctionFixedReturn");
216 }
217 return function;
218 }
219};
220
221} // namespace duckdb
222