1#include "duckdb/common/field_writer.hpp"
2#include "duckdb/common/operator/add.hpp"
3#include "duckdb/common/operator/multiply.hpp"
4#include "duckdb/common/operator/numeric_binary_operators.hpp"
5#include "duckdb/common/operator/subtract.hpp"
6#include "duckdb/common/types/date.hpp"
7#include "duckdb/common/types/decimal.hpp"
8#include "duckdb/common/types/hugeint.hpp"
9#include "duckdb/common/types/interval.hpp"
10#include "duckdb/common/types/time.hpp"
11#include "duckdb/common/types/timestamp.hpp"
12#include "duckdb/common/vector_operations/vector_operations.hpp"
13#include "duckdb/common/enum_util.hpp"
14#include "duckdb/function/scalar/operators.hpp"
15#include "duckdb/planner/expression/bound_function_expression.hpp"
16#include "duckdb/function/scalar/nested_functions.hpp"
17
18#include <limits>
19
20namespace duckdb {
21
22template <class OP>
23static scalar_function_t GetScalarIntegerFunction(PhysicalType type) {
24 scalar_function_t function;
25 switch (type) {
26 case PhysicalType::INT8:
27 function = &ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
28 break;
29 case PhysicalType::INT16:
30 function = &ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
31 break;
32 case PhysicalType::INT32:
33 function = &ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
34 break;
35 case PhysicalType::INT64:
36 function = &ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
37 break;
38 case PhysicalType::UINT8:
39 function = &ScalarFunction::BinaryFunction<uint8_t, uint8_t, uint8_t, OP>;
40 break;
41 case PhysicalType::UINT16:
42 function = &ScalarFunction::BinaryFunction<uint16_t, uint16_t, uint16_t, OP>;
43 break;
44 case PhysicalType::UINT32:
45 function = &ScalarFunction::BinaryFunction<uint32_t, uint32_t, uint32_t, OP>;
46 break;
47 case PhysicalType::UINT64:
48 function = &ScalarFunction::BinaryFunction<uint64_t, uint64_t, uint64_t, OP>;
49 break;
50 default:
51 throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction");
52 }
53 return function;
54}
55
56template <class OP>
57static scalar_function_t GetScalarBinaryFunction(PhysicalType type) {
58 scalar_function_t function;
59 switch (type) {
60 case PhysicalType::INT128:
61 function = &ScalarFunction::BinaryFunction<hugeint_t, hugeint_t, hugeint_t, OP>;
62 break;
63 case PhysicalType::FLOAT:
64 function = &ScalarFunction::BinaryFunction<float, float, float, OP>;
65 break;
66 case PhysicalType::DOUBLE:
67 function = &ScalarFunction::BinaryFunction<double, double, double, OP>;
68 break;
69 default:
70 function = GetScalarIntegerFunction<OP>(type);
71 break;
72 }
73 return function;
74}
75
76//===--------------------------------------------------------------------===//
77// + [add]
78//===--------------------------------------------------------------------===//
79struct AddPropagateStatistics {
80 template <class T, class OP>
81 static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min,
82 Value &new_max) {
83 T min, max;
84 // new min is min+min
85 if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMin<T>(rstats), min)) {
86 return true;
87 }
88 // new max is max+max
89 if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMax<T>(rstats), max)) {
90 return true;
91 }
92 new_min = Value::Numeric(type, min);
93 new_max = Value::Numeric(type, max);
94 return false;
95 }
96};
97
98struct SubtractPropagateStatistics {
99 template <class T, class OP>
100 static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min,
101 Value &new_max) {
102 T min, max;
103 if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(rstats), min)) {
104 return true;
105 }
106 if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMin<T>(rstats), max)) {
107 return true;
108 }
109 new_min = Value::Numeric(type, min);
110 new_max = Value::Numeric(type, max);
111 return false;
112 }
113};
114
115struct DecimalArithmeticBindData : public FunctionData {
116 DecimalArithmeticBindData() : check_overflow(true) {
117 }
118
119 unique_ptr<FunctionData> Copy() const override {
120 auto res = make_uniq<DecimalArithmeticBindData>();
121 res->check_overflow = check_overflow;
122 return std::move(res);
123 }
124
125 bool Equals(const FunctionData &other_p) const override {
126 auto other = other_p.Cast<DecimalArithmeticBindData>();
127 return other.check_overflow == check_overflow;
128 }
129
130 bool check_overflow;
131};
132
133template <class OP, class PROPAGATE, class BASEOP>
134static unique_ptr<BaseStatistics> PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) {
135 auto &child_stats = input.child_stats;
136 auto &expr = input.expr;
137 D_ASSERT(child_stats.size() == 2);
138 // can only propagate stats if the children have stats
139 auto &lstats = child_stats[0];
140 auto &rstats = child_stats[1];
141 Value new_min, new_max;
142 bool potential_overflow = true;
143 if (NumericStats::HasMinMax(stats: lstats) && NumericStats::HasMinMax(stats: rstats)) {
144 switch (expr.return_type.InternalType()) {
145 case PhysicalType::INT8:
146 potential_overflow =
147 PROPAGATE::template Operation<int8_t, OP>(expr.return_type, lstats, rstats, new_min, new_max);
148 break;
149 case PhysicalType::INT16:
150 potential_overflow =
151 PROPAGATE::template Operation<int16_t, OP>(expr.return_type, lstats, rstats, new_min, new_max);
152 break;
153 case PhysicalType::INT32:
154 potential_overflow =
155 PROPAGATE::template Operation<int32_t, OP>(expr.return_type, lstats, rstats, new_min, new_max);
156 break;
157 case PhysicalType::INT64:
158 potential_overflow =
159 PROPAGATE::template Operation<int64_t, OP>(expr.return_type, lstats, rstats, new_min, new_max);
160 break;
161 default:
162 return nullptr;
163 }
164 }
165 if (potential_overflow) {
166 new_min = Value(expr.return_type);
167 new_max = Value(expr.return_type);
168 } else {
169 // no potential overflow: replace with non-overflowing operator
170 if (input.bind_data) {
171 auto &bind_data = input.bind_data->Cast<DecimalArithmeticBindData>();
172 bind_data.check_overflow = false;
173 }
174 expr.function.function = GetScalarIntegerFunction<BASEOP>(expr.return_type.InternalType());
175 }
176 auto result = NumericStats::CreateEmpty(type: expr.return_type);
177 NumericStats::SetMin(stats&: result, val: new_min);
178 NumericStats::SetMax(stats&: result, val: new_max);
179 result.CombineValidity(left&: lstats, right&: rstats);
180 return result.ToUnique();
181}
182
183template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false>
184unique_ptr<FunctionData> BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function,
185 vector<unique_ptr<Expression>> &arguments) {
186
187 auto bind_data = make_uniq<DecimalArithmeticBindData>();
188
189 // get the max width and scale of the input arguments
190 uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0;
191 for (idx_t i = 0; i < arguments.size(); i++) {
192 if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) {
193 continue;
194 }
195 uint8_t width, scale;
196 auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale);
197 if (!can_convert) {
198 throw InternalException("Could not convert type %s to a decimal.", arguments[i]->return_type.ToString());
199 }
200 max_width = MaxValue<uint8_t>(a: width, b: max_width);
201 max_scale = MaxValue<uint8_t>(a: scale, b: max_scale);
202 max_width_over_scale = MaxValue<uint8_t>(a: width - scale, b: max_width_over_scale);
203 }
204 D_ASSERT(max_width > 0);
205 // for addition/subtraction, we add 1 to the width to ensure we don't overflow
206 auto required_width = MaxValue<uint8_t>(a: max_scale + max_width_over_scale, b: max_width) + 1;
207 if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) {
208 // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty
209 bind_data->check_overflow = true;
210 required_width = Decimal::MAX_WIDTH_INT64;
211 }
212 if (required_width > Decimal::MAX_WIDTH_DECIMAL) {
213 // target width does not fit in decimal at all: truncate the scale and perform overflow detection
214 bind_data->check_overflow = true;
215 required_width = Decimal::MAX_WIDTH_DECIMAL;
216 }
217 // arithmetic between two decimal arguments: check the types of the input arguments
218 LogicalType result_type = LogicalType::DECIMAL(width: required_width, scale: max_scale);
219 // we cast all input types to the specified type
220 for (idx_t i = 0; i < arguments.size(); i++) {
221 // first check if the cast is necessary
222 // if the argument has a matching scale and internal type as the output type, no casting is necessary
223 auto &argument_type = arguments[i]->return_type;
224 uint8_t width, scale;
225 argument_type.GetDecimalProperties(width, scale);
226 if (scale == DecimalType::GetScale(type: result_type) && argument_type.InternalType() == result_type.InternalType()) {
227 bound_function.arguments[i] = argument_type;
228 } else {
229 bound_function.arguments[i] = result_type;
230 }
231 }
232 bound_function.return_type = result_type;
233 // now select the physical function to execute
234 if (bind_data->check_overflow) {
235 bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(result_type.InternalType());
236 } else {
237 bound_function.function = GetScalarBinaryFunction<OP>(result_type.InternalType());
238 }
239 if (result_type.InternalType() != PhysicalType::INT128) {
240 if (IS_SUBTRACT) {
241 bound_function.statistics =
242 PropagateNumericStats<TryDecimalSubtract, SubtractPropagateStatistics, SubtractOperator>;
243 } else {
244 bound_function.statistics = PropagateNumericStats<TryDecimalAdd, AddPropagateStatistics, AddOperator>;
245 }
246 }
247 return std::move(bind_data);
248}
249
250static void SerializeDecimalArithmetic(FieldWriter &writer, const FunctionData *bind_data_p,
251 const ScalarFunction &function) {
252 auto &bind_data = bind_data_p->Cast<DecimalArithmeticBindData>();
253 writer.WriteField(element: bind_data.check_overflow);
254 writer.WriteSerializable(element: function.return_type);
255 writer.WriteRegularSerializableList(elements: function.arguments);
256}
257
258// TODO this is partially duplicated from the bind
259template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false>
260unique_ptr<FunctionData> DeserializeDecimalArithmetic(PlanDeserializationState &state, FieldReader &reader,
261 ScalarFunction &bound_function) {
262 // re-change the function pointers
263 auto check_overflow = reader.ReadRequired<bool>();
264 auto return_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>();
265 auto arguments = reader.template ReadRequiredSerializableList<LogicalType, LogicalType>();
266
267 if (check_overflow) {
268 bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(return_type.InternalType());
269 } else {
270 bound_function.function = GetScalarBinaryFunction<OP>(return_type.InternalType());
271 }
272 bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again
273 bound_function.return_type = return_type;
274 bound_function.arguments = arguments;
275
276 auto bind_data = make_uniq<DecimalArithmeticBindData>();
277 bind_data->check_overflow = check_overflow;
278 return std::move(bind_data);
279}
280
281unique_ptr<FunctionData> NopDecimalBind(ClientContext &context, ScalarFunction &bound_function,
282 vector<unique_ptr<Expression>> &arguments) {
283 bound_function.return_type = arguments[0]->return_type;
284 bound_function.arguments[0] = arguments[0]->return_type;
285 return nullptr;
286}
287
288ScalarFunction AddFun::GetFunction(const LogicalType &type) {
289 D_ASSERT(type.IsNumeric());
290 if (type.id() == LogicalTypeId::DECIMAL) {
291 return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction, NopDecimalBind);
292 } else {
293 return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction);
294 }
295}
296
297ScalarFunction AddFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) {
298 if (left_type.IsNumeric() && left_type.id() == right_type.id()) {
299 if (left_type.id() == LogicalTypeId::DECIMAL) {
300 auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr,
301 BindDecimalAddSubtract<AddOperator, DecimalAddOverflowCheck>);
302 function.serialize = SerializeDecimalArithmetic;
303 function.deserialize = DeserializeDecimalArithmetic<AddOperator, DecimalAddOverflowCheck>;
304 return function;
305 } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) {
306 return ScalarFunction("+", {left_type, right_type}, left_type,
307 GetScalarIntegerFunction<AddOperatorOverflowCheck>(type: left_type.InternalType()), nullptr,
308 nullptr, PropagateNumericStats<TryAddOperator, AddPropagateStatistics, AddOperator>);
309 } else {
310 return ScalarFunction("+", {left_type, right_type}, left_type,
311 GetScalarBinaryFunction<AddOperator>(type: left_type.InternalType()));
312 }
313 }
314
315 switch (left_type.id()) {
316 case LogicalTypeId::DATE:
317 if (right_type.id() == LogicalTypeId::INTEGER) {
318 return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE,
319 ScalarFunction::BinaryFunction<date_t, int32_t, date_t, AddOperator>);
320 } else if (right_type.id() == LogicalTypeId::INTERVAL) {
321 return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE,
322 ScalarFunction::BinaryFunction<date_t, interval_t, date_t, AddOperator>);
323 } else if (right_type.id() == LogicalTypeId::TIME) {
324 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP,
325 ScalarFunction::BinaryFunction<date_t, dtime_t, timestamp_t, AddOperator>);
326 }
327 break;
328 case LogicalTypeId::INTEGER:
329 if (right_type.id() == LogicalTypeId::DATE) {
330 return ScalarFunction("+", {left_type, right_type}, right_type,
331 ScalarFunction::BinaryFunction<int32_t, date_t, date_t, AddOperator>);
332 }
333 break;
334 case LogicalTypeId::INTERVAL:
335 if (right_type.id() == LogicalTypeId::INTERVAL) {
336 return ScalarFunction("+", {left_type, right_type}, LogicalType::INTERVAL,
337 ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, AddOperator>);
338 } else if (right_type.id() == LogicalTypeId::DATE) {
339 return ScalarFunction("+", {left_type, right_type}, LogicalType::DATE,
340 ScalarFunction::BinaryFunction<interval_t, date_t, date_t, AddOperator>);
341 } else if (right_type.id() == LogicalTypeId::TIME) {
342 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME,
343 ScalarFunction::BinaryFunction<interval_t, dtime_t, dtime_t, AddTimeOperator>);
344 } else if (right_type.id() == LogicalTypeId::TIMESTAMP) {
345 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP,
346 ScalarFunction::BinaryFunction<interval_t, timestamp_t, timestamp_t, AddOperator>);
347 }
348 break;
349 case LogicalTypeId::TIME:
350 if (right_type.id() == LogicalTypeId::INTERVAL) {
351 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIME,
352 ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, AddTimeOperator>);
353 } else if (right_type.id() == LogicalTypeId::DATE) {
354 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP,
355 ScalarFunction::BinaryFunction<dtime_t, date_t, timestamp_t, AddOperator>);
356 }
357 break;
358 case LogicalTypeId::TIMESTAMP:
359 if (right_type.id() == LogicalTypeId::INTERVAL) {
360 return ScalarFunction("+", {left_type, right_type}, LogicalType::TIMESTAMP,
361 ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, AddOperator>);
362 }
363 break;
364 default:
365 break;
366 }
367 // LCOV_EXCL_START
368 throw NotImplementedException("AddFun for types %s, %s", EnumUtil::ToString(value: left_type.id()),
369 EnumUtil::ToString(value: right_type.id()));
370 // LCOV_EXCL_STOP
371}
372
373void AddFun::RegisterFunction(BuiltinFunctions &set) {
374 ScalarFunctionSet functions("+");
375 for (auto &type : LogicalType::Numeric()) {
376 // unary add function is a nop, but only exists for numeric types
377 functions.AddFunction(function: GetFunction(type));
378 // binary add function adds two numbers together
379 functions.AddFunction(function: GetFunction(left_type: type, right_type: type));
380 }
381 // we can add integers to dates
382 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER));
383 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTEGER, right_type: LogicalType::DATE));
384 // we can add intervals together
385 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL));
386 // we can add intervals to dates/times/timestamps
387 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL));
388 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::DATE));
389
390 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL));
391 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIME));
392
393 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL));
394 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIMESTAMP));
395
396 // we can add times to dates
397 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::DATE));
398 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::TIME));
399
400 // we can add lists together
401 functions.AddFunction(function: ListConcatFun::GetFunction());
402
403 set.AddFunction(set: functions);
404
405 functions.name = "add";
406 set.AddFunction(set: functions);
407}
408
409//===--------------------------------------------------------------------===//
410// - [subtract]
411//===--------------------------------------------------------------------===//
412struct NegateOperator {
413 template <class T>
414 static bool CanNegate(T input) {
415 using Limits = std::numeric_limits<T>;
416 return !(Limits::is_integer && Limits::is_signed && Limits::lowest() == input);
417 }
418
419 template <class TA, class TR>
420 static inline TR Operation(TA input) {
421 auto cast = (TR)input;
422 if (!CanNegate<TR>(cast)) {
423 throw OutOfRangeException("Overflow in negation of integer!");
424 }
425 return -cast;
426 }
427};
428
429template <>
430bool NegateOperator::CanNegate(float input) {
431 return true;
432}
433
434template <>
435bool NegateOperator::CanNegate(double input) {
436 return true;
437}
438
439template <>
440interval_t NegateOperator::Operation(interval_t input) {
441 interval_t result;
442 result.months = NegateOperator::Operation<int32_t, int32_t>(input: input.months);
443 result.days = NegateOperator::Operation<int32_t, int32_t>(input: input.days);
444 result.micros = NegateOperator::Operation<int64_t, int64_t>(input: input.micros);
445 return result;
446}
447
448struct DecimalNegateBindData : public FunctionData {
449 DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) {
450 }
451
452 unique_ptr<FunctionData> Copy() const override {
453 auto res = make_uniq<DecimalNegateBindData>();
454 res->bound_type = bound_type;
455 return std::move(res);
456 }
457
458 bool Equals(const FunctionData &other_p) const override {
459 auto other = other_p.Cast<DecimalNegateBindData>();
460 return other.bound_type == bound_type;
461 }
462
463 LogicalTypeId bound_type;
464};
465
466unique_ptr<FunctionData> DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function,
467 vector<unique_ptr<Expression>> &arguments) {
468
469 auto bind_data = make_uniq<DecimalNegateBindData>();
470
471 auto &decimal_type = arguments[0]->return_type;
472 auto width = DecimalType::GetWidth(type: decimal_type);
473 if (width <= Decimal::MAX_WIDTH_INT16) {
474 bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::SMALLINT);
475 } else if (width <= Decimal::MAX_WIDTH_INT32) {
476 bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::INTEGER);
477 } else if (width <= Decimal::MAX_WIDTH_INT64) {
478 bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::BIGINT);
479 } else {
480 D_ASSERT(width <= Decimal::MAX_WIDTH_INT128);
481 bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::HUGEINT);
482 }
483 decimal_type.Verify();
484 bound_function.arguments[0] = decimal_type;
485 bound_function.return_type = decimal_type;
486 return nullptr;
487}
488
489struct NegatePropagateStatistics {
490 template <class T>
491 static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) {
492 auto max_value = NumericStats::GetMax<T>(istats);
493 auto min_value = NumericStats::GetMin<T>(istats);
494 if (!NegateOperator::CanNegate<T>(min_value) || !NegateOperator::CanNegate<T>(max_value)) {
495 return true;
496 }
497 // new min is -max
498 new_min = Value::Numeric(type, NegateOperator::Operation<T, T>(max_value));
499 // new max is -min
500 new_max = Value::Numeric(type, NegateOperator::Operation<T, T>(min_value));
501 return false;
502 }
503};
504
505static unique_ptr<BaseStatistics> NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) {
506 auto &child_stats = input.child_stats;
507 auto &expr = input.expr;
508 D_ASSERT(child_stats.size() == 1);
509 // can only propagate stats if the children have stats
510 auto &istats = child_stats[0];
511 Value new_min, new_max;
512 bool potential_overflow = true;
513 if (NumericStats::HasMinMax(stats: istats)) {
514 switch (expr.return_type.InternalType()) {
515 case PhysicalType::INT8:
516 potential_overflow =
517 NegatePropagateStatistics::Operation<int8_t>(type: expr.return_type, istats, new_min, new_max);
518 break;
519 case PhysicalType::INT16:
520 potential_overflow =
521 NegatePropagateStatistics::Operation<int16_t>(type: expr.return_type, istats, new_min, new_max);
522 break;
523 case PhysicalType::INT32:
524 potential_overflow =
525 NegatePropagateStatistics::Operation<int32_t>(type: expr.return_type, istats, new_min, new_max);
526 break;
527 case PhysicalType::INT64:
528 potential_overflow =
529 NegatePropagateStatistics::Operation<int64_t>(type: expr.return_type, istats, new_min, new_max);
530 break;
531 default:
532 return nullptr;
533 }
534 }
535 if (potential_overflow) {
536 new_min = Value(expr.return_type);
537 new_max = Value(expr.return_type);
538 }
539 auto stats = NumericStats::CreateEmpty(type: expr.return_type);
540 NumericStats::SetMin(stats, val: new_min);
541 NumericStats::SetMax(stats, val: new_max);
542 stats.CopyValidity(stats&: istats);
543 return stats.ToUnique();
544}
545
546ScalarFunction SubtractFun::GetFunction(const LogicalType &type) {
547 if (type.id() == LogicalTypeId::INTERVAL) {
548 return ScalarFunction("-", {type}, type, ScalarFunction::UnaryFunction<interval_t, interval_t, NegateOperator>);
549 } else if (type.id() == LogicalTypeId::DECIMAL) {
550 return ScalarFunction("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics);
551 } else {
552 D_ASSERT(type.IsNumeric());
553 return ScalarFunction("-", {type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type), nullptr,
554 nullptr, NegateBindStatistics);
555 }
556}
557
558ScalarFunction SubtractFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) {
559 if (left_type.IsNumeric() && left_type.id() == right_type.id()) {
560 if (left_type.id() == LogicalTypeId::DECIMAL) {
561 auto function =
562 ScalarFunction("-", {left_type, right_type}, left_type, nullptr,
563 BindDecimalAddSubtract<SubtractOperator, DecimalSubtractOverflowCheck, true>);
564 function.serialize = SerializeDecimalArithmetic;
565 function.deserialize = DeserializeDecimalArithmetic<SubtractOperator, DecimalSubtractOverflowCheck>;
566 return function;
567 } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) {
568 return ScalarFunction(
569 "-", {left_type, right_type}, left_type,
570 GetScalarIntegerFunction<SubtractOperatorOverflowCheck>(type: left_type.InternalType()), nullptr, nullptr,
571 PropagateNumericStats<TrySubtractOperator, SubtractPropagateStatistics, SubtractOperator>);
572
573 } else {
574 return ScalarFunction("-", {left_type, right_type}, left_type,
575 GetScalarBinaryFunction<SubtractOperator>(type: left_type.InternalType()));
576 }
577 }
578
579 switch (left_type.id()) {
580 case LogicalTypeId::DATE:
581 if (right_type.id() == LogicalTypeId::DATE) {
582 return ScalarFunction("-", {left_type, right_type}, LogicalType::BIGINT,
583 ScalarFunction::BinaryFunction<date_t, date_t, int64_t, SubtractOperator>);
584
585 } else if (right_type.id() == LogicalTypeId::INTEGER) {
586 return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE,
587 ScalarFunction::BinaryFunction<date_t, int32_t, date_t, SubtractOperator>);
588 } else if (right_type.id() == LogicalTypeId::INTERVAL) {
589 return ScalarFunction("-", {left_type, right_type}, LogicalType::DATE,
590 ScalarFunction::BinaryFunction<date_t, interval_t, date_t, SubtractOperator>);
591 }
592 break;
593 case LogicalTypeId::TIMESTAMP:
594 if (right_type.id() == LogicalTypeId::TIMESTAMP) {
595 return ScalarFunction(
596 "-", {left_type, right_type}, LogicalType::INTERVAL,
597 ScalarFunction::BinaryFunction<timestamp_t, timestamp_t, interval_t, SubtractOperator>);
598 } else if (right_type.id() == LogicalTypeId::INTERVAL) {
599 return ScalarFunction(
600 "-", {left_type, right_type}, LogicalType::TIMESTAMP,
601 ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, SubtractOperator>);
602 }
603 break;
604 case LogicalTypeId::INTERVAL:
605 if (right_type.id() == LogicalTypeId::INTERVAL) {
606 return ScalarFunction("-", {left_type, right_type}, LogicalType::INTERVAL,
607 ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, SubtractOperator>);
608 }
609 break;
610 case LogicalTypeId::TIME:
611 if (right_type.id() == LogicalTypeId::INTERVAL) {
612 return ScalarFunction("-", {left_type, right_type}, LogicalType::TIME,
613 ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, SubtractTimeOperator>);
614 }
615 break;
616 default:
617 break;
618 }
619 // LCOV_EXCL_START
620 throw NotImplementedException("SubtractFun for types %s, %s", EnumUtil::ToString(value: left_type.id()),
621 EnumUtil::ToString(value: right_type.id()));
622 // LCOV_EXCL_STOP
623}
624
625void SubtractFun::RegisterFunction(BuiltinFunctions &set) {
626 ScalarFunctionSet functions("-");
627 for (auto &type : LogicalType::Numeric()) {
628 // unary subtract function, negates the input (i.e. multiplies by -1)
629 functions.AddFunction(function: GetFunction(type));
630 // binary subtract function "a - b", subtracts b from a
631 functions.AddFunction(function: GetFunction(left_type: type, right_type: type));
632 }
633 // we can subtract dates from each other
634 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::DATE));
635 // we can subtract integers from dates
636 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER));
637 // we can subtract timestamps from each other
638 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::TIMESTAMP));
639 // we can subtract intervals from each other
640 functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL));
641 // we can subtract intervals from dates/times/timestamps, but not the other way around
642 functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL));
643 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL));
644 functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL));
645 // we can negate intervals
646 functions.AddFunction(function: GetFunction(type: LogicalType::INTERVAL));
647 set.AddFunction(set: functions);
648
649 functions.name = "subtract";
650 set.AddFunction(set: functions);
651}
652
653//===--------------------------------------------------------------------===//
654// * [multiply]
655//===--------------------------------------------------------------------===//
656struct MultiplyPropagateStatistics {
657 template <class T, class OP>
658 static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min,
659 Value &new_max) {
660 // statistics propagation on the multiplication is slightly less straightforward because of negative numbers
661 // the new min/max depend on the signs of the input types
662 // if both are positive the result is [lmin * rmin][lmax * rmax]
663 // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin]
664 // etc
665 // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax
666 // and check what the minimum/maximum value is
667 T lvals[] {NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(lstats)};
668 T rvals[] {NumericStats::GetMin<T>(rstats), NumericStats::GetMax<T>(rstats)};
669 T min = NumericLimits<T>::Maximum();
670 T max = NumericLimits<T>::Minimum();
671 // multiplications
672 for (idx_t l = 0; l < 2; l++) {
673 for (idx_t r = 0; r < 2; r++) {
674 T result;
675 if (!OP::Operation(lvals[l], rvals[r], result)) {
676 // potential overflow
677 return true;
678 }
679 if (result < min) {
680 min = result;
681 }
682 if (result > max) {
683 max = result;
684 }
685 }
686 }
687 new_min = Value::Numeric(type, min);
688 new_max = Value::Numeric(type, max);
689 return false;
690 }
691};
692
693unique_ptr<FunctionData> BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function,
694 vector<unique_ptr<Expression>> &arguments) {
695
696 auto bind_data = make_uniq<DecimalArithmeticBindData>();
697
698 uint8_t result_width = 0, result_scale = 0;
699 uint8_t max_width = 0;
700 for (idx_t i = 0; i < arguments.size(); i++) {
701 if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) {
702 continue;
703 }
704 uint8_t width, scale;
705 auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale);
706 if (!can_convert) {
707 throw InternalException("Could not convert type %s to a decimal?", arguments[i]->return_type.ToString());
708 }
709 if (width > max_width) {
710 max_width = width;
711 }
712 result_width += width;
713 result_scale += scale;
714 }
715 D_ASSERT(max_width > 0);
716 if (result_scale > Decimal::MAX_WIDTH_DECIMAL) {
717 throw OutOfRangeException(
718 "Needed scale %d to accurately represent the multiplication result, but this is out of range of the "
719 "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, "
720 "or add an explicit cast to a decimal with a lower scale.",
721 result_scale, Decimal::MAX_WIDTH_DECIMAL);
722 }
723 if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 &&
724 result_scale < Decimal::MAX_WIDTH_INT64) {
725 bind_data->check_overflow = true;
726 result_width = Decimal::MAX_WIDTH_INT64;
727 }
728 if (result_width > Decimal::MAX_WIDTH_DECIMAL) {
729 bind_data->check_overflow = true;
730 result_width = Decimal::MAX_WIDTH_DECIMAL;
731 }
732 LogicalType result_type = LogicalType::DECIMAL(width: result_width, scale: result_scale);
733 // since our scale is the summation of our input scales, we do not need to cast to the result scale
734 // however, we might need to cast to the correct internal type
735 for (idx_t i = 0; i < arguments.size(); i++) {
736 auto &argument_type = arguments[i]->return_type;
737 if (argument_type.InternalType() == result_type.InternalType()) {
738 bound_function.arguments[i] = argument_type;
739 } else {
740 uint8_t width, scale;
741 if (!argument_type.GetDecimalProperties(width, scale)) {
742 scale = 0;
743 }
744
745 bound_function.arguments[i] = LogicalType::DECIMAL(width: result_width, scale);
746 }
747 }
748 result_type.Verify();
749 bound_function.return_type = result_type;
750 // now select the physical function to execute
751 if (bind_data->check_overflow) {
752 bound_function.function = GetScalarBinaryFunction<DecimalMultiplyOverflowCheck>(type: result_type.InternalType());
753 } else {
754 bound_function.function = GetScalarBinaryFunction<MultiplyOperator>(type: result_type.InternalType());
755 }
756 if (result_type.InternalType() != PhysicalType::INT128) {
757 bound_function.statistics =
758 PropagateNumericStats<TryDecimalMultiply, MultiplyPropagateStatistics, MultiplyOperator>;
759 }
760 return std::move(bind_data);
761}
762
763void MultiplyFun::RegisterFunction(BuiltinFunctions &set) {
764 ScalarFunctionSet functions("*");
765 for (auto &type : LogicalType::Numeric()) {
766 if (type.id() == LogicalTypeId::DECIMAL) {
767 ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply);
768 function.serialize = SerializeDecimalArithmetic;
769 function.deserialize = DeserializeDecimalArithmetic<MultiplyOperator, DecimalMultiplyOverflowCheck>;
770 functions.AddFunction(function);
771 } else if (TypeIsIntegral(type: type.InternalType()) && type.id() != LogicalTypeId::HUGEINT) {
772 functions.AddFunction(function: ScalarFunction(
773 {type, type}, type, GetScalarIntegerFunction<MultiplyOperatorOverflowCheck>(type: type.InternalType()),
774 nullptr, nullptr,
775 PropagateNumericStats<TryMultiplyOperator, MultiplyPropagateStatistics, MultiplyOperator>));
776 } else {
777 functions.AddFunction(
778 function: ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type: type.InternalType())));
779 }
780 }
781 functions.AddFunction(
782 function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL,
783 ScalarFunction::BinaryFunction<interval_t, int64_t, interval_t, MultiplyOperator>));
784 functions.AddFunction(
785 function: ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL,
786 ScalarFunction::BinaryFunction<int64_t, interval_t, interval_t, MultiplyOperator>));
787 set.AddFunction(set: functions);
788
789 functions.name = "multiply";
790 set.AddFunction(set: functions);
791}
792
793//===--------------------------------------------------------------------===//
794// / [divide]
795//===--------------------------------------------------------------------===//
796template <>
797float DivideOperator::Operation(float left, float right) {
798 auto result = left / right;
799 return result;
800}
801
802template <>
803double DivideOperator::Operation(double left, double right) {
804 auto result = left / right;
805 return result;
806}
807
808template <>
809hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) {
810 if (right.lower == 0 && right.upper == 0) {
811 throw InternalException("Hugeint division by zero!");
812 }
813 return left / right;
814}
815
816template <>
817interval_t DivideOperator::Operation(interval_t left, int64_t right) {
818 left.days /= right;
819 left.months /= right;
820 left.micros /= right;
821 return left;
822}
823
824struct BinaryNumericDivideWrapper {
825 template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE>
826 static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) {
827 if (left == NumericLimits<LEFT_TYPE>::Minimum() && right == -1) {
828 throw OutOfRangeException("Overflow in division of %d / %d", left, right);
829 } else if (right == 0) {
830 mask.SetInvalid(idx);
831 return left;
832 } else {
833 return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right);
834 }
835 }
836
837 static bool AddsNulls() {
838 return true;
839 }
840};
841
842struct BinaryZeroIsNullWrapper {
843 template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE>
844 static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) {
845 if (right == 0) {
846 mask.SetInvalid(idx);
847 return left;
848 } else {
849 return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right);
850 }
851 }
852
853 static bool AddsNulls() {
854 return true;
855 }
856};
857
858struct BinaryZeroIsNullHugeintWrapper {
859 template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE>
860 static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) {
861 if (right.upper == 0 && right.lower == 0) {
862 mask.SetInvalid(idx);
863 return left;
864 } else {
865 return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right);
866 }
867 }
868
869 static bool AddsNulls() {
870 return true;
871 }
872};
873
874template <class TA, class TB, class TC, class OP, class ZWRAPPER = BinaryZeroIsNullWrapper>
875static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) {
876 BinaryExecutor::Execute<TA, TB, TC, OP, ZWRAPPER>(input.data[0], input.data[1], result, input.size());
877}
878
879template <class OP>
880static scalar_function_t GetBinaryFunctionIgnoreZero(const LogicalType &type) {
881 switch (type.id()) {
882 case LogicalTypeId::TINYINT:
883 return BinaryScalarFunctionIgnoreZero<int8_t, int8_t, int8_t, OP, BinaryNumericDivideWrapper>;
884 case LogicalTypeId::SMALLINT:
885 return BinaryScalarFunctionIgnoreZero<int16_t, int16_t, int16_t, OP, BinaryNumericDivideWrapper>;
886 case LogicalTypeId::INTEGER:
887 return BinaryScalarFunctionIgnoreZero<int32_t, int32_t, int32_t, OP, BinaryNumericDivideWrapper>;
888 case LogicalTypeId::BIGINT:
889 return BinaryScalarFunctionIgnoreZero<int64_t, int64_t, int64_t, OP, BinaryNumericDivideWrapper>;
890 case LogicalTypeId::UTINYINT:
891 return BinaryScalarFunctionIgnoreZero<uint8_t, uint8_t, uint8_t, OP>;
892 case LogicalTypeId::USMALLINT:
893 return BinaryScalarFunctionIgnoreZero<uint16_t, uint16_t, uint16_t, OP>;
894 case LogicalTypeId::UINTEGER:
895 return BinaryScalarFunctionIgnoreZero<uint32_t, uint32_t, uint32_t, OP>;
896 case LogicalTypeId::UBIGINT:
897 return BinaryScalarFunctionIgnoreZero<uint64_t, uint64_t, uint64_t, OP>;
898 case LogicalTypeId::HUGEINT:
899 return BinaryScalarFunctionIgnoreZero<hugeint_t, hugeint_t, hugeint_t, OP, BinaryZeroIsNullHugeintWrapper>;
900 case LogicalTypeId::FLOAT:
901 return BinaryScalarFunctionIgnoreZero<float, float, float, OP>;
902 case LogicalTypeId::DOUBLE:
903 return BinaryScalarFunctionIgnoreZero<double, double, double, OP>;
904 default:
905 throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction");
906 }
907}
908
909void DivideFun::RegisterFunction(BuiltinFunctions &set) {
910 ScalarFunctionSet fp_divide("/");
911 fp_divide.AddFunction(function: ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT,
912 GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::FLOAT)));
913 fp_divide.AddFunction(function: ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE,
914 GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::DOUBLE)));
915 fp_divide.AddFunction(
916 function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL,
917 BinaryScalarFunctionIgnoreZero<interval_t, int64_t, interval_t, DivideOperator>));
918 set.AddFunction(set: fp_divide);
919
920 ScalarFunctionSet full_divide("//");
921 for (auto &type : LogicalType::Numeric()) {
922 if (type.id() == LogicalTypeId::DECIMAL) {
923 continue;
924 } else {
925 full_divide.AddFunction(
926 function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type)));
927 }
928 }
929 set.AddFunction(set: full_divide);
930
931 full_divide.name = "divide";
932 set.AddFunction(set: full_divide);
933}
934
935//===--------------------------------------------------------------------===//
936// % [modulo]
937//===--------------------------------------------------------------------===//
938template <>
939float ModuloOperator::Operation(float left, float right) {
940 D_ASSERT(right != 0);
941 auto result = std::fmod(x: left, y: right);
942 return result;
943}
944
945template <>
946double ModuloOperator::Operation(double left, double right) {
947 D_ASSERT(right != 0);
948 auto result = std::fmod(x: left, y: right);
949 return result;
950}
951
952template <>
953hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) {
954 if (right.lower == 0 && right.upper == 0) {
955 throw InternalException("Hugeint division by zero!");
956 }
957 return left % right;
958}
959
960void ModFun::RegisterFunction(BuiltinFunctions &set) {
961 ScalarFunctionSet functions("%");
962 for (auto &type : LogicalType::Numeric()) {
963 if (type.id() == LogicalTypeId::DECIMAL) {
964 continue;
965 } else {
966 functions.AddFunction(
967 function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type)));
968 }
969 }
970 set.AddFunction(set: functions);
971 functions.name = "mod";
972 set.AddFunction(set: functions);
973}
974
975} // namespace duckdb
976