| 1 | #include "duckdb/function/scalar/operators.hpp" |
| 2 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 3 | #include "duckdb/common/operator/numeric_binary_operators.hpp" |
| 4 | |
| 5 | using namespace std; |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | template <class OP> static scalar_function_t GetScalarBinaryFunction(SQLType type) { |
| 10 | switch (type.id) { |
| 11 | case SQLTypeId::TINYINT: |
| 12 | return ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>; |
| 13 | case SQLTypeId::SMALLINT: |
| 14 | return ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>; |
| 15 | case SQLTypeId::INTEGER: |
| 16 | return ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>; |
| 17 | case SQLTypeId::BIGINT: |
| 18 | return ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>; |
| 19 | case SQLTypeId::FLOAT: |
| 20 | return ScalarFunction::BinaryFunction<float, float, float, OP, true>; |
| 21 | case SQLTypeId::DOUBLE: |
| 22 | return ScalarFunction::BinaryFunction<double, double, double, OP, true>; |
| 23 | case SQLTypeId::DECIMAL: |
| 24 | return ScalarFunction::BinaryFunction<double, double, double, OP, true>; |
| 25 | default: |
| 26 | throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction" ); |
| 27 | } |
| 28 | } |
| 29 | |
| 30 | //===--------------------------------------------------------------------===// |
| 31 | // + [add] |
| 32 | //===--------------------------------------------------------------------===// |
| 33 | template <> float AddOperator::Operation(float left, float right) { |
| 34 | auto result = left + right; |
| 35 | if (!Value::FloatIsValid(result)) { |
| 36 | throw OutOfRangeException("Overflow in addition of float!" ); |
| 37 | } |
| 38 | return result; |
| 39 | } |
| 40 | |
| 41 | template <> double AddOperator::Operation(double left, double right) { |
| 42 | auto result = left + right; |
| 43 | if (!Value::DoubleIsValid(result)) { |
| 44 | throw OutOfRangeException("Overflow in addition of double!" ); |
| 45 | } |
| 46 | return result; |
| 47 | } |
| 48 | |
| 49 | void AddFun::RegisterFunction(BuiltinFunctions &set) { |
| 50 | ScalarFunctionSet functions("+" ); |
| 51 | // binary add function adds two numbers together |
| 52 | for (auto &type : SQLType::NUMERIC) { |
| 53 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<AddOperator>(type))); |
| 54 | } |
| 55 | // we can add integers to dates |
| 56 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE, |
| 57 | GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER))); |
| 58 | functions.AddFunction(ScalarFunction({SQLType::INTEGER, SQLType::DATE}, SQLType::DATE, |
| 59 | GetScalarBinaryFunction<AddOperator>(SQLType::INTEGER))); |
| 60 | // unary add function is a nop, but only exists for numeric types |
| 61 | for (auto &type : SQLType::NUMERIC) { |
| 62 | functions.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); |
| 63 | } |
| 64 | set.AddFunction(functions); |
| 65 | } |
| 66 | |
| 67 | //===--------------------------------------------------------------------===// |
| 68 | // - [subtract] |
| 69 | //===--------------------------------------------------------------------===// |
| 70 | template <> float SubtractOperator::Operation(float left, float right) { |
| 71 | auto result = left - right; |
| 72 | if (!Value::FloatIsValid(result)) { |
| 73 | throw OutOfRangeException("Overflow in subtraction of float!" ); |
| 74 | } |
| 75 | return result; |
| 76 | } |
| 77 | |
| 78 | template <> double SubtractOperator::Operation(double left, double right) { |
| 79 | auto result = left - right; |
| 80 | if (!Value::DoubleIsValid(result)) { |
| 81 | throw OutOfRangeException("Overflow in subtraction of double!" ); |
| 82 | } |
| 83 | return result; |
| 84 | } |
| 85 | |
| 86 | void SubtractFun::RegisterFunction(BuiltinFunctions &set) { |
| 87 | ScalarFunctionSet functions("-" ); |
| 88 | // binary subtract function "a - b", subtracts b from a |
| 89 | for (auto &type : SQLType::NUMERIC) { |
| 90 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<SubtractOperator>(type))); |
| 91 | } |
| 92 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::DATE}, SQLType::INTEGER, |
| 93 | GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER))); |
| 94 | functions.AddFunction(ScalarFunction({SQLType::DATE, SQLType::INTEGER}, SQLType::DATE, |
| 95 | GetScalarBinaryFunction<SubtractOperator>(SQLType::INTEGER))); |
| 96 | // unary subtract function, negates the input (i.e. multiplies by -1) |
| 97 | for (auto &type : SQLType::NUMERIC) { |
| 98 | functions.AddFunction( |
| 99 | ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type))); |
| 100 | } |
| 101 | set.AddFunction(functions); |
| 102 | } |
| 103 | |
| 104 | //===--------------------------------------------------------------------===// |
| 105 | // * [multiply] |
| 106 | //===--------------------------------------------------------------------===// |
| 107 | template <> float MultiplyOperator::Operation(float left, float right) { |
| 108 | auto result = left * right; |
| 109 | if (!Value::FloatIsValid(result)) { |
| 110 | throw OutOfRangeException("Overflow in multiplication of float!" ); |
| 111 | } |
| 112 | return result; |
| 113 | } |
| 114 | |
| 115 | template <> double MultiplyOperator::Operation(double left, double right) { |
| 116 | auto result = left * right; |
| 117 | if (!Value::DoubleIsValid(result)) { |
| 118 | throw OutOfRangeException("Overflow in multiplication of double!" ); |
| 119 | } |
| 120 | return result; |
| 121 | } |
| 122 | |
| 123 | void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { |
| 124 | ScalarFunctionSet functions("*" ); |
| 125 | for (auto &type : SQLType::NUMERIC) { |
| 126 | functions.AddFunction(ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type))); |
| 127 | } |
| 128 | set.AddFunction(functions); |
| 129 | } |
| 130 | |
| 131 | //===--------------------------------------------------------------------===// |
| 132 | // / [divide] |
| 133 | //===--------------------------------------------------------------------===// |
| 134 | template <> float DivideOperator::Operation(float left, float right) { |
| 135 | auto result = left / right; |
| 136 | if (!Value::FloatIsValid(result)) { |
| 137 | throw OutOfRangeException("Overflow in division of float!" ); |
| 138 | } |
| 139 | return result; |
| 140 | } |
| 141 | |
| 142 | template <> double DivideOperator::Operation(double left, double right) { |
| 143 | auto result = left / right; |
| 144 | if (!Value::DoubleIsValid(result)) { |
| 145 | throw OutOfRangeException("Overflow in division of double!" ); |
| 146 | } |
| 147 | return result; |
| 148 | } |
| 149 | |
| 150 | struct BinaryZeroIsNullWrapper { |
| 151 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
| 152 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, nullmask_t &nullmask, idx_t idx) { |
| 153 | if (right == 0) { |
| 154 | nullmask[idx] = true; |
| 155 | return 0; |
| 156 | } else { |
| 157 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
| 158 | } |
| 159 | } |
| 160 | }; |
| 161 | |
| 162 | template <class T, class OP> |
| 163 | static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { |
| 164 | BinaryExecutor::Execute<T, T, T, OP, true, BinaryZeroIsNullWrapper>(input.data[0], input.data[1], result, |
| 165 | input.size()); |
| 166 | } |
| 167 | |
| 168 | template <class OP> static scalar_function_t GetBinaryFunctionIgnoreZero(SQLType type) { |
| 169 | switch (type.id) { |
| 170 | case SQLTypeId::TINYINT: |
| 171 | return BinaryScalarFunctionIgnoreZero<int8_t, OP>; |
| 172 | case SQLTypeId::SMALLINT: |
| 173 | return BinaryScalarFunctionIgnoreZero<int16_t, OP>; |
| 174 | case SQLTypeId::INTEGER: |
| 175 | return BinaryScalarFunctionIgnoreZero<int32_t, OP>; |
| 176 | case SQLTypeId::BIGINT: |
| 177 | return BinaryScalarFunctionIgnoreZero<int64_t, OP>; |
| 178 | case SQLTypeId::FLOAT: |
| 179 | return BinaryScalarFunctionIgnoreZero<float, OP>; |
| 180 | case SQLTypeId::DOUBLE: |
| 181 | return BinaryScalarFunctionIgnoreZero<double, OP>; |
| 182 | case SQLTypeId::DECIMAL: |
| 183 | return BinaryScalarFunctionIgnoreZero<double, OP>; |
| 184 | default: |
| 185 | throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction" ); |
| 186 | } |
| 187 | } |
| 188 | |
| 189 | void DivideFun::RegisterFunction(BuiltinFunctions &set) { |
| 190 | ScalarFunctionSet functions("/" ); |
| 191 | for (auto &type : SQLType::NUMERIC) { |
| 192 | functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type))); |
| 193 | } |
| 194 | set.AddFunction(functions); |
| 195 | } |
| 196 | |
| 197 | //===--------------------------------------------------------------------===// |
| 198 | // % [modulo] |
| 199 | //===--------------------------------------------------------------------===// |
| 200 | template <> float ModuloOperator::Operation(float left, float right) { |
| 201 | assert(right != 0); |
| 202 | return fmod(left, right); |
| 203 | } |
| 204 | |
| 205 | template <> double ModuloOperator::Operation(double left, double right) { |
| 206 | assert(right != 0); |
| 207 | return fmod(left, right); |
| 208 | } |
| 209 | |
| 210 | void ModFun::RegisterFunction(BuiltinFunctions &set) { |
| 211 | ScalarFunctionSet functions("%" ); |
| 212 | for (auto &type : SQLType::NUMERIC) { |
| 213 | functions.AddFunction(ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type))); |
| 214 | } |
| 215 | set.AddFunction(functions); |
| 216 | functions.name = "mod" ; |
| 217 | set.AddFunction(functions); |
| 218 | } |
| 219 | |
| 220 | } // namespace duckdb |
| 221 | |