1 | #include "duckdb/function/scalar/operators.hpp" |
2 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
3 | #include "duckdb/common/operator/numeric_binary_operators.hpp" |
4 | |
5 | using namespace std; |
6 | |
7 | namespace duckdb { |
8 | |
9 | template <class OP> static scalar_function_t GetScalarBinaryFunction(SQLType type) { |
10 | switch (type.id) { |
11 | case SQLTypeId::TINYINT: |
12 | return ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>; |
13 | case SQLTypeId::SMALLINT: |
14 | return ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>; |
15 | case SQLTypeId::INTEGER: |
16 | return ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>; |
17 | case SQLTypeId::BIGINT: |
18 | return ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>; |
19 | case SQLTypeId::FLOAT: |
20 | return ScalarFunction::BinaryFunction<float, float, float, OP, true>; |
21 | case SQLTypeId::DOUBLE: |
22 | return ScalarFunction::BinaryFunction<double, double, double, OP, true>; |
23 | case SQLTypeId::DECIMAL: |
24 | return ScalarFunction::BinaryFunction<double, double, double, OP, true>; |
25 | default: |
26 | throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction" ); |
27 | } |
28 | } |
29 | |
30 | //===--------------------------------------------------------------------===// |
31 | // + [add] |
32 | //===--------------------------------------------------------------------===// |
33 | template <> float AddOperator::Operation(float left, float right) { |
34 | auto result = left + right; |
35 | if (!Value::FloatIsValid(result)) { |
36 | throw OutOfRangeException("Overflow in addition of float!" ); |
37 | } |
38 | return result; |
39 | } |
40 | |
41 | template <> double AddOperator::Operation(double left, double right) { |
42 | auto result = left + right; |
43 | if (!Value::DoubleIsValid(result)) { |
44 | throw OutOfRangeException("Overflow in addition of double!" ); |
45 | } |
46 | return result; |
47 | } |
48 | |
49 | void AddFun::RegisterFunction(BuiltinFunctions &set) { |
50 | ScalarFunctionSet functions("+" ); |
51 | // binary add function adds two numbers together |
52 | for (auto &type : SQLType::NUMERIC) { |
53 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<AddOperator>(type))); |
54 | } |
55 | // we can add integers to dates |
56 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE, |
57 | GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER))); |
58 | functions.AddFunction(ScalarFunction({SQLType::INTEGER, SQLType::DATE}, SQLType::DATE, |
59 | GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER))); |
60 | // unary add function is a nop, but only exists for numeric types |
61 | for (auto &type : SQLType::NUMERIC) { |
62 | functions.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); |
63 | } |
64 | set.AddFunction(functions); |
65 | } |
66 | |
67 | //===--------------------------------------------------------------------===// |
68 | // - [subtract] |
69 | //===--------------------------------------------------------------------===// |
70 | template <> float SubtractOperator::Operation(float left, float right) { |
71 | auto result = left - right; |
72 | if (!Value::FloatIsValid(result)) { |
73 | throw OutOfRangeException("Overflow in subtraction of float!" ); |
74 | } |
75 | return result; |
76 | } |
77 | |
78 | template <> double SubtractOperator::Operation(double left, double right) { |
79 | auto result = left - right; |
80 | if (!Value::DoubleIsValid(result)) { |
81 | throw OutOfRangeException("Overflow in subtraction of double!" ); |
82 | } |
83 | return result; |
84 | } |
85 | |
86 | void SubtractFun::RegisterFunction(BuiltinFunctions &set) { |
87 | ScalarFunctionSet functions("-" ); |
88 | // binary subtract function "a - b", subtracts b from a |
89 | for (auto &type : SQLType::NUMERIC) { |
90 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<SubtractOperator>(type))); |
91 | } |
92 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::DATE}, SQLType::INTEGER, |
93 | GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER))); |
94 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE, |
95 | GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER))); |
96 | // unary subtract function, negates the input (i.e. multiplies by -1) |
97 | for (auto &type : SQLType::NUMERIC) { |
98 | functions.AddFunction( |
99 | ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type))); |
100 | } |
101 | set.AddFunction(functions); |
102 | } |
103 | |
104 | //===--------------------------------------------------------------------===// |
105 | // * [multiply] |
106 | //===--------------------------------------------------------------------===// |
107 | template <> float MultiplyOperator::Operation(float left, float right) { |
108 | auto result = left * right; |
109 | if (!Value::FloatIsValid(result)) { |
110 | throw OutOfRangeException("Overflow in multiplication of float!" ); |
111 | } |
112 | return result; |
113 | } |
114 | |
115 | template <> double MultiplyOperator::Operation(double left, double right) { |
116 | auto result = left * right; |
117 | if (!Value::DoubleIsValid(result)) { |
118 | throw OutOfRangeException("Overflow in multiplication of double!" ); |
119 | } |
120 | return result; |
121 | } |
122 | |
123 | void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { |
124 | ScalarFunctionSet functions("*" ); |
125 | for (auto &type : SQLType::NUMERIC) { |
126 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type))); |
127 | } |
128 | set.AddFunction(functions); |
129 | } |
130 | |
131 | //===--------------------------------------------------------------------===// |
132 | // / [divide] |
133 | //===--------------------------------------------------------------------===// |
134 | template <> float DivideOperator::Operation(float left, float right) { |
135 | auto result = left / right; |
136 | if (!Value::FloatIsValid(result)) { |
137 | throw OutOfRangeException("Overflow in division of float!" ); |
138 | } |
139 | return result; |
140 | } |
141 | |
142 | template <> double DivideOperator::Operation(double left, double right) { |
143 | auto result = left / right; |
144 | if (!Value::DoubleIsValid(result)) { |
145 | throw OutOfRangeException("Overflow in division of double!" ); |
146 | } |
147 | return result; |
148 | } |
149 | |
150 | struct BinaryZeroIsNullWrapper { |
151 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
152 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, nullmask_t &nullmask, idx_t idx) { |
153 | if (right == 0) { |
154 | nullmask[idx] = true; |
155 | return 0; |
156 | } else { |
157 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
158 | } |
159 | } |
160 | }; |
161 | |
162 | template <class T, class OP> |
163 | static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { |
164 | BinaryExecutor::Execute<T, T, T, OP, true, BinaryZeroIsNullWrapper>(input.data[0], input.data[1], result, |
165 | input.size()); |
166 | } |
167 | |
168 | template <class OP> static scalar_function_t GetBinaryFunctionIgnoreZero(SQLType type) { |
169 | switch (type.id) { |
170 | case SQLTypeId::TINYINT: |
171 | return BinaryScalarFunctionIgnoreZero<int8_t, OP>; |
172 | case SQLTypeId::SMALLINT: |
173 | return BinaryScalarFunctionIgnoreZero<int16_t, OP>; |
174 | case SQLTypeId::INTEGER: |
175 | return BinaryScalarFunctionIgnoreZero<int32_t, OP>; |
176 | case SQLTypeId::BIGINT: |
177 | return BinaryScalarFunctionIgnoreZero<int64_t, OP>; |
178 | case SQLTypeId::FLOAT: |
179 | return BinaryScalarFunctionIgnoreZero<float, OP>; |
180 | case SQLTypeId::DOUBLE: |
181 | return BinaryScalarFunctionIgnoreZero<double, OP>; |
182 | case SQLTypeId::DECIMAL: |
183 | return BinaryScalarFunctionIgnoreZero<double, OP>; |
184 | default: |
185 | throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction" ); |
186 | } |
187 | } |
188 | |
189 | void DivideFun::RegisterFunction(BuiltinFunctions &set) { |
190 | ScalarFunctionSet functions("/" ); |
191 | for (auto &type : SQLType::NUMERIC) { |
192 | functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type))); |
193 | } |
194 | set.AddFunction(functions); |
195 | } |
196 | |
197 | //===--------------------------------------------------------------------===// |
198 | // % [modulo] |
199 | //===--------------------------------------------------------------------===// |
200 | template <> float ModuloOperator::Operation(float left, float right) { |
201 | assert(right != 0); |
202 | return fmod(left, right); |
203 | } |
204 | |
205 | template <> double ModuloOperator::Operation(double left, double right) { |
206 | assert(right != 0); |
207 | return fmod(left, right); |
208 | } |
209 | |
210 | void ModFun::RegisterFunction(BuiltinFunctions &set) { |
211 | ScalarFunctionSet functions("%" ); |
212 | for (auto &type : SQLType::NUMERIC) { |
213 | functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type))); |
214 | } |
215 | set.AddFunction(functions); |
216 | functions.name = "mod" ; |
217 | set.AddFunction(functions); |
218 | } |
219 | |
220 | } // namespace duckdb |
221 | |