1 | #include "duckdb/common/field_writer.hpp" |
2 | #include "duckdb/common/operator/add.hpp" |
3 | #include "duckdb/common/operator/multiply.hpp" |
4 | #include "duckdb/common/operator/numeric_binary_operators.hpp" |
5 | #include "duckdb/common/operator/subtract.hpp" |
6 | #include "duckdb/common/types/date.hpp" |
7 | #include "duckdb/common/types/decimal.hpp" |
8 | #include "duckdb/common/types/hugeint.hpp" |
9 | #include "duckdb/common/types/interval.hpp" |
10 | #include "duckdb/common/types/time.hpp" |
11 | #include "duckdb/common/types/timestamp.hpp" |
12 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
13 | #include "duckdb/common/enum_util.hpp" |
14 | #include "duckdb/function/scalar/operators.hpp" |
15 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
16 | #include "duckdb/function/scalar/nested_functions.hpp" |
17 | |
18 | #include <limits> |
19 | |
20 | namespace duckdb { |
21 | |
22 | template <class OP> |
23 | static scalar_function_t GetScalarIntegerFunction(PhysicalType type) { |
24 | scalar_function_t function; |
25 | switch (type) { |
26 | case PhysicalType::INT8: |
27 | function = &ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>; |
28 | break; |
29 | case PhysicalType::INT16: |
30 | function = &ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>; |
31 | break; |
32 | case PhysicalType::INT32: |
33 | function = &ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>; |
34 | break; |
35 | case PhysicalType::INT64: |
36 | function = &ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>; |
37 | break; |
38 | case PhysicalType::UINT8: |
39 | function = &ScalarFunction::BinaryFunction<uint8_t, uint8_t, uint8_t, OP>; |
40 | break; |
41 | case PhysicalType::UINT16: |
42 | function = &ScalarFunction::BinaryFunction<uint16_t, uint16_t, uint16_t, OP>; |
43 | break; |
44 | case PhysicalType::UINT32: |
45 | function = &ScalarFunction::BinaryFunction<uint32_t, uint32_t, uint32_t, OP>; |
46 | break; |
47 | case PhysicalType::UINT64: |
48 | function = &ScalarFunction::BinaryFunction<uint64_t, uint64_t, uint64_t, OP>; |
49 | break; |
50 | default: |
51 | throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction" ); |
52 | } |
53 | return function; |
54 | } |
55 | |
56 | template <class OP> |
57 | static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { |
58 | scalar_function_t function; |
59 | switch (type) { |
60 | case PhysicalType::INT128: |
61 | function = &ScalarFunction::BinaryFunction<hugeint_t, hugeint_t, hugeint_t, OP>; |
62 | break; |
63 | case PhysicalType::FLOAT: |
64 | function = &ScalarFunction::BinaryFunction<float, float, float, OP>; |
65 | break; |
66 | case PhysicalType::DOUBLE: |
67 | function = &ScalarFunction::BinaryFunction<double, double, double, OP>; |
68 | break; |
69 | default: |
70 | function = GetScalarIntegerFunction<OP>(type); |
71 | break; |
72 | } |
73 | return function; |
74 | } |
75 | |
76 | //===--------------------------------------------------------------------===// |
77 | // + [add] |
78 | //===--------------------------------------------------------------------===// |
79 | struct AddPropagateStatistics { |
80 | template <class T, class OP> |
81 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
82 | Value &new_max) { |
83 | T min, max; |
84 | // new min is min+min |
85 | if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMin<T>(rstats), min)) { |
86 | return true; |
87 | } |
88 | // new max is max+max |
89 | if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMax<T>(rstats), max)) { |
90 | return true; |
91 | } |
92 | new_min = Value::Numeric(type, min); |
93 | new_max = Value::Numeric(type, max); |
94 | return false; |
95 | } |
96 | }; |
97 | |
98 | struct SubtractPropagateStatistics { |
99 | template <class T, class OP> |
100 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
101 | Value &new_max) { |
102 | T min, max; |
103 | if (!OP::Operation(NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(rstats), min)) { |
104 | return true; |
105 | } |
106 | if (!OP::Operation(NumericStats::GetMax<T>(lstats), NumericStats::GetMin<T>(rstats), max)) { |
107 | return true; |
108 | } |
109 | new_min = Value::Numeric(type, min); |
110 | new_max = Value::Numeric(type, max); |
111 | return false; |
112 | } |
113 | }; |
114 | |
115 | struct DecimalArithmeticBindData : public FunctionData { |
116 | DecimalArithmeticBindData() : check_overflow(true) { |
117 | } |
118 | |
119 | unique_ptr<FunctionData> Copy() const override { |
120 | auto res = make_uniq<DecimalArithmeticBindData>(); |
121 | res->check_overflow = check_overflow; |
122 | return std::move(res); |
123 | } |
124 | |
125 | bool Equals(const FunctionData &other_p) const override { |
126 | auto other = other_p.Cast<DecimalArithmeticBindData>(); |
127 | return other.check_overflow == check_overflow; |
128 | } |
129 | |
130 | bool check_overflow; |
131 | }; |
132 | |
133 | template <class OP, class PROPAGATE, class BASEOP> |
134 | static unique_ptr<BaseStatistics> PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) { |
135 | auto &child_stats = input.child_stats; |
136 | auto &expr = input.expr; |
137 | D_ASSERT(child_stats.size() == 2); |
138 | // can only propagate stats if the children have stats |
139 | auto &lstats = child_stats[0]; |
140 | auto &rstats = child_stats[1]; |
141 | Value new_min, new_max; |
142 | bool potential_overflow = true; |
143 | if (NumericStats::HasMinMax(stats: lstats) && NumericStats::HasMinMax(stats: rstats)) { |
144 | switch (expr.return_type.InternalType()) { |
145 | case PhysicalType::INT8: |
146 | potential_overflow = |
147 | PROPAGATE::template Operation<int8_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
148 | break; |
149 | case PhysicalType::INT16: |
150 | potential_overflow = |
151 | PROPAGATE::template Operation<int16_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
152 | break; |
153 | case PhysicalType::INT32: |
154 | potential_overflow = |
155 | PROPAGATE::template Operation<int32_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
156 | break; |
157 | case PhysicalType::INT64: |
158 | potential_overflow = |
159 | PROPAGATE::template Operation<int64_t, OP>(expr.return_type, lstats, rstats, new_min, new_max); |
160 | break; |
161 | default: |
162 | return nullptr; |
163 | } |
164 | } |
165 | if (potential_overflow) { |
166 | new_min = Value(expr.return_type); |
167 | new_max = Value(expr.return_type); |
168 | } else { |
169 | // no potential overflow: replace with non-overflowing operator |
170 | if (input.bind_data) { |
171 | auto &bind_data = input.bind_data->Cast<DecimalArithmeticBindData>(); |
172 | bind_data.check_overflow = false; |
173 | } |
174 | expr.function.function = GetScalarIntegerFunction<BASEOP>(expr.return_type.InternalType()); |
175 | } |
176 | auto result = NumericStats::CreateEmpty(type: expr.return_type); |
177 | NumericStats::SetMin(stats&: result, val: new_min); |
178 | NumericStats::SetMax(stats&: result, val: new_max); |
179 | result.CombineValidity(left&: lstats, right&: rstats); |
180 | return result.ToUnique(); |
181 | } |
182 | |
183 | template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false> |
184 | unique_ptr<FunctionData> BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function, |
185 | vector<unique_ptr<Expression>> &arguments) { |
186 | |
187 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
188 | |
189 | // get the max width and scale of the input arguments |
190 | uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; |
191 | for (idx_t i = 0; i < arguments.size(); i++) { |
192 | if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { |
193 | continue; |
194 | } |
195 | uint8_t width, scale; |
196 | auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); |
197 | if (!can_convert) { |
198 | throw InternalException("Could not convert type %s to a decimal." , arguments[i]->return_type.ToString()); |
199 | } |
200 | max_width = MaxValue<uint8_t>(a: width, b: max_width); |
201 | max_scale = MaxValue<uint8_t>(a: scale, b: max_scale); |
202 | max_width_over_scale = MaxValue<uint8_t>(a: width - scale, b: max_width_over_scale); |
203 | } |
204 | D_ASSERT(max_width > 0); |
205 | // for addition/subtraction, we add 1 to the width to ensure we don't overflow |
206 | auto required_width = MaxValue<uint8_t>(a: max_scale + max_width_over_scale, b: max_width) + 1; |
207 | if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) { |
208 | // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty |
209 | bind_data->check_overflow = true; |
210 | required_width = Decimal::MAX_WIDTH_INT64; |
211 | } |
212 | if (required_width > Decimal::MAX_WIDTH_DECIMAL) { |
213 | // target width does not fit in decimal at all: truncate the scale and perform overflow detection |
214 | bind_data->check_overflow = true; |
215 | required_width = Decimal::MAX_WIDTH_DECIMAL; |
216 | } |
217 | // arithmetic between two decimal arguments: check the types of the input arguments |
218 | LogicalType result_type = LogicalType::DECIMAL(width: required_width, scale: max_scale); |
219 | // we cast all input types to the specified type |
220 | for (idx_t i = 0; i < arguments.size(); i++) { |
221 | // first check if the cast is necessary |
222 | // if the argument has a matching scale and internal type as the output type, no casting is necessary |
223 | auto &argument_type = arguments[i]->return_type; |
224 | uint8_t width, scale; |
225 | argument_type.GetDecimalProperties(width, scale); |
226 | if (scale == DecimalType::GetScale(type: result_type) && argument_type.InternalType() == result_type.InternalType()) { |
227 | bound_function.arguments[i] = argument_type; |
228 | } else { |
229 | bound_function.arguments[i] = result_type; |
230 | } |
231 | } |
232 | bound_function.return_type = result_type; |
233 | // now select the physical function to execute |
234 | if (bind_data->check_overflow) { |
235 | bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(result_type.InternalType()); |
236 | } else { |
237 | bound_function.function = GetScalarBinaryFunction<OP>(result_type.InternalType()); |
238 | } |
239 | if (result_type.InternalType() != PhysicalType::INT128) { |
240 | if (IS_SUBTRACT) { |
241 | bound_function.statistics = |
242 | PropagateNumericStats<TryDecimalSubtract, SubtractPropagateStatistics, SubtractOperator>; |
243 | } else { |
244 | bound_function.statistics = PropagateNumericStats<TryDecimalAdd, AddPropagateStatistics, AddOperator>; |
245 | } |
246 | } |
247 | return std::move(bind_data); |
248 | } |
249 | |
250 | static void SerializeDecimalArithmetic(FieldWriter &writer, const FunctionData *bind_data_p, |
251 | const ScalarFunction &function) { |
252 | auto &bind_data = bind_data_p->Cast<DecimalArithmeticBindData>(); |
253 | writer.WriteField(element: bind_data.check_overflow); |
254 | writer.WriteSerializable(element: function.return_type); |
255 | writer.WriteRegularSerializableList(elements: function.arguments); |
256 | } |
257 | |
258 | // TODO this is partially duplicated from the bind |
259 | template <class OP, class OPOVERFLOWCHECK, bool IS_SUBTRACT = false> |
260 | unique_ptr<FunctionData> DeserializeDecimalArithmetic(PlanDeserializationState &state, FieldReader &reader, |
261 | ScalarFunction &bound_function) { |
262 | // re-change the function pointers |
263 | auto check_overflow = reader.ReadRequired<bool>(); |
264 | auto return_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
265 | auto arguments = reader.template ReadRequiredSerializableList<LogicalType, LogicalType>(); |
266 | |
267 | if (check_overflow) { |
268 | bound_function.function = GetScalarBinaryFunction<OPOVERFLOWCHECK>(return_type.InternalType()); |
269 | } else { |
270 | bound_function.function = GetScalarBinaryFunction<OP>(return_type.InternalType()); |
271 | } |
272 | bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again |
273 | bound_function.return_type = return_type; |
274 | bound_function.arguments = arguments; |
275 | |
276 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
277 | bind_data->check_overflow = check_overflow; |
278 | return std::move(bind_data); |
279 | } |
280 | |
281 | unique_ptr<FunctionData> NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, |
282 | vector<unique_ptr<Expression>> &arguments) { |
283 | bound_function.return_type = arguments[0]->return_type; |
284 | bound_function.arguments[0] = arguments[0]->return_type; |
285 | return nullptr; |
286 | } |
287 | |
288 | ScalarFunction AddFun::GetFunction(const LogicalType &type) { |
289 | D_ASSERT(type.IsNumeric()); |
290 | if (type.id() == LogicalTypeId::DECIMAL) { |
291 | return ScalarFunction("+" , {type}, type, ScalarFunction::NopFunction, NopDecimalBind); |
292 | } else { |
293 | return ScalarFunction("+" , {type}, type, ScalarFunction::NopFunction); |
294 | } |
295 | } |
296 | |
297 | ScalarFunction AddFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { |
298 | if (left_type.IsNumeric() && left_type.id() == right_type.id()) { |
299 | if (left_type.id() == LogicalTypeId::DECIMAL) { |
300 | auto function = ScalarFunction("+" , {left_type, right_type}, left_type, nullptr, |
301 | BindDecimalAddSubtract<AddOperator, DecimalAddOverflowCheck>); |
302 | function.serialize = SerializeDecimalArithmetic; |
303 | function.deserialize = DeserializeDecimalArithmetic<AddOperator, DecimalAddOverflowCheck>; |
304 | return function; |
305 | } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) { |
306 | return ScalarFunction("+" , {left_type, right_type}, left_type, |
307 | GetScalarIntegerFunction<AddOperatorOverflowCheck>(type: left_type.InternalType()), nullptr, |
308 | nullptr, PropagateNumericStats<TryAddOperator, AddPropagateStatistics, AddOperator>); |
309 | } else { |
310 | return ScalarFunction("+" , {left_type, right_type}, left_type, |
311 | GetScalarBinaryFunction<AddOperator>(type: left_type.InternalType())); |
312 | } |
313 | } |
314 | |
315 | switch (left_type.id()) { |
316 | case LogicalTypeId::DATE: |
317 | if (right_type.id() == LogicalTypeId::INTEGER) { |
318 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
319 | ScalarFunction::BinaryFunction<date_t, int32_t, date_t, AddOperator>); |
320 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
321 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
322 | ScalarFunction::BinaryFunction<date_t, interval_t, date_t, AddOperator>); |
323 | } else if (right_type.id() == LogicalTypeId::TIME) { |
324 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
325 | ScalarFunction::BinaryFunction<date_t, dtime_t, timestamp_t, AddOperator>); |
326 | } |
327 | break; |
328 | case LogicalTypeId::INTEGER: |
329 | if (right_type.id() == LogicalTypeId::DATE) { |
330 | return ScalarFunction("+" , {left_type, right_type}, right_type, |
331 | ScalarFunction::BinaryFunction<int32_t, date_t, date_t, AddOperator>); |
332 | } |
333 | break; |
334 | case LogicalTypeId::INTERVAL: |
335 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
336 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::INTERVAL, |
337 | ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, AddOperator>); |
338 | } else if (right_type.id() == LogicalTypeId::DATE) { |
339 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::DATE, |
340 | ScalarFunction::BinaryFunction<interval_t, date_t, date_t, AddOperator>); |
341 | } else if (right_type.id() == LogicalTypeId::TIME) { |
342 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIME, |
343 | ScalarFunction::BinaryFunction<interval_t, dtime_t, dtime_t, AddTimeOperator>); |
344 | } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { |
345 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
346 | ScalarFunction::BinaryFunction<interval_t, timestamp_t, timestamp_t, AddOperator>); |
347 | } |
348 | break; |
349 | case LogicalTypeId::TIME: |
350 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
351 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIME, |
352 | ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, AddTimeOperator>); |
353 | } else if (right_type.id() == LogicalTypeId::DATE) { |
354 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
355 | ScalarFunction::BinaryFunction<dtime_t, date_t, timestamp_t, AddOperator>); |
356 | } |
357 | break; |
358 | case LogicalTypeId::TIMESTAMP: |
359 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
360 | return ScalarFunction("+" , {left_type, right_type}, LogicalType::TIMESTAMP, |
361 | ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, AddOperator>); |
362 | } |
363 | break; |
364 | default: |
365 | break; |
366 | } |
367 | // LCOV_EXCL_START |
368 | throw NotImplementedException("AddFun for types %s, %s" , EnumUtil::ToString(value: left_type.id()), |
369 | EnumUtil::ToString(value: right_type.id())); |
370 | // LCOV_EXCL_STOP |
371 | } |
372 | |
373 | void AddFun::RegisterFunction(BuiltinFunctions &set) { |
374 | ScalarFunctionSet functions("+" ); |
375 | for (auto &type : LogicalType::Numeric()) { |
376 | // unary add function is a nop, but only exists for numeric types |
377 | functions.AddFunction(function: GetFunction(type)); |
378 | // binary add function adds two numbers together |
379 | functions.AddFunction(function: GetFunction(left_type: type, right_type: type)); |
380 | } |
381 | // we can add integers to dates |
382 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER)); |
383 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTEGER, right_type: LogicalType::DATE)); |
384 | // we can add intervals together |
385 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL)); |
386 | // we can add intervals to dates/times/timestamps |
387 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL)); |
388 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::DATE)); |
389 | |
390 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL)); |
391 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIME)); |
392 | |
393 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL)); |
394 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::TIMESTAMP)); |
395 | |
396 | // we can add times to dates |
397 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::DATE)); |
398 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::TIME)); |
399 | |
400 | // we can add lists together |
401 | functions.AddFunction(function: ListConcatFun::GetFunction()); |
402 | |
403 | set.AddFunction(set: functions); |
404 | |
405 | functions.name = "add" ; |
406 | set.AddFunction(set: functions); |
407 | } |
408 | |
409 | //===--------------------------------------------------------------------===// |
410 | // - [subtract] |
411 | //===--------------------------------------------------------------------===// |
412 | struct NegateOperator { |
413 | template <class T> |
414 | static bool CanNegate(T input) { |
415 | using Limits = std::numeric_limits<T>; |
416 | return !(Limits::is_integer && Limits::is_signed && Limits::lowest() == input); |
417 | } |
418 | |
419 | template <class TA, class TR> |
420 | static inline TR Operation(TA input) { |
421 | auto cast = (TR)input; |
422 | if (!CanNegate<TR>(cast)) { |
423 | throw OutOfRangeException("Overflow in negation of integer!" ); |
424 | } |
425 | return -cast; |
426 | } |
427 | }; |
428 | |
429 | template <> |
430 | bool NegateOperator::CanNegate(float input) { |
431 | return true; |
432 | } |
433 | |
434 | template <> |
435 | bool NegateOperator::CanNegate(double input) { |
436 | return true; |
437 | } |
438 | |
439 | template <> |
440 | interval_t NegateOperator::Operation(interval_t input) { |
441 | interval_t result; |
442 | result.months = NegateOperator::Operation<int32_t, int32_t>(input: input.months); |
443 | result.days = NegateOperator::Operation<int32_t, int32_t>(input: input.days); |
444 | result.micros = NegateOperator::Operation<int64_t, int64_t>(input: input.micros); |
445 | return result; |
446 | } |
447 | |
448 | struct DecimalNegateBindData : public FunctionData { |
449 | DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) { |
450 | } |
451 | |
452 | unique_ptr<FunctionData> Copy() const override { |
453 | auto res = make_uniq<DecimalNegateBindData>(); |
454 | res->bound_type = bound_type; |
455 | return std::move(res); |
456 | } |
457 | |
458 | bool Equals(const FunctionData &other_p) const override { |
459 | auto other = other_p.Cast<DecimalNegateBindData>(); |
460 | return other.bound_type == bound_type; |
461 | } |
462 | |
463 | LogicalTypeId bound_type; |
464 | }; |
465 | |
466 | unique_ptr<FunctionData> DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, |
467 | vector<unique_ptr<Expression>> &arguments) { |
468 | |
469 | auto bind_data = make_uniq<DecimalNegateBindData>(); |
470 | |
471 | auto &decimal_type = arguments[0]->return_type; |
472 | auto width = DecimalType::GetWidth(type: decimal_type); |
473 | if (width <= Decimal::MAX_WIDTH_INT16) { |
474 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::SMALLINT); |
475 | } else if (width <= Decimal::MAX_WIDTH_INT32) { |
476 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::INTEGER); |
477 | } else if (width <= Decimal::MAX_WIDTH_INT64) { |
478 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::BIGINT); |
479 | } else { |
480 | D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); |
481 | bound_function.function = ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type: LogicalTypeId::HUGEINT); |
482 | } |
483 | decimal_type.Verify(); |
484 | bound_function.arguments[0] = decimal_type; |
485 | bound_function.return_type = decimal_type; |
486 | return nullptr; |
487 | } |
488 | |
489 | struct NegatePropagateStatistics { |
490 | template <class T> |
491 | static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) { |
492 | auto max_value = NumericStats::GetMax<T>(istats); |
493 | auto min_value = NumericStats::GetMin<T>(istats); |
494 | if (!NegateOperator::CanNegate<T>(min_value) || !NegateOperator::CanNegate<T>(max_value)) { |
495 | return true; |
496 | } |
497 | // new min is -max |
498 | new_min = Value::Numeric(type, NegateOperator::Operation<T, T>(max_value)); |
499 | // new max is -min |
500 | new_max = Value::Numeric(type, NegateOperator::Operation<T, T>(min_value)); |
501 | return false; |
502 | } |
503 | }; |
504 | |
505 | static unique_ptr<BaseStatistics> NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) { |
506 | auto &child_stats = input.child_stats; |
507 | auto &expr = input.expr; |
508 | D_ASSERT(child_stats.size() == 1); |
509 | // can only propagate stats if the children have stats |
510 | auto &istats = child_stats[0]; |
511 | Value new_min, new_max; |
512 | bool potential_overflow = true; |
513 | if (NumericStats::HasMinMax(stats: istats)) { |
514 | switch (expr.return_type.InternalType()) { |
515 | case PhysicalType::INT8: |
516 | potential_overflow = |
517 | NegatePropagateStatistics::Operation<int8_t>(type: expr.return_type, istats, new_min, new_max); |
518 | break; |
519 | case PhysicalType::INT16: |
520 | potential_overflow = |
521 | NegatePropagateStatistics::Operation<int16_t>(type: expr.return_type, istats, new_min, new_max); |
522 | break; |
523 | case PhysicalType::INT32: |
524 | potential_overflow = |
525 | NegatePropagateStatistics::Operation<int32_t>(type: expr.return_type, istats, new_min, new_max); |
526 | break; |
527 | case PhysicalType::INT64: |
528 | potential_overflow = |
529 | NegatePropagateStatistics::Operation<int64_t>(type: expr.return_type, istats, new_min, new_max); |
530 | break; |
531 | default: |
532 | return nullptr; |
533 | } |
534 | } |
535 | if (potential_overflow) { |
536 | new_min = Value(expr.return_type); |
537 | new_max = Value(expr.return_type); |
538 | } |
539 | auto stats = NumericStats::CreateEmpty(type: expr.return_type); |
540 | NumericStats::SetMin(stats, val: new_min); |
541 | NumericStats::SetMax(stats, val: new_max); |
542 | stats.CopyValidity(stats&: istats); |
543 | return stats.ToUnique(); |
544 | } |
545 | |
546 | ScalarFunction SubtractFun::GetFunction(const LogicalType &type) { |
547 | if (type.id() == LogicalTypeId::INTERVAL) { |
548 | return ScalarFunction("-" , {type}, type, ScalarFunction::UnaryFunction<interval_t, interval_t, NegateOperator>); |
549 | } else if (type.id() == LogicalTypeId::DECIMAL) { |
550 | return ScalarFunction("-" , {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); |
551 | } else { |
552 | D_ASSERT(type.IsNumeric()); |
553 | return ScalarFunction("-" , {type}, type, ScalarFunction::GetScalarUnaryFunction<NegateOperator>(type), nullptr, |
554 | nullptr, NegateBindStatistics); |
555 | } |
556 | } |
557 | |
558 | ScalarFunction SubtractFun::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { |
559 | if (left_type.IsNumeric() && left_type.id() == right_type.id()) { |
560 | if (left_type.id() == LogicalTypeId::DECIMAL) { |
561 | auto function = |
562 | ScalarFunction("-" , {left_type, right_type}, left_type, nullptr, |
563 | BindDecimalAddSubtract<SubtractOperator, DecimalSubtractOverflowCheck, true>); |
564 | function.serialize = SerializeDecimalArithmetic; |
565 | function.deserialize = DeserializeDecimalArithmetic<SubtractOperator, DecimalSubtractOverflowCheck>; |
566 | return function; |
567 | } else if (left_type.IsIntegral() && left_type.id() != LogicalTypeId::HUGEINT) { |
568 | return ScalarFunction( |
569 | "-" , {left_type, right_type}, left_type, |
570 | GetScalarIntegerFunction<SubtractOperatorOverflowCheck>(type: left_type.InternalType()), nullptr, nullptr, |
571 | PropagateNumericStats<TrySubtractOperator, SubtractPropagateStatistics, SubtractOperator>); |
572 | |
573 | } else { |
574 | return ScalarFunction("-" , {left_type, right_type}, left_type, |
575 | GetScalarBinaryFunction<SubtractOperator>(type: left_type.InternalType())); |
576 | } |
577 | } |
578 | |
579 | switch (left_type.id()) { |
580 | case LogicalTypeId::DATE: |
581 | if (right_type.id() == LogicalTypeId::DATE) { |
582 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::BIGINT, |
583 | ScalarFunction::BinaryFunction<date_t, date_t, int64_t, SubtractOperator>); |
584 | |
585 | } else if (right_type.id() == LogicalTypeId::INTEGER) { |
586 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::DATE, |
587 | ScalarFunction::BinaryFunction<date_t, int32_t, date_t, SubtractOperator>); |
588 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
589 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::DATE, |
590 | ScalarFunction::BinaryFunction<date_t, interval_t, date_t, SubtractOperator>); |
591 | } |
592 | break; |
593 | case LogicalTypeId::TIMESTAMP: |
594 | if (right_type.id() == LogicalTypeId::TIMESTAMP) { |
595 | return ScalarFunction( |
596 | "-" , {left_type, right_type}, LogicalType::INTERVAL, |
597 | ScalarFunction::BinaryFunction<timestamp_t, timestamp_t, interval_t, SubtractOperator>); |
598 | } else if (right_type.id() == LogicalTypeId::INTERVAL) { |
599 | return ScalarFunction( |
600 | "-" , {left_type, right_type}, LogicalType::TIMESTAMP, |
601 | ScalarFunction::BinaryFunction<timestamp_t, interval_t, timestamp_t, SubtractOperator>); |
602 | } |
603 | break; |
604 | case LogicalTypeId::INTERVAL: |
605 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
606 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::INTERVAL, |
607 | ScalarFunction::BinaryFunction<interval_t, interval_t, interval_t, SubtractOperator>); |
608 | } |
609 | break; |
610 | case LogicalTypeId::TIME: |
611 | if (right_type.id() == LogicalTypeId::INTERVAL) { |
612 | return ScalarFunction("-" , {left_type, right_type}, LogicalType::TIME, |
613 | ScalarFunction::BinaryFunction<dtime_t, interval_t, dtime_t, SubtractTimeOperator>); |
614 | } |
615 | break; |
616 | default: |
617 | break; |
618 | } |
619 | // LCOV_EXCL_START |
620 | throw NotImplementedException("SubtractFun for types %s, %s" , EnumUtil::ToString(value: left_type.id()), |
621 | EnumUtil::ToString(value: right_type.id())); |
622 | // LCOV_EXCL_STOP |
623 | } |
624 | |
625 | void SubtractFun::RegisterFunction(BuiltinFunctions &set) { |
626 | ScalarFunctionSet functions("-" ); |
627 | for (auto &type : LogicalType::Numeric()) { |
628 | // unary subtract function, negates the input (i.e. multiplies by -1) |
629 | functions.AddFunction(function: GetFunction(type)); |
630 | // binary subtract function "a - b", subtracts b from a |
631 | functions.AddFunction(function: GetFunction(left_type: type, right_type: type)); |
632 | } |
633 | // we can subtract dates from each other |
634 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::DATE)); |
635 | // we can subtract integers from dates |
636 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTEGER)); |
637 | // we can subtract timestamps from each other |
638 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::TIMESTAMP)); |
639 | // we can subtract intervals from each other |
640 | functions.AddFunction(function: GetFunction(left_type: LogicalType::INTERVAL, right_type: LogicalType::INTERVAL)); |
641 | // we can subtract intervals from dates/times/timestamps, but not the other way around |
642 | functions.AddFunction(function: GetFunction(left_type: LogicalType::DATE, right_type: LogicalType::INTERVAL)); |
643 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIME, right_type: LogicalType::INTERVAL)); |
644 | functions.AddFunction(function: GetFunction(left_type: LogicalType::TIMESTAMP, right_type: LogicalType::INTERVAL)); |
645 | // we can negate intervals |
646 | functions.AddFunction(function: GetFunction(type: LogicalType::INTERVAL)); |
647 | set.AddFunction(set: functions); |
648 | |
649 | functions.name = "subtract" ; |
650 | set.AddFunction(set: functions); |
651 | } |
652 | |
653 | //===--------------------------------------------------------------------===// |
654 | // * [multiply] |
655 | //===--------------------------------------------------------------------===// |
656 | struct MultiplyPropagateStatistics { |
657 | template <class T, class OP> |
658 | static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, |
659 | Value &new_max) { |
660 | // statistics propagation on the multiplication is slightly less straightforward because of negative numbers |
661 | // the new min/max depend on the signs of the input types |
662 | // if both are positive the result is [lmin * rmin][lmax * rmax] |
663 | // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin] |
664 | // etc |
665 | // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax |
666 | // and check what the minimum/maximum value is |
667 | T lvals[] {NumericStats::GetMin<T>(lstats), NumericStats::GetMax<T>(lstats)}; |
668 | T rvals[] {NumericStats::GetMin<T>(rstats), NumericStats::GetMax<T>(rstats)}; |
669 | T min = NumericLimits<T>::Maximum(); |
670 | T max = NumericLimits<T>::Minimum(); |
671 | // multiplications |
672 | for (idx_t l = 0; l < 2; l++) { |
673 | for (idx_t r = 0; r < 2; r++) { |
674 | T result; |
675 | if (!OP::Operation(lvals[l], rvals[r], result)) { |
676 | // potential overflow |
677 | return true; |
678 | } |
679 | if (result < min) { |
680 | min = result; |
681 | } |
682 | if (result > max) { |
683 | max = result; |
684 | } |
685 | } |
686 | } |
687 | new_min = Value::Numeric(type, min); |
688 | new_max = Value::Numeric(type, max); |
689 | return false; |
690 | } |
691 | }; |
692 | |
693 | unique_ptr<FunctionData> BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, |
694 | vector<unique_ptr<Expression>> &arguments) { |
695 | |
696 | auto bind_data = make_uniq<DecimalArithmeticBindData>(); |
697 | |
698 | uint8_t result_width = 0, result_scale = 0; |
699 | uint8_t max_width = 0; |
700 | for (idx_t i = 0; i < arguments.size(); i++) { |
701 | if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { |
702 | continue; |
703 | } |
704 | uint8_t width, scale; |
705 | auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); |
706 | if (!can_convert) { |
707 | throw InternalException("Could not convert type %s to a decimal?" , arguments[i]->return_type.ToString()); |
708 | } |
709 | if (width > max_width) { |
710 | max_width = width; |
711 | } |
712 | result_width += width; |
713 | result_scale += scale; |
714 | } |
715 | D_ASSERT(max_width > 0); |
716 | if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { |
717 | throw OutOfRangeException( |
718 | "Needed scale %d to accurately represent the multiplication result, but this is out of range of the " |
719 | "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, " |
720 | "or add an explicit cast to a decimal with a lower scale." , |
721 | result_scale, Decimal::MAX_WIDTH_DECIMAL); |
722 | } |
723 | if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 && |
724 | result_scale < Decimal::MAX_WIDTH_INT64) { |
725 | bind_data->check_overflow = true; |
726 | result_width = Decimal::MAX_WIDTH_INT64; |
727 | } |
728 | if (result_width > Decimal::MAX_WIDTH_DECIMAL) { |
729 | bind_data->check_overflow = true; |
730 | result_width = Decimal::MAX_WIDTH_DECIMAL; |
731 | } |
732 | LogicalType result_type = LogicalType::DECIMAL(width: result_width, scale: result_scale); |
733 | // since our scale is the summation of our input scales, we do not need to cast to the result scale |
734 | // however, we might need to cast to the correct internal type |
735 | for (idx_t i = 0; i < arguments.size(); i++) { |
736 | auto &argument_type = arguments[i]->return_type; |
737 | if (argument_type.InternalType() == result_type.InternalType()) { |
738 | bound_function.arguments[i] = argument_type; |
739 | } else { |
740 | uint8_t width, scale; |
741 | if (!argument_type.GetDecimalProperties(width, scale)) { |
742 | scale = 0; |
743 | } |
744 | |
745 | bound_function.arguments[i] = LogicalType::DECIMAL(width: result_width, scale); |
746 | } |
747 | } |
748 | result_type.Verify(); |
749 | bound_function.return_type = result_type; |
750 | // now select the physical function to execute |
751 | if (bind_data->check_overflow) { |
752 | bound_function.function = GetScalarBinaryFunction<DecimalMultiplyOverflowCheck>(type: result_type.InternalType()); |
753 | } else { |
754 | bound_function.function = GetScalarBinaryFunction<MultiplyOperator>(type: result_type.InternalType()); |
755 | } |
756 | if (result_type.InternalType() != PhysicalType::INT128) { |
757 | bound_function.statistics = |
758 | PropagateNumericStats<TryDecimalMultiply, MultiplyPropagateStatistics, MultiplyOperator>; |
759 | } |
760 | return std::move(bind_data); |
761 | } |
762 | |
763 | void MultiplyFun::RegisterFunction(BuiltinFunctions &set) { |
764 | ScalarFunctionSet functions("*" ); |
765 | for (auto &type : LogicalType::Numeric()) { |
766 | if (type.id() == LogicalTypeId::DECIMAL) { |
767 | ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); |
768 | function.serialize = SerializeDecimalArithmetic; |
769 | function.deserialize = DeserializeDecimalArithmetic<MultiplyOperator, DecimalMultiplyOverflowCheck>; |
770 | functions.AddFunction(function); |
771 | } else if (TypeIsIntegral(type: type.InternalType()) && type.id() != LogicalTypeId::HUGEINT) { |
772 | functions.AddFunction(function: ScalarFunction( |
773 | {type, type}, type, GetScalarIntegerFunction<MultiplyOperatorOverflowCheck>(type: type.InternalType()), |
774 | nullptr, nullptr, |
775 | PropagateNumericStats<TryMultiplyOperator, MultiplyPropagateStatistics, MultiplyOperator>)); |
776 | } else { |
777 | functions.AddFunction( |
778 | function: ScalarFunction({type, type}, type, GetScalarBinaryFunction<MultiplyOperator>(type: type.InternalType()))); |
779 | } |
780 | } |
781 | functions.AddFunction( |
782 | function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, |
783 | ScalarFunction::BinaryFunction<interval_t, int64_t, interval_t, MultiplyOperator>)); |
784 | functions.AddFunction( |
785 | function: ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL, |
786 | ScalarFunction::BinaryFunction<int64_t, interval_t, interval_t, MultiplyOperator>)); |
787 | set.AddFunction(set: functions); |
788 | |
789 | functions.name = "multiply" ; |
790 | set.AddFunction(set: functions); |
791 | } |
792 | |
793 | //===--------------------------------------------------------------------===// |
794 | // / [divide] |
795 | //===--------------------------------------------------------------------===// |
796 | template <> |
797 | float DivideOperator::Operation(float left, float right) { |
798 | auto result = left / right; |
799 | return result; |
800 | } |
801 | |
802 | template <> |
803 | double DivideOperator::Operation(double left, double right) { |
804 | auto result = left / right; |
805 | return result; |
806 | } |
807 | |
808 | template <> |
809 | hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { |
810 | if (right.lower == 0 && right.upper == 0) { |
811 | throw InternalException("Hugeint division by zero!" ); |
812 | } |
813 | return left / right; |
814 | } |
815 | |
816 | template <> |
817 | interval_t DivideOperator::Operation(interval_t left, int64_t right) { |
818 | left.days /= right; |
819 | left.months /= right; |
820 | left.micros /= right; |
821 | return left; |
822 | } |
823 | |
824 | struct BinaryNumericDivideWrapper { |
825 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
826 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
827 | if (left == NumericLimits<LEFT_TYPE>::Minimum() && right == -1) { |
828 | throw OutOfRangeException("Overflow in division of %d / %d" , left, right); |
829 | } else if (right == 0) { |
830 | mask.SetInvalid(idx); |
831 | return left; |
832 | } else { |
833 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
834 | } |
835 | } |
836 | |
837 | static bool AddsNulls() { |
838 | return true; |
839 | } |
840 | }; |
841 | |
842 | struct BinaryZeroIsNullWrapper { |
843 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
844 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
845 | if (right == 0) { |
846 | mask.SetInvalid(idx); |
847 | return left; |
848 | } else { |
849 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
850 | } |
851 | } |
852 | |
853 | static bool AddsNulls() { |
854 | return true; |
855 | } |
856 | }; |
857 | |
858 | struct BinaryZeroIsNullHugeintWrapper { |
859 | template <class FUNC, class OP, class LEFT_TYPE, class RIGHT_TYPE, class RESULT_TYPE> |
860 | static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { |
861 | if (right.upper == 0 && right.lower == 0) { |
862 | mask.SetInvalid(idx); |
863 | return left; |
864 | } else { |
865 | return OP::template Operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE>(left, right); |
866 | } |
867 | } |
868 | |
869 | static bool AddsNulls() { |
870 | return true; |
871 | } |
872 | }; |
873 | |
874 | template <class TA, class TB, class TC, class OP, class ZWRAPPER = BinaryZeroIsNullWrapper> |
875 | static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { |
876 | BinaryExecutor::Execute<TA, TB, TC, OP, ZWRAPPER>(input.data[0], input.data[1], result, input.size()); |
877 | } |
878 | |
879 | template <class OP> |
880 | static scalar_function_t GetBinaryFunctionIgnoreZero(const LogicalType &type) { |
881 | switch (type.id()) { |
882 | case LogicalTypeId::TINYINT: |
883 | return BinaryScalarFunctionIgnoreZero<int8_t, int8_t, int8_t, OP, BinaryNumericDivideWrapper>; |
884 | case LogicalTypeId::SMALLINT: |
885 | return BinaryScalarFunctionIgnoreZero<int16_t, int16_t, int16_t, OP, BinaryNumericDivideWrapper>; |
886 | case LogicalTypeId::INTEGER: |
887 | return BinaryScalarFunctionIgnoreZero<int32_t, int32_t, int32_t, OP, BinaryNumericDivideWrapper>; |
888 | case LogicalTypeId::BIGINT: |
889 | return BinaryScalarFunctionIgnoreZero<int64_t, int64_t, int64_t, OP, BinaryNumericDivideWrapper>; |
890 | case LogicalTypeId::UTINYINT: |
891 | return BinaryScalarFunctionIgnoreZero<uint8_t, uint8_t, uint8_t, OP>; |
892 | case LogicalTypeId::USMALLINT: |
893 | return BinaryScalarFunctionIgnoreZero<uint16_t, uint16_t, uint16_t, OP>; |
894 | case LogicalTypeId::UINTEGER: |
895 | return BinaryScalarFunctionIgnoreZero<uint32_t, uint32_t, uint32_t, OP>; |
896 | case LogicalTypeId::UBIGINT: |
897 | return BinaryScalarFunctionIgnoreZero<uint64_t, uint64_t, uint64_t, OP>; |
898 | case LogicalTypeId::HUGEINT: |
899 | return BinaryScalarFunctionIgnoreZero<hugeint_t, hugeint_t, hugeint_t, OP, BinaryZeroIsNullHugeintWrapper>; |
900 | case LogicalTypeId::FLOAT: |
901 | return BinaryScalarFunctionIgnoreZero<float, float, float, OP>; |
902 | case LogicalTypeId::DOUBLE: |
903 | return BinaryScalarFunctionIgnoreZero<double, double, double, OP>; |
904 | default: |
905 | throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction" ); |
906 | } |
907 | } |
908 | |
909 | void DivideFun::RegisterFunction(BuiltinFunctions &set) { |
910 | ScalarFunctionSet fp_divide("/" ); |
911 | fp_divide.AddFunction(function: ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, |
912 | GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::FLOAT))); |
913 | fp_divide.AddFunction(function: ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, |
914 | GetBinaryFunctionIgnoreZero<DivideOperator>(type: LogicalType::DOUBLE))); |
915 | fp_divide.AddFunction( |
916 | function: ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, |
917 | BinaryScalarFunctionIgnoreZero<interval_t, int64_t, interval_t, DivideOperator>)); |
918 | set.AddFunction(set: fp_divide); |
919 | |
920 | ScalarFunctionSet full_divide("//" ); |
921 | for (auto &type : LogicalType::Numeric()) { |
922 | if (type.id() == LogicalTypeId::DECIMAL) { |
923 | continue; |
924 | } else { |
925 | full_divide.AddFunction( |
926 | function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<DivideOperator>(type))); |
927 | } |
928 | } |
929 | set.AddFunction(set: full_divide); |
930 | |
931 | full_divide.name = "divide" ; |
932 | set.AddFunction(set: full_divide); |
933 | } |
934 | |
935 | //===--------------------------------------------------------------------===// |
936 | // % [modulo] |
937 | //===--------------------------------------------------------------------===// |
938 | template <> |
939 | float ModuloOperator::Operation(float left, float right) { |
940 | D_ASSERT(right != 0); |
941 | auto result = std::fmod(x: left, y: right); |
942 | return result; |
943 | } |
944 | |
945 | template <> |
946 | double ModuloOperator::Operation(double left, double right) { |
947 | D_ASSERT(right != 0); |
948 | auto result = std::fmod(x: left, y: right); |
949 | return result; |
950 | } |
951 | |
952 | template <> |
953 | hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { |
954 | if (right.lower == 0 && right.upper == 0) { |
955 | throw InternalException("Hugeint division by zero!" ); |
956 | } |
957 | return left % right; |
958 | } |
959 | |
960 | void ModFun::RegisterFunction(BuiltinFunctions &set) { |
961 | ScalarFunctionSet functions("%" ); |
962 | for (auto &type : LogicalType::Numeric()) { |
963 | if (type.id() == LogicalTypeId::DECIMAL) { |
964 | continue; |
965 | } else { |
966 | functions.AddFunction( |
967 | function: ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero<ModuloOperator>(type))); |
968 | } |
969 | } |
970 | set.AddFunction(set: functions); |
971 | functions.name = "mod" ; |
972 | set.AddFunction(set: functions); |
973 | } |
974 | |
975 | } // namespace duckdb |
976 | |