| 1 | #include "duckdb/common/field_writer.hpp" |
| 2 | #include "duckdb/common/operator/add.hpp" |
| 3 | #include "duckdb/common/operator/multiply.hpp" |
| 4 | #include "duckdb/common/operator/numeric_binary_operators.hpp" |
| 5 | #include "duckdb/common/operator/subtract.hpp" |
| 6 | #include "duckdb/common/types/date.hpp" |
| 7 | #include "duckdb/common/types/decimal.hpp" |
| 8 | #include "duckdb/common/types/hugeint.hpp" |
| 9 | #include "duckdb/common/types/interval.hpp" |
| 10 | #include "duckdb/common/types/time.hpp" |
| 11 | #include "duckdb/common/types/timestamp.hpp" |
| 12 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 13 | #include "duckdb/common/enum_util.hpp" |
| 14 | #include "duckdb/function/scalar/operators.hpp" |
| 15 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 16 | #include "duckdb/function/scalar/nested_functions.hpp" |
| 17 | |
| 18 | #include <limits> |
| 19 | |
| 20 | namespace duckdb { |
| 21 | |
| 22 | template <class OP> |
| 23 | static scalar_function_t GetScalarIntegerFunction(PhysicalType type) { |
| 24 | scalar_function_t function; |
| 25 | switch (type) { |
| 26 | case PhysicalType::INT8: |
| 27 | function = &ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>; |
| 28 | break; |
| 29 | case PhysicalType::INT16: |
| 30 | function = &ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>; |
| 31 | break; |
| 32 | case PhysicalType::INT32: |
| 33 | function = &ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>; |
| 34 | break; |
| 35 | case PhysicalType::INT64: |
| 36 | function = &ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>; |
| 37 | break; |
| 38 | case PhysicalType::UINT8: |
| 39 | function = &ScalarFunction::BinaryFunction<uint8_t, uint8_t, uint8_t, OP>; |
| 40 | break; |
| 41 | case PhysicalType::UINT16: |
| 42 | function = &ScalarFunction::BinaryFunction<uint16_t, uint16_t, uint16_t, OP>; |
| 43 | break; |
| 44 | case PhysicalType::UINT32: |
| 45 | function = &ScalarFunction::BinaryFunction<uint32_t, uint32_t, uint32_t, OP>; |
| 46 | break; |
| 47 | case PhysicalType::UINT64: |
| 48 | function = &ScalarFunction::BinaryFunction<uint64_t, uint64_t, uint64_t, OP>; |
| 49 | break; |
| 50 | default: |
| 51 | throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction" ); |
| 52 | } |
| 53 | return function; |
| 54 | } |
| 55 | |
| 56 | template <class OP> |
| 57 | static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { |
| 58 | scalar_function_t function; |
| 59 | switch (type) { |
| 60 | case PhysicalType::INT128: |
| 61 | function = &ScalarFunction::BinaryFunction<hugeint_t, hugeint_t, hugeint_t, OP>; |
| 62 | break; |
| 63 | case PhysicalType::FLOAT: |
| 64 | function = &ScalarFunction::BinaryFunction<float, float, float, OP>; |
| 65 | break; |
| 66 | case PhysicalType::DOUBLE: |
| 67 | function = &ScalarFunction::BinaryFunction<double, double, double, OP>; |
| 68 | break; |
| 69 | default: |
| 70 | function = GetScalarIntegerFunction<OP>(type); |
| 71 | break; |
| 72 | } |
| 73 | return function; |
| 74 | } |
| 75 | |
| 76 | //===--------------------------------------------------------------------===// |
| 77 | // + [add] |
| 78 | //===--------------------------------------------------------------------===// |
| 79 | struct AddPropagateStatistics { |
| 80 | template <class T, class OP> |
| 81 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
| 82 | Value &new_max) { |
| 83 | T min, max; |
| 84 | // new min is min+min |
| 85 | if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMin<T>(rstats), min)) { |
| 86 | return true; |
| 87 | } |
| 88 | // new max is max+max |
| 89 | if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMax<T>(rstats), max)) { |
| 90 | return true; |
| 91 | } |
| 92 | new_min = Value::Numeric(type, min); |
| 93 | new_max = Value::Numeric(type, max); |
| 94 | return false; |
| 95 | } |
| 96 | }; |
| 97 | |
| 98 | struct SubtractPropagateStatistics { |
| 99 | template <class T, class OP> |
| 100 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
| 101 | Value &new_max) { |
| 102 | T min, max; |
| 103 | if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(rstats), min)) { |
| 104 | return true; |
| 105 | } |
| 106 | if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMin<T>(rstats), max)) { |
| 107 | return true; |
| 108 | } |
| 109 | new_min = Value::Numeric(type, min); |
| 110 | new_max = Value::Numeric(type, max); |
| 111 | return false; |
| 112 | } |
| 113 | }; |
| 114 | |
| 115 | struct DecimalArithmeticBindData : public FunctionData { |
| 116 | DecimalArithmeticBindData() : check_overflow(true) { |
| 117 | } |
| 118 | |
| 119 | unique_ptr<FunctionData> Copy() const override { |
| 120 | auto res = make_uniq<DecimalArithmeticBindData>(); |
| 121 | res->check_overflow = check_overflow; |
| 122 | return std::move(res); |
| 123 | } |
| 124 | |
| 125 | bool Equals(const FunctionData &other_p) const override { |
| 126 | auto other = other_p.Cast<DecimalArithmeticBindData>(); |
| 127 | return other.check_overflow == check_overflow; |
| 128 | } |
| 129 | |
| 130 | bool check_overflow; |
| 131 | }; |
| 132 | |
| 133 | template <class OP, class PROPAGATE, class BASEOP> |
| 134 | static unique_ptr<BaseStatistics> PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) { |
| 135 | auto &child_stats = input.child_stats; |
| 136 | auto &expr = input.expr; |
| 137 | D_ASSERT(child_stats.size() == 2); |
| 138 | // can only propagate stats if the children have stats |
| 139 | auto &lstats = child_stats[0]; |
| 140 | auto &rstats = child_stats[1]; |
| 141 | Value new_min, new_max; |
| 142 | bool potential_overflow = true; |
| 143 | if (NumericStats::HasMinMax(stats: lstats) && NumericStats::HasMinMax(stats: rstats)) { |
| 144 | switch (expr.return_type.InternalType()) { |
| 145 | case PhysicalType::INT8: |
| 146 | potential_overflow = |
| 147 | PROPAGATE::template Operation<int8_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
| 148 | break; |
| 149 | case PhysicalType::INT16: |
| 150 | potential_overflow = |
| 151 | PROPAGATE::template Operation<int16_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
| 152 | break; |
| 153 | case PhysicalType::INT32: |
| 154 | potential_overflow = |
| 155 | PROPAGATE::template Operation<int32_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
| 156 | break; |
| 157 | case PhysicalType::INT64: |
| 158 | potential_overflow = |
| 159 | PROPAGATE::template Operation<int64_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
| 160 | break; |
| 161 | default: |
| 162 | return nullptr; |
| 163 | } |
| 164 | } |
| 165 | if (potential_overflow) { |
| 166 | new_min = Value(expr.return_type); |
| 167 | new_max = Value(expr.return_type); |
| 168 | } else { |
| 169 | // no potential overflow: replace with non-overflowing operator |
| 170 | if (input.bind_data) { |
| 171 | auto &bind_data = input.bind_data->Cast<DecimalArithmeticBindData>(); |
| 172 | bind_data.check_overflow = false; |
| 173 | } |
| 174 | expr.function.function = GetScalarIntegerFunction<BASEOP>(expr.return_type.InternalType()); |
| 175 | } |
| 176 | auto result = NumericStats::CreateEmpty(type: expr.return_type); |
| 177 | NumericStats::SetMin(stats&: result, val: new_min); |
| 178 | NumericStats::SetMax(stats&: result, val: new_max); |
| 179 | result.CombineValidity(left&: lstats, right&: rstats); |
| 180 | return result.ToUnique(); |
| 181 | } |
| 182 | |
| 183 | template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false> |
| 184 | unique_ptr<FunctionData> BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function, |
| 185 | vector<unique_ptr<Expression>> &arguments) { |
| 186 | |
| 187 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
| 188 | |
| 189 | // get the max width and scale of the input arguments |
| 190 | uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; |
| 191 | for (idx_t i = 0; i < arguments.size(); i++) { |
| 192 | if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { |
| 193 | continue; |
| 194 | } |
| 195 | uint8_t width, scale; |
| 196 | auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); |
| 197 | if (!can_convert) { |
| 198 | throw InternalException("Could not convert type %s to a decimal." , arguments[i]->return_type.ToString()); |
| 199 | } |
| 200 | max_width = MaxValue<uint8_t>(a: width, b: max_width); |
| 201 | max_scale = MaxValue<uint8_t>(a: scale, b: max_scale); |
| 202 | max_width_over_scale = MaxValue<uint8_t>(a: width - scale, b: max_width_over_scale); |
| 203 | } |
| 204 | D_ASSERT(max_width > 0); |
| 205 | // for addition/subtraction, we add 1 to the width to ensure we don't overflow |
| 206 | auto required_width = MaxValue<uint8_t>(a: max_scale + max_width_over_scale, b: max_width) + 1; |
| 207 | if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) { |
| 208 | // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty |
| 209 | bind_data->check_overflow = true; |
| 210 | required_width = Decimal::MAX_WIDTH_INT64; |
| 211 | } |
| 212 | if (required_width > Decimal::MAX_WIDTH_DECIMAL) { |
| 213 | // target width does not fit in decimal at all: truncate the scale and perform overflow detection |
| 214 | bind_data->check_overflow = true; |
| 215 | required_width = Decimal::MAX_WIDTH_DECIMAL; |
| 216 | } |
| 217 | // arithmetic between two decimal arguments: check the types of the input arguments |
| 218 | LogicalType result_type = LogicalType::DECIMAL(width: required_width, scale: max_scale); |
| 219 | // we cast all input types to the specified type |
| 220 | for (idx_t i = 0; i < arguments.size(); i++) { |
| 221 | // first check if the cast is necessary |
| 222 | // if the argument has a matching scale and internal type as the output type, no casting is necessary |
| 223 | auto &argument_type = arguments[i]->return_type; |
| 224 | uint8_t width, scale; |
| 225 | argument_type.GetDecimalProperties(width, scale); |
| 226 | if (scale == DecimalType::GetScale(type: result_type) && argument_type.InternalType() == result_type.InternalType()) { |
| 227 | bound_function.arguments[i] = argument_type; |
| 228 | } else { |
| 229 | bound_function.arguments[i] = result_type; |
| 230 | } |
| 231 | } |
| 232 | bound_function.return_type = result_type; |
| 233 | // now select the physical function to execute |
| 234 | if (bind_data->check_overflow) { |
| 235 | bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(result_type.InternalType()); |
| 236 | } else { |
| 237 | bound_function.function = GetScalarBinaryFunction<OP>(result_type.InternalType()); |
| 238 | } |
| 239 | if (result_type.InternalType() != PhysicalType::INT128) { |
| 240 | if (IS_SUBTRACT) { |
| 241 | bound_function.statistics = |
| 242 | PropagateNumericStats<TryDecimalSubtract, SubtractPropagateStatistics, SubtractOperator>; |
| 243 | } else { |
| 244 | bound_function.statistics = PropagateNumericStats<TryDecimalAdd, AddPropagateStatistics, AddOperator>; |
| 245 | } |
| 246 | } |
| 247 | return std::move(bind_data); |
| 248 | } |
| 249 | |
| 250 | static void SerializeDecimalArithmetic(FieldWriter &writer, const FunctionData *bind_data_p, |
| 251 | const ScalarFunction &function) { |
| 252 | auto &bind_data = bind_data_p->Cast<DecimalArithmeticBindData>(); |
| 253 | writer.WriteField(element: bind_data.check_overflow); |
| 254 | writer.WriteSerializable(element: function.return_type); |
| 255 | writer.WriteRegularSerializableList(elements: function.arguments); |
| 256 | } |
| 257 | |
| 258 | // TODO this is partially duplicated from the bind |
| 259 | template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false> |
| 260 | unique_ptr<FunctionData> DeserializeDecimalArithmetic(PlanDeserializationState &state, FieldReader &reader, |
| 261 | ScalarFunction &bound_function) { |
| 262 | // re-change the function pointers |
| 263 | auto check_overflow = reader.ReadRequired<bool>(); |
| 264 | auto return_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
| 265 | auto arguments = reader.template ReadRequiredSerializableList<LogicalType, LogicalType>(); |
| 266 | |
| 267 | if (check_overflow) { |
| 268 | bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(return_type.InternalType()); |
| 269 | } else { |
| 270 | bound_function.function = GetScalarBinaryFunction<OP>(return_type.InternalType()); |
| 271 | } |
| 272 | bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again |
| 273 | bound_function.return_type = return_type; |
| 274 | bound_function.arguments = arguments; |
| 275 | |
| 276 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
| 277 | bind_data->check_overflow = check_overflow; |
| 278 | return std::move(bind_data); |
| 279 | } |
| 280 | |
| 281 | unique_ptr<FunctionData> NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, |
| 282 | vector<unique_ptr<Expression>> &arguments) { |
| 283 | bound_function.return_type = arguments[0]->return_type; |
| 284 | bound_function.arguments[0] = arguments[0]->return_type; |
| 285 | return nullptr; |
| 286 | } |
| 287 | |
| 288 | ScalarFunction AddFun::GetFunction(const LogicalType &type) { |
| 289 | D_ASSERT(type.IsNumeric()); |
| 290 | if (type.id() == LogicalTypeId::DECIMAL) { |
| 291 | return ScalarFunction("+" , {type}, type, ScalarFunction::NopFunction, NopDecimalBind); |
| 292 | } else { |
| 293 | return ScalarFunction("+" , {type}, type, ScalarFunction::NopFunction); |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | ScalarFunction AddFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { |
| 298 | if (left_type.IsNumeric() && left_type.id() == right_type.id()) { |
| 299 | if (left_type.id() == LogicalTypeId::DECIMAL) { |
| 300 | auto function = ScalarFunction("+" , {left_type, right_type}, left_type, nullptr, |
| 301 | BindDecimalAddSubtract<AddOperator, DecimalAddOverflowCheck>); |
| 302 | function.serialize = SerializeDecimalArithmetic; |
| 303 | function.deserialize = DeserializeDecimalArithmetic<AddOperator, DecimalAddOverflowCheck>; |
| 304 | return function; |
| 305 | } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) { |
| 306 | return ScalarFunction("+" , {left_type, right_type}, left_type, |
| 307 | GetScalarIntegerFunction<AddOperatorOverflowCheck>(type: left_type.InternalType()), nullptr, |
| 308 | nullptr, PropagateNumericStats<TryAddOperator, AddPropagateStatistics, AddOperator>); |
| 309 | } else { |
| 310 | return ScalarFunction("+" , {left_type, right_type}, left_type, |
| 311 | GetScalarBinaryFunction<AddOperator>(type: left_type.InternalType())); |
| 312 | } |
| 313 | } |
| 314 | |
| 315 | switch (left_type.id()) { |
| 316 | case LogicalTypeId::DATE: |
| 317 | if (right_type.id() == LogicalTypeId::INTEGER) { |
| 318 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
| 319 | ScalarFunction::BinaryFunction<date_t, int32_t, date_t, AddOperator>); |
| 320 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 321 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
| 322 | ScalarFunction::BinaryFunction<date_t, interval_t, date_t, AddOperator>); |
| 323 | } else if (right_type.id() == LogicalTypeId::TIME) { |
| 324 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
| 325 | ScalarFunction::BinaryFunction<date_t, dtime_t, timestamp_t, AddOperator>); |
| 326 | } |
| 327 | break; |
| 328 | case LogicalTypeId::INTEGER: |
| 329 | if (right_type.id() == LogicalTypeId::DATE) { |
| 330 | return ScalarFunction("+" , {left_type, right_type}, right_type, |
| 331 | ScalarFunction::BinaryFunction<int32_t, date_t, date_t, AddOperator>); |
| 332 | } |
| 333 | break; |
| 334 | case LogicalTypeId::INTERVAL: |
| 335 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 336 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::INTERVAL, |
| 337 | ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, AddOperator>); |
| 338 | } else if (right_type.id() == LogicalTypeId::DATE) { |
| 339 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
| 340 | ScalarFunction::BinaryFunction<interval_t, date_t, date_t, AddOperator>); |
| 341 | } else if (right_type.id() == LogicalTypeId::TIME) { |
| 342 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIME, |
| 343 | ScalarFunction::BinaryFunction<interval_t, dtime_t, dtime_t, AddTimeOperator>); |
| 344 | } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { |
| 345 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
| 346 | ScalarFunction::BinaryFunction<interval_t, timestamp_t, timestamp_t, AddOperator>); |
| 347 | } |
| 348 | break; |
| 349 | case LogicalTypeId::TIME: |
| 350 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 351 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIME, |
| 352 | ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, AddTimeOperator>); |
| 353 | } else if (right_type.id() == LogicalTypeId::DATE) { |
| 354 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
| 355 | ScalarFunction::BinaryFunction<dtime_t, date_t, timestamp_t, AddOperator>); |
| 356 | } |
| 357 | break; |
| 358 | case LogicalTypeId::TIMESTAMP: |
| 359 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 360 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
| 361 | ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, AddOperator>); |
| 362 | } |
| 363 | break; |
| 364 | default: |
| 365 | break; |
| 366 | } |
| 367 | // LCOV_EXCL_START |
| 368 | throw NotImplementedException("AddFun for types %s, %s" , EnumUtil::ToString(value: left_type.id()), |
| 369 | EnumUtil::ToString(value: right_type.id())); |
| 370 | // LCOV_EXCL_STOP |
| 371 | } |
| 372 | |
| 373 | void AddFun::RegisterFunction(BuiltinFunctions &set) { |
| 374 | ScalarFunctionSet functions("+" ); |
| 375 | for (auto &type : LogicalType::Numeric()) { |
| 376 | // unary add function is a nop, but only exists for numeric types |
| 377 | functions.AddFunction(function: GetFunction(type)); |
| 378 | // binary add function adds two numbers together |
| 379 | functions.AddFunction(function: GetFunction(left_type: type, right_type: type)); |
| 380 | } |
| 381 | // we can add integers to dates |
| 382 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER)); |
| 383 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTEGER, right_type: LogicalType::DATE)); |
| 384 | // we can add intervals together |
| 385 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL)); |
| 386 | // we can add intervals to dates/times/timestamps |
| 387 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL)); |
| 388 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::DATE)); |
| 389 | |
| 390 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL)); |
| 391 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIME)); |
| 392 | |
| 393 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL)); |
| 394 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIMESTAMP)); |
| 395 | |
| 396 | // we can add times to dates |
| 397 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::DATE)); |
| 398 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::TIME)); |
| 399 | |
| 400 | // we can add lists together |
| 401 | functions.AddFunction(function: ListConcatFun::GetFunction()); |
| 402 | |
| 403 | set.AddFunction(set: functions); |
| 404 | |
| 405 | functions.name = "add" ; |
| 406 | set.AddFunction(set: functions); |
| 407 | } |
| 408 | |
| 409 | //===--------------------------------------------------------------------===// |
| 410 | // - [subtract] |
| 411 | //===--------------------------------------------------------------------===// |
| 412 | struct NegateOperator { |
| 413 | template <class T> |
| 414 | static bool CanNegate(T input) { |
| 415 | using Limits = std::numeric_limits<T>; |
| 416 | return !(Limits::is_integer && Limits::is_signed && Limits::lowest() == input); |
| 417 | } |
| 418 | |
| 419 | template <class TA, class TR> |
| 420 | static inline TR Operation(TA input) { |
| 421 | auto cast = (TR)input; |
| 422 | if (!CanNegate<TR>(cast)) { |
| 423 | throw OutOfRangeException("Overflow in negation of integer!" ); |
| 424 | } |
| 425 | return -cast; |
| 426 | } |
| 427 | }; |
| 428 | |
| 429 | template <> |
| 430 | bool NegateOperator::CanNegate(float input) { |
| 431 | return true; |
| 432 | } |
| 433 | |
| 434 | template <> |
| 435 | bool NegateOperator::CanNegate(double input) { |
| 436 | return true; |
| 437 | } |
| 438 | |
| 439 | template <> |
| 440 | interval_t NegateOperator::Operation(interval_t input) { |
| 441 | interval_t result; |
| 442 | result.months = NegateOperator::Operation<int32_t, int32_t>(input: input.months); |
| 443 | result.days = NegateOperator::Operation<int32_t, int32_t>(input: input.days); |
| 444 | result.micros = NegateOperator::Operation<int64_t, int64_t>(input: input.micros); |
| 445 | return result; |
| 446 | } |
| 447 | |
| 448 | struct DecimalNegateBindData : public FunctionData { |
| 449 | DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) { |
| 450 | } |
| 451 | |
| 452 | unique_ptr<FunctionData> Copy() const override { |
| 453 | auto res = make_uniq<DecimalNegateBindData>(); |
| 454 | res->bound_type = bound_type; |
| 455 | return std::move(res); |
| 456 | } |
| 457 | |
| 458 | bool Equals(const FunctionData &other_p) const override { |
| 459 | auto other = other_p.Cast<DecimalNegateBindData>(); |
| 460 | return other.bound_type == bound_type; |
| 461 | } |
| 462 | |
| 463 | LogicalTypeId bound_type; |
| 464 | }; |
| 465 | |
| 466 | unique_ptr<FunctionData> DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, |
| 467 | vector<unique_ptr<Expression>> &arguments) { |
| 468 | |
| 469 | auto bind_data = make_uniq<DecimalNegateBindData>(); |
| 470 | |
| 471 | auto &decimal_type = arguments[0]->return_type; |
| 472 | auto width = DecimalType::GetWidth(type: decimal_type); |
| 473 | if (width <= Decimal::MAX_WIDTH_INT16) { |
| 474 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::SMALLINT); |
| 475 | } else if (width <= Decimal::MAX_WIDTH_INT32) { |
| 476 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::INTEGER); |
| 477 | } else if (width <= Decimal::MAX_WIDTH_INT64) { |
| 478 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::BIGINT); |
| 479 | } else { |
| 480 | D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); |
| 481 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::HUGEINT); |
| 482 | } |
| 483 | decimal_type.Verify(); |
| 484 | bound_function.arguments[0] = decimal_type; |
| 485 | bound_function.return_type = decimal_type; |
| 486 | return nullptr; |
| 487 | } |
| 488 | |
| 489 | struct NegatePropagateStatistics { |
| 490 | template <class T> |
| 491 | static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) { |
| 492 | auto max_value = NumericStats::GetMax<T>(istats); |
| 493 | auto min_value = NumericStats::GetMin<T>(istats); |
| 494 | if (!NegateOperator::CanNegate<T>(min_value) || !NegateOperator::CanNegate<T>(max_value)) { |
| 495 | return true; |
| 496 | } |
| 497 | // new min is -max |
| 498 | new_min = Value::Numeric(type, NegateOperator::Operation<T, T>(max_value)); |
| 499 | // new max is -min |
| 500 | new_max = Value::Numeric(type, NegateOperator::Operation<T, T>(min_value)); |
| 501 | return false; |
| 502 | } |
| 503 | }; |
| 504 | |
| 505 | static unique_ptr<BaseStatistics> NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) { |
| 506 | auto &child_stats = input.child_stats; |
| 507 | auto &expr = input.expr; |
| 508 | D_ASSERT(child_stats.size() == 1); |
| 509 | // can only propagate stats if the children have stats |
| 510 | auto &istats = child_stats[0]; |
| 511 | Value new_min, new_max; |
| 512 | bool potential_overflow = true; |
| 513 | if (NumericStats::HasMinMax(stats: istats)) { |
| 514 | switch (expr.return_type.InternalType()) { |
| 515 | case PhysicalType::INT8: |
| 516 | potential_overflow = |
| 517 | NegatePropagateStatistics::Operation<int8_t>(type: expr.return_type, istats, new_min, new_max); |
| 518 | break; |
| 519 | case PhysicalType::INT16: |
| 520 | potential_overflow = |
| 521 | NegatePropagateStatistics::Operation<int16_t>(type: expr.return_type, istats, new_min, new_max); |
| 522 | break; |
| 523 | case PhysicalType::INT32: |
| 524 | potential_overflow = |
| 525 | NegatePropagateStatistics::Operation<int32_t>(type: expr.return_type, istats, new_min, new_max); |
| 526 | break; |
| 527 | case PhysicalType::INT64: |
| 528 | potential_overflow = |
| 529 | NegatePropagateStatistics::Operation<int64_t>(type: expr.return_type, istats, new_min, new_max); |
| 530 | break; |
| 531 | default: |
| 532 | return nullptr; |
| 533 | } |
| 534 | } |
| 535 | if (potential_overflow) { |
| 536 | new_min = Value(expr.return_type); |
| 537 | new_max = Value(expr.return_type); |
| 538 | } |
| 539 | auto stats = NumericStats::CreateEmpty(type: expr.return_type); |
| 540 | NumericStats::SetMin(stats, val: new_min); |
| 541 | NumericStats::SetMax(stats, val: new_max); |
| 542 | stats.CopyValidity(stats&: istats); |
| 543 | return stats.ToUnique(); |
| 544 | } |
| 545 | |
| 546 | ScalarFunction SubtractFun::GetFunction(const LogicalType &type) { |
| 547 | if (type.id() == LogicalTypeId::INTERVAL) { |
| 548 | return ScalarFunction("-" , {type}, type, ScalarFunction::UnaryFunction<interval_t, interval_t, NegateOperator>); |
| 549 | } else if (type.id() == LogicalTypeId::DECIMAL) { |
| 550 | return ScalarFunction("-" , {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); |
| 551 | } else { |
| 552 | D_ASSERT(type.IsNumeric()); |
| 553 | return ScalarFunction("-" , {type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type), nullptr, |
| 554 | nullptr, NegateBindStatistics); |
| 555 | } |
| 556 | } |
| 557 | |
| 558 | ScalarFunction SubtractFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { |
| 559 | if (left_type.IsNumeric() && left_type.id() == right_type.id()) { |
| 560 | if (left_type.id() == LogicalTypeId::DECIMAL) { |
| 561 | auto function = |
| 562 | ScalarFunction("-" , {left_type, right_type}, left_type, nullptr, |
| 563 | BindDecimalAddSubtract<SubtractOperator, DecimalSubtractOverflowCheck, true>); |
| 564 | function.serialize = SerializeDecimalArithmetic; |
| 565 | function.deserialize = DeserializeDecimalArithmetic<SubtractOperator, DecimalSubtractOverflowCheck>; |
| 566 | return function; |
| 567 | } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) { |
| 568 | return ScalarFunction( |
| 569 | "-" , {left_type, right_type}, left_type, |
| 570 | GetScalarIntegerFunction<SubtractOperatorOverflowCheck>(type: left_type.InternalType()), nullptr, nullptr, |
| 571 | PropagateNumericStats<TrySubtractOperator, SubtractPropagateStatistics, SubtractOperator>); |
| 572 | |
| 573 | } else { |
| 574 | return ScalarFunction("-" , {left_type, right_type}, left_type, |
| 575 | GetScalarBinaryFunction<SubtractOperator>(type: left_type.InternalType())); |
| 576 | } |
| 577 | } |
| 578 | |
| 579 | switch (left_type.id()) { |
| 580 | case LogicalTypeId::DATE: |
| 581 | if (right_type.id() == LogicalTypeId::DATE) { |
| 582 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::BIGINT, |
| 583 | ScalarFunction::BinaryFunction<date_t, date_t, int64_t, SubtractOperator>); |
| 584 | |
| 585 | } else if (right_type.id() == LogicalTypeId::INTEGER) { |
| 586 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::DATE, |
| 587 | ScalarFunction::BinaryFunction<date_t, int32_t, date_t, SubtractOperator>); |
| 588 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 589 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::DATE, |
| 590 | ScalarFunction::BinaryFunction<date_t, interval_t, date_t, SubtractOperator>); |
| 591 | } |
| 592 | break; |
| 593 | case LogicalTypeId::TIMESTAMP: |
| 594 | if (right_type.id() == LogicalTypeId::TIMESTAMP) { |
| 595 | return ScalarFunction( |
| 596 | "-" , {left_type, right_type}, LogicalType::INTERVAL, |
| 597 | ScalarFunction::BinaryFunction<timestamp_t, timestamp_t, interval_t, SubtractOperator>); |
| 598 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 599 | return ScalarFunction( |
| 600 | "-" , {left_type, right_type}, LogicalType::TIMESTAMP, |
| 601 | ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, SubtractOperator>); |
| 602 | } |
| 603 | break; |
| 604 | case LogicalTypeId::INTERVAL: |
| 605 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 606 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::INTERVAL, |
| 607 | ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, SubtractOperator>); |
| 608 | } |
| 609 | break; |
| 610 | case LogicalTypeId::TIME: |
| 611 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
| 612 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::TIME, |
| 613 | ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, SubtractTimeOperator>); |
| 614 | } |
| 615 | break; |
| 616 | default: |
| 617 | break; |
| 618 | } |
| 619 | // LCOV_EXCL_START |
| 620 | throw NotImplementedException("SubtractFun for types %s, %s" , EnumUtil::ToString(value: left_type.id()), |
| 621 | EnumUtil::ToString(value: right_type.id())); |
| 622 | // LCOV_EXCL_STOP |
| 623 | } |
| 624 | |
| 625 | void SubtractFun::RegisterFunction(BuiltinFunctions &set) { |
| 626 | ScalarFunctionSet functions("-" ); |
| 627 | for (auto &type : LogicalType::Numeric()) { |
| 628 | // unary subtract function, negates the input (i.e. multiplies by -1) |
| 629 | functions.AddFunction(function: GetFunction(type)); |
| 630 | // binary subtract function "a - b", subtracts b from a |
| 631 | functions.AddFunction(function: GetFunction(left_type: type, right_type: type)); |
| 632 | } |
| 633 | // we can subtract dates from each other |
| 634 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::DATE)); |
| 635 | // we can subtract integers from dates |
| 636 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER)); |
| 637 | // we can subtract timestamps from each other |
| 638 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::TIMESTAMP)); |
| 639 | // we can subtract intervals from each other |
| 640 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL)); |
| 641 | // we can subtract intervals from dates/times/timestamps, but not the other way around |
| 642 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL)); |
| 643 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL)); |
| 644 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL)); |
| 645 | // we can negate intervals |
| 646 | functions.AddFunction(function: GetFunction(type: LogicalType::INTERVAL)); |
| 647 | set.AddFunction(set: functions); |
| 648 | |
| 649 | functions.name = "subtract" ; |
| 650 | set.AddFunction(set: functions); |
| 651 | } |
| 652 | |
| 653 | //===--------------------------------------------------------------------===// |
| 654 | // * [multiply] |
| 655 | //===--------------------------------------------------------------------===// |
| 656 | struct MultiplyPropagateStatistics { |
| 657 | template <class T, class OP> |
| 658 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
| 659 | Value &new_max) { |
| 660 | // statistics propagation on the multiplication is slightly less straightforward because of negative numbers |
| 661 | // the new min/max depend on the signs of the input types |
| 662 | // if both are positive the result is [lmin * rmin][lmax * rmax] |
| 663 | // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin] |
| 664 | // etc |
| 665 | // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax |
| 666 | // and check what the minimum/maximum value is |
| 667 | T lvals[] {NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(lstats)}; |
| 668 | T rvals[] {NumericStats::GetMin<T>(rstats), NumericStats::GetMax<T>(rstats)}; |
| 669 | T min = NumericLimits<T>::Maximum(); |
| 670 | T max = NumericLimits<T>::Minimum(); |
| 671 | // multiplications |
| 672 | for (idx_t l = 0; l < 2; l++) { |
| 673 | for (idx_t r = 0; r < 2; r++) { |
| 674 | T result; |
| 675 | if (!OP::Operation(lvals[l], rvals[r], result)) { |
| 676 | // potential overflow |
| 677 | return true; |
| 678 | } |
| 679 | if (result < min) { |
| 680 | min = result; |
| 681 | } |
| 682 | if (result > max) { |
| 683 | max = result; |
| 684 | } |
| 685 | } |
| 686 | } |
| 687 | new_min = Value::Numeric(type, min); |
| 688 | new_max = Value::Numeric(type, max); |
| 689 | return false; |
| 690 | } |
| 691 | }; |
| 692 | |
| 693 | unique_ptr<FunctionData> BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, |
| 694 | vector<unique_ptr<Expression>> &arguments) { |
| 695 | |
| 696 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
| 697 | |
| 698 | uint8_t result_width = 0, result_scale = 0; |
| 699 | uint8_t max_width = 0; |
| 700 | for (idx_t i = 0; i < arguments.size(); i++) { |
| 701 | if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { |
| 702 | continue; |
| 703 | } |
| 704 | uint8_t width, scale; |
| 705 | auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); |
| 706 | if (!can_convert) { |
| 707 | throw InternalException("Could not convert type %s to a decimal?" , arguments[i]->return_type.ToString()); |
| 708 | } |
| 709 | if (width > max_width) { |
| 710 | max_width = width; |
| 711 | } |
| 712 | result_width += width; |
| 713 | result_scale += scale; |
| 714 | } |
| 715 | D_ASSERT(max_width > 0); |
| 716 | if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { |
| 717 | throw OutOfRangeException( |
| 718 | "Needed scale %d to accurately represent the multiplication result, but this is out of range of the " |
| 719 | "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, " |
| 720 | "or add an explicit cast to a decimal with a lower scale." , |
| 721 | result_scale, Decimal::MAX_WIDTH_DECIMAL); |
| 722 | } |
| 723 | if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 && |
| 724 | result_scale < Decimal::MAX_WIDTH_INT64) { |
| 725 | bind_data->check_overflow = true; |
| 726 | result_width = Decimal::MAX_WIDTH_INT64; |
| 727 | } |
| 728 | if (result_width > Decimal::MAX_WIDTH_DECIMAL) { |
| 729 | bind_data->check_overflow = true; |
| 730 | result_width = Decimal::MAX_WIDTH_DECIMAL; |
| 731 | } |
| 732 | LogicalType result_type = LogicalType::DECIMAL(width: result_width, scale: result_scale); |
| 733 | // since our scale is the summation of our input scales, we do not need to cast to the result scale |
| 734 | // however, we might need to cast to the correct internal type |
| 735 | for (idx_t i = 0; i < arguments.size(); i++) { |
| 736 | auto &argument_type = arguments[i]->return_type; |
| 737 | if (argument_type.InternalType() == result_type.InternalType()) { |
| 738 | bound_function.arguments[i] = argument_type; |
| 739 | } else { |
| 740 | uint8_t width, scale; |
| 741 | if (!argument_type.GetDecimalProperties(width, scale)) { |
| 742 | scale = 0; |
| 743 | } |
| 744 | |
| 745 | bound_function.arguments[i] = LogicalType::DECIMAL(width: result_width, scale); |
| 746 | } |
| 747 | } |
| 748 | result_type.Verify(); |
| 749 | bound_function.return_type = result_type; |
| 750 | // now select the physical function to execute |
| 751 | if (bind_data->check_overflow) { |
| 752 | bound_function.function = GetScalarBinaryFunction<DecimalMultiplyOverflowCheck>(type: result_type.InternalType()); |
| 753 | } else { |
| 754 | bound_function.function = GetScalarBinaryFunction<MultiplyOperator>(type: result_type.InternalType()); |
| 755 | } |
| 756 | if (result_type.InternalType() != PhysicalType::INT128) { |
| 757 | bound_function.statistics = |
| 758 | PropagateNumericStats<TryDecimalMultiply, MultiplyPropagateStatistics, MultiplyOperator>; |
| 759 | } |
| 760 | return std::move(bind_data); |
| 761 | } |
| 762 | |
| 763 | void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { |
| 764 | ScalarFunctionSet functions("*" ); |
| 765 | for (auto &type : LogicalType::Numeric()) { |
| 766 | if (type.id() == LogicalTypeId::DECIMAL) { |
| 767 | ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); |
| 768 | function.serialize = SerializeDecimalArithmetic; |
| 769 | function.deserialize = DeserializeDecimalArithmetic<MultiplyOperator, DecimalMultiplyOverflowCheck>; |
| 770 | functions.AddFunction(function); |
| 771 | } else if (TypeIsIntegral(type: type.InternalType()) && type.id() != LogicalTypeId::HUGEINT) { |
| 772 | functions.AddFunction(function: ScalarFunction( |
| 773 | {type, type}, type, GetScalarIntegerFunction<MultiplyOperatorOverflowCheck>(type: type.InternalType()), |
| 774 | nullptr, nullptr, |
| 775 | PropagateNumericStats<TryMultiplyOperator, MultiplyPropagateStatistics, MultiplyOperator>)); |
| 776 | } else { |
| 777 | functions.AddFunction( |
| 778 | function: ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type: type.InternalType()))); |
| 779 | } |
| 780 | } |
| 781 | functions.AddFunction( |
| 782 | function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, |
| 783 | ScalarFunction::BinaryFunction<interval_t, int64_t, interval_t, MultiplyOperator>)); |
| 784 | functions.AddFunction( |
| 785 | function: ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL, |
| 786 | ScalarFunction::BinaryFunction<int64_t, interval_t, interval_t, MultiplyOperator>)); |
| 787 | set.AddFunction(set: functions); |
| 788 | |
| 789 | functions.name = "multiply" ; |
| 790 | set.AddFunction(set: functions); |
| 791 | } |
| 792 | |
| 793 | //===--------------------------------------------------------------------===// |
| 794 | // / [divide] |
| 795 | //===--------------------------------------------------------------------===// |
| 796 | template <> |
| 797 | float DivideOperator::Operation(float left, float right) { |
| 798 | auto result = left / right; |
| 799 | return result; |
| 800 | } |
| 801 | |
| 802 | template <> |
| 803 | double DivideOperator::Operation(double left, double right) { |
| 804 | auto result = left / right; |
| 805 | return result; |
| 806 | } |
| 807 | |
| 808 | template <> |
| 809 | hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { |
| 810 | if (right.lower == 0 && right.upper == 0) { |
| 811 | throw InternalException("Hugeint division by zero!" ); |
| 812 | } |
| 813 | return left / right; |
| 814 | } |
| 815 | |
| 816 | template <> |
| 817 | interval_t DivideOperator::Operation(interval_t left, int64_t right) { |
| 818 | left.days /= right; |
| 819 | left.months /= right; |
| 820 | left.micros /= right; |
| 821 | return left; |
| 822 | } |
| 823 | |
| 824 | struct BinaryNumericDivideWrapper { |
| 825 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
| 826 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
| 827 | if (left == NumericLimits<LEFT_TYPE>::Minimum() && right == -1) { |
| 828 | throw OutOfRangeException("Overflow in division of %d / %d" , left, right); |
| 829 | } else if (right == 0) { |
| 830 | mask.SetInvalid(idx); |
| 831 | return left; |
| 832 | } else { |
| 833 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
| 834 | } |
| 835 | } |
| 836 | |
| 837 | static bool AddsNulls() { |
| 838 | return true; |
| 839 | } |
| 840 | }; |
| 841 | |
| 842 | struct BinaryZeroIsNullWrapper { |
| 843 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
| 844 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
| 845 | if (right == 0) { |
| 846 | mask.SetInvalid(idx); |
| 847 | return left; |
| 848 | } else { |
| 849 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
| 850 | } |
| 851 | } |
| 852 | |
| 853 | static bool AddsNulls() { |
| 854 | return true; |
| 855 | } |
| 856 | }; |
| 857 | |
| 858 | struct BinaryZeroIsNullHugeintWrapper { |
| 859 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
| 860 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
| 861 | if (right.upper == 0 && right.lower == 0) { |
| 862 | mask.SetInvalid(idx); |
| 863 | return left; |
| 864 | } else { |
| 865 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
| 866 | } |
| 867 | } |
| 868 | |
| 869 | static bool AddsNulls() { |
| 870 | return true; |
| 871 | } |
| 872 | }; |
| 873 | |
| 874 | template <class TA, class TB, class TC, class OP, class ZWRAPPER = BinaryZeroIsNullWrapper> |
| 875 | static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { |
| 876 | BinaryExecutor::Execute<TA, TB, TC, OP, ZWRAPPER>(input.data[0], input.data[1], result, input.size()); |
| 877 | } |
| 878 | |
| 879 | template <class OP> |
| 880 | static scalar_function_t GetBinaryFunctionIgnoreZero(const LogicalType &type) { |
| 881 | switch (type.id()) { |
| 882 | case LogicalTypeId::TINYINT: |
| 883 | return BinaryScalarFunctionIgnoreZero<int8_t, int8_t, int8_t, OP, BinaryNumericDivideWrapper>; |
| 884 | case LogicalTypeId::SMALLINT: |
| 885 | return BinaryScalarFunctionIgnoreZero<int16_t, int16_t, int16_t, OP, BinaryNumericDivideWrapper>; |
| 886 | case LogicalTypeId::INTEGER: |
| 887 | return BinaryScalarFunctionIgnoreZero<int32_t, int32_t, int32_t, OP, BinaryNumericDivideWrapper>; |
| 888 | case LogicalTypeId::BIGINT: |
| 889 | return BinaryScalarFunctionIgnoreZero<int64_t, int64_t, int64_t, OP, BinaryNumericDivideWrapper>; |
| 890 | case LogicalTypeId::UTINYINT: |
| 891 | return BinaryScalarFunctionIgnoreZero<uint8_t, uint8_t, uint8_t, OP>; |
| 892 | case LogicalTypeId::USMALLINT: |
| 893 | return BinaryScalarFunctionIgnoreZero<uint16_t, uint16_t, uint16_t, OP>; |
| 894 | case LogicalTypeId::UINTEGER: |
| 895 | return BinaryScalarFunctionIgnoreZero<uint32_t, uint32_t, uint32_t, OP>; |
| 896 | case LogicalTypeId::UBIGINT: |
| 897 | return BinaryScalarFunctionIgnoreZero<uint64_t, uint64_t, uint64_t, OP>; |
| 898 | case LogicalTypeId::HUGEINT: |
| 899 | return BinaryScalarFunctionIgnoreZero<hugeint_t, hugeint_t, hugeint_t, OP, BinaryZeroIsNullHugeintWrapper>; |
| 900 | case LogicalTypeId::FLOAT: |
| 901 | return BinaryScalarFunctionIgnoreZero<float, float, float, OP>; |
| 902 | case LogicalTypeId::DOUBLE: |
| 903 | return BinaryScalarFunctionIgnoreZero<double, double, double, OP>; |
| 904 | default: |
| 905 | throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction" ); |
| 906 | } |
| 907 | } |
| 908 | |
| 909 | void DivideFun::RegisterFunction(BuiltinFunctions &set) { |
| 910 | ScalarFunctionSet fp_divide("/" ); |
| 911 | fp_divide.AddFunction(function: ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, |
| 912 | GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::FLOAT))); |
| 913 | fp_divide.AddFunction(function: ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, |
| 914 | GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::DOUBLE))); |
| 915 | fp_divide.AddFunction( |
| 916 | function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, |
| 917 | BinaryScalarFunctionIgnoreZero<interval_t, int64_t, interval_t, DivideOperator>)); |
| 918 | set.AddFunction(set: fp_divide); |
| 919 | |
| 920 | ScalarFunctionSet full_divide("//" ); |
| 921 | for (auto &type : LogicalType::Numeric()) { |
| 922 | if (type.id() == LogicalTypeId::DECIMAL) { |
| 923 | continue; |
| 924 | } else { |
| 925 | full_divide.AddFunction( |
| 926 | function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type))); |
| 927 | } |
| 928 | } |
| 929 | set.AddFunction(set: full_divide); |
| 930 | |
| 931 | full_divide.name = "divide" ; |
| 932 | set.AddFunction(set: full_divide); |
| 933 | } |
| 934 | |
| 935 | //===--------------------------------------------------------------------===// |
| 936 | // % [modulo] |
| 937 | //===--------------------------------------------------------------------===// |
| 938 | template <> |
| 939 | float ModuloOperator::Operation(float left, float right) { |
| 940 | D_ASSERT(right != 0); |
| 941 | auto result = std::fmod(x: left, y: right); |
| 942 | return result; |
| 943 | } |
| 944 | |
| 945 | template <> |
| 946 | double ModuloOperator::Operation(double left, double right) { |
| 947 | D_ASSERT(right != 0); |
| 948 | auto result = std::fmod(x: left, y: right); |
| 949 | return result; |
| 950 | } |
| 951 | |
| 952 | template <> |
| 953 | hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { |
| 954 | if (right.lower == 0 && right.upper == 0) { |
| 955 | throw InternalException("Hugeint division by zero!" ); |
| 956 | } |
| 957 | return left % right; |
| 958 | } |
| 959 | |
| 960 | void ModFun::RegisterFunction(BuiltinFunctions &set) { |
| 961 | ScalarFunctionSet functions("%" ); |
| 962 | for (auto &type : LogicalType::Numeric()) { |
| 963 | if (type.id() == LogicalTypeId::DECIMAL) { |
| 964 | continue; |
| 965 | } else { |
| 966 | functions.AddFunction( |
| 967 | function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type))); |
| 968 | } |
| 969 | } |
| 970 | set.AddFunction(set: functions); |
| 971 | functions.name = "mod" ; |
| 972 | set.AddFunction(set: functions); |
| 973 | } |
| 974 | |
| 975 | } // namespace duckdb |
| 976 | |