| 1 | //===----------------------------------------------------------------------===// |
| 2 | // DuckDB |
| 3 | // |
| 4 | // duckdb/function/scalar_function.hpp |
| 5 | // |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #pragma once |
| 10 | |
| 11 | #include "duckdb/common/vector_operations/binary_executor.hpp" |
| 12 | #include "duckdb/common/vector_operations/ternary_executor.hpp" |
| 13 | #include "duckdb/common/vector_operations/unary_executor.hpp" |
| 14 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 15 | #include "duckdb/execution/expression_executor_state.hpp" |
| 16 | #include "duckdb/function/function.hpp" |
| 17 | #include "duckdb/planner/plan_serialization.hpp" |
| 18 | #include "duckdb/storage/statistics/base_statistics.hpp" |
| 19 | #include "duckdb/common/optional_ptr.hpp" |
| 20 | |
| 21 | namespace duckdb { |
| 22 | |
| 23 | struct FunctionLocalState { |
| 24 | DUCKDB_API virtual ~FunctionLocalState(); |
| 25 | |
| 26 | template <class TARGET> |
| 27 | TARGET &Cast() { |
| 28 | D_ASSERT(dynamic_cast<TARGET *>(this)); |
| 29 | return reinterpret_cast<TARGET &>(*this); |
| 30 | } |
| 31 | template <class TARGET> |
| 32 | const TARGET &Cast() const { |
| 33 | D_ASSERT(dynamic_cast<const TARGET *>(this)); |
| 34 | return reinterpret_cast<const TARGET &>(*this); |
| 35 | } |
| 36 | }; |
| 37 | |
| 38 | class Binder; |
| 39 | class BoundFunctionExpression; |
| 40 | class DependencyList; |
| 41 | class ScalarFunctionCatalogEntry; |
| 42 | |
| 43 | struct FunctionStatisticsInput { |
| 44 | FunctionStatisticsInput(BoundFunctionExpression &expr_p, optional_ptr<FunctionData> bind_data_p, |
| 45 | vector<BaseStatistics> &child_stats_p, unique_ptr<Expression> *expr_ptr_p) |
| 46 | : expr(expr_p), bind_data(bind_data_p), child_stats(child_stats_p), expr_ptr(expr_ptr_p) { |
| 47 | } |
| 48 | |
| 49 | BoundFunctionExpression &expr; |
| 50 | optional_ptr<FunctionData> bind_data; |
| 51 | vector<BaseStatistics> &child_stats; |
| 52 | unique_ptr<Expression> *expr_ptr; |
| 53 | }; |
| 54 | |
| 55 | //! The type used for scalar functions |
| 56 | typedef std::function<void(DataChunk &, ExpressionState &, Vector &)> scalar_function_t; |
| 57 | //! Binds the scalar function and creates the function data |
| 58 | typedef unique_ptr<FunctionData> (*bind_scalar_function_t)(ClientContext &context, ScalarFunction &bound_function, |
| 59 | vector<unique_ptr<Expression>> &arguments); |
| 60 | typedef unique_ptr<FunctionLocalState> (*init_local_state_t)(ExpressionState &state, |
| 61 | const BoundFunctionExpression &expr, |
| 62 | FunctionData *bind_data); |
| 63 | typedef unique_ptr<BaseStatistics> (*function_statistics_t)(ClientContext &context, FunctionStatisticsInput &input); |
| 64 | //! Adds the dependencies of this BoundFunctionExpression to the set of dependencies |
| 65 | typedef void (*dependency_function_t)(BoundFunctionExpression &expr, DependencyList &dependencies); |
| 66 | |
| 67 | typedef void (*function_serialize_t)(FieldWriter &writer, const FunctionData *bind_data, |
| 68 | const ScalarFunction &function); |
| 69 | typedef unique_ptr<FunctionData> (*function_deserialize_t)(PlanDeserializationState &state, FieldReader &reader, |
| 70 | ScalarFunction &function); |
| 71 | |
| 72 | class ScalarFunction : public BaseScalarFunction { |
| 73 | public: |
| 74 | DUCKDB_API ScalarFunction(string name, vector<LogicalType> arguments, LogicalType return_type, |
| 75 | scalar_function_t function, bind_scalar_function_t bind = nullptr, |
| 76 | dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr, |
| 77 | init_local_state_t init_local_state = nullptr, |
| 78 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID), |
| 79 | FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, |
| 80 | FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); |
| 81 | |
| 82 | DUCKDB_API ScalarFunction(vector<LogicalType> arguments, LogicalType return_type, scalar_function_t function, |
| 83 | bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr, |
| 84 | function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, |
| 85 | LogicalType varargs = LogicalType(LogicalTypeId::INVALID), |
| 86 | FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, |
| 87 | FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); |
| 88 | |
| 89 | //! The main scalar function to execute |
| 90 | scalar_function_t function; |
| 91 | //! The bind function (if any) |
| 92 | bind_scalar_function_t bind; |
| 93 | //! Init thread local state for the function (if any) |
| 94 | init_local_state_t init_local_state; |
| 95 | //! The dependency function (if any) |
| 96 | dependency_function_t dependency; |
| 97 | //! The statistics propagation function (if any) |
| 98 | function_statistics_t statistics; |
| 99 | |
| 100 | function_serialize_t serialize; |
| 101 | function_deserialize_t deserialize; |
| 102 | |
| 103 | DUCKDB_API bool operator==(const ScalarFunction &rhs) const; |
| 104 | DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; |
| 105 | |
| 106 | DUCKDB_API bool Equal(const ScalarFunction &rhs) const; |
| 107 | |
| 108 | private: |
| 109 | bool CompareScalarFunctionT(const scalar_function_t &other) const; |
| 110 | |
| 111 | public: |
| 112 | DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result); |
| 113 | |
| 114 | template <class TA, class TR, class OP> |
| 115 | static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
| 116 | D_ASSERT(input.ColumnCount() >= 1); |
| 117 | UnaryExecutor::Execute<TA, TR, OP>(input.data[0], result, input.size()); |
| 118 | } |
| 119 | |
| 120 | template <class TA, class TB, class TR, class OP> |
| 121 | static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
| 122 | D_ASSERT(input.ColumnCount() == 2); |
| 123 | BinaryExecutor::ExecuteStandard<TA, TB, TR, OP>(input.data[0], input.data[1], result, input.size()); |
| 124 | } |
| 125 | |
| 126 | template <class TA, class TB, class TC, class TR, class OP> |
| 127 | static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { |
| 128 | D_ASSERT(input.ColumnCount() == 3); |
| 129 | TernaryExecutor::ExecuteStandard<TA, TB, TC, TR, OP>(input.data[0], input.data[1], input.data[2], result, |
| 130 | input.size()); |
| 131 | } |
| 132 | |
| 133 | public: |
| 134 | template <class OP> |
| 135 | static scalar_function_t GetScalarUnaryFunction(LogicalType type) { |
| 136 | scalar_function_t function; |
| 137 | switch (type.id()) { |
| 138 | case LogicalTypeId::TINYINT: |
| 139 | function = &ScalarFunction::UnaryFunction<int8_t, int8_t, OP>; |
| 140 | break; |
| 141 | case LogicalTypeId::SMALLINT: |
| 142 | function = &ScalarFunction::UnaryFunction<int16_t, int16_t, OP>; |
| 143 | break; |
| 144 | case LogicalTypeId::INTEGER: |
| 145 | function = &ScalarFunction::UnaryFunction<int32_t, int32_t, OP>; |
| 146 | break; |
| 147 | case LogicalTypeId::BIGINT: |
| 148 | function = &ScalarFunction::UnaryFunction<int64_t, int64_t, OP>; |
| 149 | break; |
| 150 | case LogicalTypeId::UTINYINT: |
| 151 | function = &ScalarFunction::UnaryFunction<uint8_t, uint8_t, OP>; |
| 152 | break; |
| 153 | case LogicalTypeId::USMALLINT: |
| 154 | function = &ScalarFunction::UnaryFunction<uint16_t, uint16_t, OP>; |
| 155 | break; |
| 156 | case LogicalTypeId::UINTEGER: |
| 157 | function = &ScalarFunction::UnaryFunction<uint32_t, uint32_t, OP>; |
| 158 | break; |
| 159 | case LogicalTypeId::UBIGINT: |
| 160 | function = &ScalarFunction::UnaryFunction<uint64_t, uint64_t, OP>; |
| 161 | break; |
| 162 | case LogicalTypeId::HUGEINT: |
| 163 | function = &ScalarFunction::UnaryFunction<hugeint_t, hugeint_t, OP>; |
| 164 | break; |
| 165 | case LogicalTypeId::FLOAT: |
| 166 | function = &ScalarFunction::UnaryFunction<float, float, OP>; |
| 167 | break; |
| 168 | case LogicalTypeId::DOUBLE: |
| 169 | function = &ScalarFunction::UnaryFunction<double, double, OP>; |
| 170 | break; |
| 171 | default: |
| 172 | throw InternalException("Unimplemented type for GetScalarUnaryFunction" ); |
| 173 | } |
| 174 | return function; |
| 175 | } |
| 176 | |
| 177 | template <class TR, class OP> |
| 178 | static scalar_function_t GetScalarUnaryFunctionFixedReturn(LogicalType type) { |
| 179 | scalar_function_t function; |
| 180 | switch (type.id()) { |
| 181 | case LogicalTypeId::TINYINT: |
| 182 | function = &ScalarFunction::UnaryFunction<int8_t, TR, OP>; |
| 183 | break; |
| 184 | case LogicalTypeId::SMALLINT: |
| 185 | function = &ScalarFunction::UnaryFunction<int16_t, TR, OP>; |
| 186 | break; |
| 187 | case LogicalTypeId::INTEGER: |
| 188 | function = &ScalarFunction::UnaryFunction<int32_t, TR, OP>; |
| 189 | break; |
| 190 | case LogicalTypeId::BIGINT: |
| 191 | function = &ScalarFunction::UnaryFunction<int64_t, TR, OP>; |
| 192 | break; |
| 193 | case LogicalTypeId::UTINYINT: |
| 194 | function = &ScalarFunction::UnaryFunction<uint8_t, TR, OP>; |
| 195 | break; |
| 196 | case LogicalTypeId::USMALLINT: |
| 197 | function = &ScalarFunction::UnaryFunction<uint16_t, TR, OP>; |
| 198 | break; |
| 199 | case LogicalTypeId::UINTEGER: |
| 200 | function = &ScalarFunction::UnaryFunction<uint32_t, TR, OP>; |
| 201 | break; |
| 202 | case LogicalTypeId::UBIGINT: |
| 203 | function = &ScalarFunction::UnaryFunction<uint64_t, TR, OP>; |
| 204 | break; |
| 205 | case LogicalTypeId::HUGEINT: |
| 206 | function = &ScalarFunction::UnaryFunction<hugeint_t, TR, OP>; |
| 207 | break; |
| 208 | case LogicalTypeId::FLOAT: |
| 209 | function = &ScalarFunction::UnaryFunction<float, TR, OP>; |
| 210 | break; |
| 211 | case LogicalTypeId::DOUBLE: |
| 212 | function = &ScalarFunction::UnaryFunction<double, TR, OP>; |
| 213 | break; |
| 214 | default: |
| 215 | throw InternalException("Unimplemented type for GetScalarUnaryFunctionFixedReturn" ); |
| 216 | } |
| 217 | return function; |
| 218 | } |
| 219 | }; |
| 220 | |
| 221 | } // namespace duckdb |
| 222 | |