1#include "duckdb/function/scalar/operators.hpp"
2#include "duckdb/common/vector_operations/vector_operations.hpp"
3#include "duckdb/common/operator/numeric_binary_operators.hpp"
4
5using namespace std;
6
7namespace duckdb {
8
9template <class OP> static scalar_function_t GetScalarBinaryFunction(SQLType type) {
10 switch (type.id) {
11 case SQLTypeId::TINYINT:
12 return ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
13 case SQLTypeId::SMALLINT:
14 return ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
15 case SQLTypeId::INTEGER:
16 return ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
17 case SQLTypeId::BIGINT:
18 return ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
19 case SQLTypeId::FLOAT:
20 return ScalarFunction::BinaryFunction<float, float, float, OP, true>;
21 case SQLTypeId::DOUBLE:
22 return ScalarFunction::BinaryFunction<double, double, double, OP, true>;
23 case SQLTypeId::DECIMAL:
24 return ScalarFunction::BinaryFunction<double, double, double, OP, true>;
25 default:
26 throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction");
27 }
28}
29
30//===--------------------------------------------------------------------===//
31// + [add]
32//===--------------------------------------------------------------------===//
33template <> float AddOperator::Operation(float left, float right) {
34 auto result = left + right;
35 if (!Value::FloatIsValid(result)) {
36 throw OutOfRangeException("Overflow in addition of float!");
37 }
38 return result;
39}
40
41template <> double AddOperator::Operation(double left, double right) {
42 auto result = left + right;
43 if (!Value::DoubleIsValid(result)) {
44 throw OutOfRangeException("Overflow in addition of double!");
45 }
46 return result;
47}
48
49void AddFun::RegisterFunction(BuiltinFunctions &set) {
50 ScalarFunctionSet functions("+");
51 // binary add function adds two numbers together
52 for (auto &type : SQLType::NUMERIC) {
53 functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<AddOperator>(type)));
54 }
55 // we can add integers to dates
56 functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE,
57 GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER)));
58 functions.AddFunction(ScalarFunction({SQLType::INTEGER, SQLType::DATE}, SQLType::DATE,
59 GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER)));
60 // unary add function is a nop, but only exists for numeric types
61 for (auto &type : SQLType::NUMERIC) {
62 functions.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction));
63 }
64 set.AddFunction(functions);
65}
66
67//===--------------------------------------------------------------------===//
68// - [subtract]
69//===--------------------------------------------------------------------===//
70template <> float SubtractOperator::Operation(float left, float right) {
71 auto result = left - right;
72 if (!Value::FloatIsValid(result)) {
73 throw OutOfRangeException("Overflow in subtraction of float!");
74 }
75 return result;
76}
77
78template <> double SubtractOperator::Operation(double left, double right) {
79 auto result = left - right;
80 if (!Value::DoubleIsValid(result)) {
81 throw OutOfRangeException("Overflow in subtraction of double!");
82 }
83 return result;
84}
85
86void SubtractFun::RegisterFunction(BuiltinFunctions &set) {
87 ScalarFunctionSet functions("-");
88 // binary subtract function "a - b", subtracts b from a
89 for (auto &type : SQLType::NUMERIC) {
90 functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<SubtractOperator>(type)));
91 }
92 functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::DATE}, SQLType::INTEGER,
93 GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER)));
94 functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE,
95 GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER)));
96 // unary subtract function, negates the input (i.e. multiplies by -1)
97 for (auto &type : SQLType::NUMERIC) {
98 functions.AddFunction(
99 ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type)));
100 }
101 set.AddFunction(functions);
102}
103
104//===--------------------------------------------------------------------===//
105// * [multiply]
106//===--------------------------------------------------------------------===//
107template <> float MultiplyOperator::Operation(float left, float right) {
108 auto result = left * right;
109 if (!Value::FloatIsValid(result)) {
110 throw OutOfRangeException("Overflow in multiplication of float!");
111 }
112 return result;
113}
114
115template <> double MultiplyOperator::Operation(double left, double right) {
116 auto result = left * right;
117 if (!Value::DoubleIsValid(result)) {
118 throw OutOfRangeException("Overflow in multiplication of double!");
119 }
120 return result;
121}
122
123void MultiplyFun::RegisterFunction(BuiltinFunctions &set) {
124 ScalarFunctionSet functions("*");
125 for (auto &type : SQLType::NUMERIC) {
126 functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type)));
127 }
128 set.AddFunction(functions);
129}
130
131//===--------------------------------------------------------------------===//
132// / [divide]
133//===--------------------------------------------------------------------===//
134template <> float DivideOperator::Operation(float left, float right) {
135 auto result = left / right;
136 if (!Value::FloatIsValid(result)) {
137 throw OutOfRangeException("Overflow in division of float!");
138 }
139 return result;
140}
141
142template <> double DivideOperator::Operation(double left, double right) {
143 auto result = left / right;
144 if (!Value::DoubleIsValid(result)) {
145 throw OutOfRangeException("Overflow in division of double!");
146 }
147 return result;
148}
149
150struct BinaryZeroIsNullWrapper {
151 template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE>
152 static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, nullmask_t &nullmask, idx_t idx) {
153 if (right == 0) {
154 nullmask[idx] = true;
155 return 0;
156 } else {
157 return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right);
158 }
159 }
160};
161
162template <class T, class OP>
163static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) {
164 BinaryExecutor::Execute<T, T, T, OP, true, BinaryZeroIsNullWrapper>(input.data[0], input.data[1], result,
165 input.size());
166}
167
168template <class OP> static scalar_function_t GetBinaryFunctionIgnoreZero(SQLType type) {
169 switch (type.id) {
170 case SQLTypeId::TINYINT:
171 return BinaryScalarFunctionIgnoreZero<int8_t, OP>;
172 case SQLTypeId::SMALLINT:
173 return BinaryScalarFunctionIgnoreZero<int16_t, OP>;
174 case SQLTypeId::INTEGER:
175 return BinaryScalarFunctionIgnoreZero<int32_t, OP>;
176 case SQLTypeId::BIGINT:
177 return BinaryScalarFunctionIgnoreZero<int64_t, OP>;
178 case SQLTypeId::FLOAT:
179 return BinaryScalarFunctionIgnoreZero<float, OP>;
180 case SQLTypeId::DOUBLE:
181 return BinaryScalarFunctionIgnoreZero<double, OP>;
182 case SQLTypeId::DECIMAL:
183 return BinaryScalarFunctionIgnoreZero<double, OP>;
184 default:
185 throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction");
186 }
187}
188
189void DivideFun::RegisterFunction(BuiltinFunctions &set) {
190 ScalarFunctionSet functions("/");
191 for (auto &type : SQLType::NUMERIC) {
192 functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type)));
193 }
194 set.AddFunction(functions);
195}
196
197//===--------------------------------------------------------------------===//
198// % [modulo]
199//===--------------------------------------------------------------------===//
200template <> float ModuloOperator::Operation(float left, float right) {
201 assert(right != 0);
202 return fmod(left, right);
203}
204
205template <> double ModuloOperator::Operation(double left, double right) {
206 assert(right != 0);
207 return fmod(left, right);
208}
209
210void ModFun::RegisterFunction(BuiltinFunctions &set) {
211 ScalarFunctionSet functions("%");
212 for (auto &type : SQLType::NUMERIC) {
213 functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type)));
214 }
215 set.AddFunction(functions);
216 functions.name = "mod";
217 set.AddFunction(functions);
218}
219
220} // namespace duckdb
221