1#pragma once
2
3// Include this first, because `#define _asan_poison_address` from
4// llvm/Support/Compiler.h conflicts with its forward declaration in
5// sanitizer/asan_interface.h
6#include <Common/Arena.h>
7
8#include <DataTypes/DataTypesNumber.h>
9#include <DataTypes/DataTypesDecimal.h>
10#include <DataTypes/DataTypeDate.h>
11#include <DataTypes/DataTypeDateTime.h>
12#include <DataTypes/DataTypeDateTime64.h>
13#include <DataTypes/DataTypeInterval.h>
14#include <DataTypes/DataTypeAggregateFunction.h>
15#include <DataTypes/Native.h>
16#include <DataTypes/NumberTraits.h>
17#include <Columns/ColumnVector.h>
18#include <Columns/ColumnDecimal.h>
19#include <Columns/ColumnConst.h>
20#include <Columns/ColumnAggregateFunction.h>
21#include "IFunctionImpl.h"
22#include "FunctionHelpers.h"
23#include "intDiv.h"
24#include "castTypeToEither.h"
25#include "FunctionFactory.h"
26#include <Common/typeid_cast.h>
27#include <Common/assert_cast.h>
28#include <Common/config.h>
29
30#if USE_EMBEDDED_COMPILER
31#pragma GCC diagnostic push
32#pragma GCC diagnostic ignored "-Wunused-parameter"
33#include <llvm/IR/IRBuilder.h>
34#pragma GCC diagnostic pop
35#endif
36
37
38namespace DB
39{
40
41namespace ErrorCodes
42{
43 extern const int ILLEGAL_COLUMN;
44 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
45 extern const int LOGICAL_ERROR;
46 extern const int DECIMAL_OVERFLOW;
47 extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES;
48 extern const int ILLEGAL_DIVISION;
49}
50
51
52/** Arithmetic operations: +, -, *, /, %,
53 * intDiv (integer division)
54 * Bitwise operations: |, &, ^, ~.
55 * Etc.
56 */
57
58template <typename A, typename B, typename Op, typename ResultType_ = typename Op::ResultType>
59struct BinaryOperationImplBase
60{
61 using ResultType = ResultType_;
62
63 static void NO_INLINE vector_vector(const PaddedPODArray<A> & a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
64 {
65 size_t size = a.size();
66 for (size_t i = 0; i < size; ++i)
67 c[i] = Op::template apply<ResultType>(a[i], b[i]);
68 }
69
70 static void NO_INLINE vector_constant(const PaddedPODArray<A> & a, B b, PaddedPODArray<ResultType> & c)
71 {
72 size_t size = a.size();
73 for (size_t i = 0; i < size; ++i)
74 c[i] = Op::template apply<ResultType>(a[i], b);
75 }
76
77 static void NO_INLINE constant_vector(A a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
78 {
79 size_t size = b.size();
80 for (size_t i = 0; i < size; ++i)
81 c[i] = Op::template apply<ResultType>(a, b[i]);
82 }
83
84 static ResultType constant_constant(A a, B b)
85 {
86 return Op::template apply<ResultType>(a, b);
87 }
88};
89
90template <typename A, typename B, typename Op, typename ResultType = typename Op::ResultType>
91struct BinaryOperationImpl : BinaryOperationImplBase<A, B, Op, ResultType>
92{
93};
94
95
96template <typename, typename> struct PlusImpl;
97template <typename, typename> struct MinusImpl;
98template <typename, typename> struct MultiplyImpl;
99template <typename, typename> struct DivideFloatingImpl;
100template <typename, typename> struct DivideIntegralImpl;
101template <typename, typename> struct DivideIntegralOrZeroImpl;
102template <typename, typename> struct LeastBaseImpl;
103template <typename, typename> struct GreatestBaseImpl;
104template <typename, typename> struct ModuloImpl;
105
106
107/// Binary operations for Decimals need scale args
108/// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2);
109/// * no agrs scale. ScaleR = Scale1 + Scale2;
110/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::getScale()).
111template <typename A, typename B, template <typename, typename> typename Operation, typename ResultType_, bool _check_overflow = true>
112struct DecimalBinaryOperation
113{
114 static constexpr bool is_plus_minus = std::is_same_v<Operation<Int32, Int32>, PlusImpl<Int32, Int32>> ||
115 std::is_same_v<Operation<Int32, Int32>, MinusImpl<Int32, Int32>>;
116 static constexpr bool is_multiply = std::is_same_v<Operation<Int32, Int32>, MultiplyImpl<Int32, Int32>>;
117 static constexpr bool is_float_division = std::is_same_v<Operation<Int32, Int32>, DivideFloatingImpl<Int32, Int32>>;
118 static constexpr bool is_int_division = std::is_same_v<Operation<Int32, Int32>, DivideIntegralImpl<Int32, Int32>> ||
119 std::is_same_v<Operation<Int32, Int32>, DivideIntegralOrZeroImpl<Int32, Int32>>;
120 static constexpr bool is_division = is_float_division || is_int_division;
121 static constexpr bool is_compare = std::is_same_v<Operation<Int32, Int32>, LeastBaseImpl<Int32, Int32>> ||
122 std::is_same_v<Operation<Int32, Int32>, GreatestBaseImpl<Int32, Int32>>;
123 static constexpr bool is_plus_minus_compare = is_plus_minus || is_compare;
124 static constexpr bool can_overflow = is_plus_minus || is_multiply;
125
126 using ResultType = ResultType_;
127 using NativeResultType = typename NativeType<ResultType>::Type;
128 using Op = std::conditional_t<is_float_division,
129 DivideIntegralImpl<NativeResultType, NativeResultType>, /// substitute divide by intDiv (throw on division by zero)
130 Operation<NativeResultType, NativeResultType>>;
131 using ColVecA = std::conditional_t<IsDecimalNumber<A>, ColumnDecimal<A>, ColumnVector<A>>;
132 using ColVecB = std::conditional_t<IsDecimalNumber<B>, ColumnDecimal<B>, ColumnVector<B>>;
133 using ArrayA = typename ColVecA::Container;
134 using ArrayB = typename ColVecB::Container;
135 using ArrayC = typename ColumnDecimal<ResultType>::Container;
136 using SelfNoOverflow = DecimalBinaryOperation<A, B, Operation, ResultType_, false>;
137
138 static void vector_vector(const ArrayA & a, const ArrayB & b, ArrayC & c, ResultType scale_a, ResultType scale_b, bool check_overflow)
139 {
140 if (check_overflow)
141 vector_vector(a, b, c, scale_a, scale_b);
142 else
143 SelfNoOverflow::vector_vector(a, b, c, scale_a, scale_b);
144 }
145
146 static void vector_constant(const ArrayA & a, B b, ArrayC & c, ResultType scale_a, ResultType scale_b, bool check_overflow)
147 {
148 if (check_overflow)
149 vector_constant(a, b, c, scale_a, scale_b);
150 else
151 SelfNoOverflow::vector_constant(a, b, c, scale_a, scale_b);
152 }
153
154 static void constant_vector(A a, const ArrayB & b, ArrayC & c, ResultType scale_a, ResultType scale_b, bool check_overflow)
155 {
156 if (check_overflow)
157 constant_vector(a, b, c, scale_a, scale_b);
158 else
159 SelfNoOverflow::constant_vector(a, b, c, scale_a, scale_b);
160 }
161
162 static ResultType constant_constant(A a, B b, ResultType scale_a, ResultType scale_b, bool check_overflow)
163 {
164 if (check_overflow)
165 return constant_constant(a, b, scale_a, scale_b);
166 else
167 return SelfNoOverflow::constant_constant(a, b, scale_a, scale_b);
168 }
169
170 static void NO_INLINE vector_vector(const ArrayA & a, const ArrayB & b, ArrayC & c,
171 ResultType scale_a [[maybe_unused]], ResultType scale_b [[maybe_unused]])
172 {
173 size_t size = a.size();
174 if constexpr (is_plus_minus_compare)
175 {
176 if (scale_a != 1)
177 {
178 for (size_t i = 0; i < size; ++i)
179 c[i] = applyScaled<true>(a[i], b[i], scale_a);
180 return;
181 }
182 else if (scale_b != 1)
183 {
184 for (size_t i = 0; i < size; ++i)
185 c[i] = applyScaled<false>(a[i], b[i], scale_b);
186 return;
187 }
188 }
189 else if constexpr (is_division && IsDecimalNumber<B>)
190 {
191 for (size_t i = 0; i < size; ++i)
192 c[i] = applyScaledDiv(a[i], b[i], scale_a);
193 return;
194 }
195
196 /// default: use it if no return before
197 for (size_t i = 0; i < size; ++i)
198 c[i] = apply(a[i], b[i]);
199 }
200
201 static void NO_INLINE vector_constant(const ArrayA & a, B b, ArrayC & c,
202 ResultType scale_a [[maybe_unused]], ResultType scale_b [[maybe_unused]])
203 {
204 size_t size = a.size();
205 if constexpr (is_plus_minus_compare)
206 {
207 if (scale_a != 1)
208 {
209 for (size_t i = 0; i < size; ++i)
210 c[i] = applyScaled<true>(a[i], b, scale_a);
211 return;
212 }
213 else if (scale_b != 1)
214 {
215 for (size_t i = 0; i < size; ++i)
216 c[i] = applyScaled<false>(a[i], b, scale_b);
217 return;
218 }
219 }
220 else if constexpr (is_division && IsDecimalNumber<B>)
221 {
222 for (size_t i = 0; i < size; ++i)
223 c[i] = applyScaledDiv(a[i], b, scale_a);
224 return;
225 }
226
227 /// default: use it if no return before
228 for (size_t i = 0; i < size; ++i)
229 c[i] = apply(a[i], b);
230 }
231
232 static void NO_INLINE constant_vector(A a, const ArrayB & b, ArrayC & c,
233 ResultType scale_a [[maybe_unused]], ResultType scale_b [[maybe_unused]])
234 {
235 size_t size = b.size();
236 if constexpr (is_plus_minus_compare)
237 {
238 if (scale_a != 1)
239 {
240 for (size_t i = 0; i < size; ++i)
241 c[i] = applyScaled<true>(a, b[i], scale_a);
242 return;
243 }
244 else if (scale_b != 1)
245 {
246 for (size_t i = 0; i < size; ++i)
247 c[i] = applyScaled<false>(a, b[i], scale_b);
248 return;
249 }
250 }
251 else if constexpr (is_division && IsDecimalNumber<B>)
252 {
253 for (size_t i = 0; i < size; ++i)
254 c[i] = applyScaledDiv(a, b[i], scale_a);
255 return;
256 }
257
258 /// default: use it if no return before
259 for (size_t i = 0; i < size; ++i)
260 c[i] = apply(a, b[i]);
261 }
262
263 static ResultType constant_constant(A a, B b, ResultType scale_a [[maybe_unused]], ResultType scale_b [[maybe_unused]])
264 {
265 if constexpr (is_plus_minus_compare)
266 {
267 if (scale_a != 1)
268 return applyScaled<true>(a, b, scale_a);
269 else if (scale_b != 1)
270 return applyScaled<false>(a, b, scale_b);
271 }
272 else if constexpr (is_division && IsDecimalNumber<B>)
273 return applyScaledDiv(a, b, scale_a);
274 return apply(a, b);
275 }
276
277private:
278 /// there's implicit type convertion here
279 static NativeResultType apply(NativeResultType a, NativeResultType b)
280 {
281 if constexpr (can_overflow && _check_overflow)
282 {
283 NativeResultType res;
284 if (Op::template apply<NativeResultType>(a, b, res))
285 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
286 return res;
287 }
288 else
289 return Op::template apply<NativeResultType>(a, b);
290 }
291
292 template <bool scale_left>
293 static NO_SANITIZE_UNDEFINED NativeResultType applyScaled(NativeResultType a, NativeResultType b, NativeResultType scale)
294 {
295 if constexpr (is_plus_minus_compare)
296 {
297 NativeResultType res;
298
299 if constexpr (_check_overflow)
300 {
301 bool overflow = false;
302 if constexpr (scale_left)
303 overflow |= common::mulOverflow(a, scale, a);
304 else
305 overflow |= common::mulOverflow(b, scale, b);
306
307 if constexpr (can_overflow)
308 overflow |= Op::template apply<NativeResultType>(a, b, res);
309 else
310 res = Op::template apply<NativeResultType>(a, b);
311
312 if (overflow)
313 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
314 }
315 else
316 {
317 if constexpr (scale_left)
318 a *= scale;
319 else
320 b *= scale;
321 res = Op::template apply<NativeResultType>(a, b);
322 }
323
324 return res;
325 }
326 }
327
328 static NO_SANITIZE_UNDEFINED NativeResultType applyScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale)
329 {
330 if constexpr (is_division)
331 {
332 if constexpr (_check_overflow)
333 {
334 bool overflow = false;
335 if constexpr (!IsDecimalNumber<A>)
336 overflow |= common::mulOverflow(scale, scale, scale);
337 overflow |= common::mulOverflow(a, scale, a);
338 if (overflow)
339 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
340 }
341 else
342 {
343 if constexpr (!IsDecimalNumber<A>)
344 scale *= scale;
345 a *= scale;
346 }
347
348 return Op::template apply<NativeResultType>(a, b);
349 }
350 }
351};
352
353
354/// Used to indicate undefined operation
355struct InvalidType;
356
357template <bool V, typename T> struct Case : std::bool_constant<V> { using type = T; };
358
359/// Switch<Case<C0, T0>, ...> -- select the first Ti for which Ci is true; InvalidType if none.
360template <typename... Ts> using Switch = typename std::disjunction<Ts..., Case<true, InvalidType>>::type;
361
362template <typename DataType> constexpr bool IsIntegral = false;
363template <> inline constexpr bool IsIntegral<DataTypeUInt8> = true;
364template <> inline constexpr bool IsIntegral<DataTypeUInt16> = true;
365template <> inline constexpr bool IsIntegral<DataTypeUInt32> = true;
366template <> inline constexpr bool IsIntegral<DataTypeUInt64> = true;
367template <> inline constexpr bool IsIntegral<DataTypeInt8> = true;
368template <> inline constexpr bool IsIntegral<DataTypeInt16> = true;
369template <> inline constexpr bool IsIntegral<DataTypeInt32> = true;
370template <> inline constexpr bool IsIntegral<DataTypeInt64> = true;
371
372template <typename DataType> constexpr bool IsFloatingPoint = false;
373template <> inline constexpr bool IsFloatingPoint<DataTypeFloat32> = true;
374template <> inline constexpr bool IsFloatingPoint<DataTypeFloat64> = true;
375
376template <typename DataType> constexpr bool IsDateOrDateTime = false;
377template <> inline constexpr bool IsDateOrDateTime<DataTypeDate> = true;
378template <> inline constexpr bool IsDateOrDateTime<DataTypeDateTime> = true;
379
380template <typename T0, typename T1> constexpr bool UseLeftDecimal = false;
381template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal128>, DataTypeDecimal<Decimal32>> = true;
382template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal128>, DataTypeDecimal<Decimal64>> = true;
383template <> inline constexpr bool UseLeftDecimal<DataTypeDecimal<Decimal64>, DataTypeDecimal<Decimal32>> = true;
384
385template <typename T> using DataTypeFromFieldType = std::conditional_t<std::is_same_v<T, NumberTraits::Error>, InvalidType, DataTypeNumber<T>>;
386
387template <template <typename, typename> class Operation, typename LeftDataType, typename RightDataType>
388struct BinaryOperationTraits
389{
390 using T0 = typename LeftDataType::FieldType;
391 using T1 = typename RightDataType::FieldType;
392private: /// it's not correct for Decimal
393 using Op = Operation<T0, T1>;
394public:
395
396 static constexpr bool allow_decimal =
397 std::is_same_v<Operation<T0, T0>, PlusImpl<T0, T0>> ||
398 std::is_same_v<Operation<T0, T0>, MinusImpl<T0, T0>> ||
399 std::is_same_v<Operation<T0, T0>, MultiplyImpl<T0, T0>> ||
400 std::is_same_v<Operation<T0, T0>, DivideFloatingImpl<T0, T0>> ||
401 std::is_same_v<Operation<T0, T0>, DivideIntegralImpl<T0, T0>> ||
402 std::is_same_v<Operation<T0, T0>, DivideIntegralOrZeroImpl<T0, T0>> ||
403 std::is_same_v<Operation<T0, T0>, LeastBaseImpl<T0, T0>> ||
404 std::is_same_v<Operation<T0, T0>, GreatestBaseImpl<T0, T0>>;
405
406 /// Appropriate result type for binary operator on numeric types. "Date" can also mean
407 /// DateTime, but if both operands are Dates, their type must be the same (e.g. Date - DateTime is invalid).
408 using ResultDataType = Switch<
409 /// Decimal cases
410 Case<!allow_decimal && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>), InvalidType>,
411 Case<IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType> && UseLeftDecimal<LeftDataType, RightDataType>, LeftDataType>,
412 Case<IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>, RightDataType>,
413 Case<IsDataTypeDecimal<LeftDataType> && !IsDataTypeDecimal<RightDataType> && IsIntegral<RightDataType>, LeftDataType>,
414 Case<!IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType> && IsIntegral<LeftDataType>, RightDataType>,
415 /// Decimal <op> Real is not supported (traditional DBs convert Decimal <op> Real to Real)
416 Case<IsDataTypeDecimal<LeftDataType> && !IsDataTypeDecimal<RightDataType> && !IsIntegral<RightDataType>, InvalidType>,
417 Case<!IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType> && !IsIntegral<LeftDataType>, InvalidType>,
418 /// number <op> number -> see corresponding impl
419 Case<!IsDateOrDateTime<LeftDataType> && !IsDateOrDateTime<RightDataType>,
420 DataTypeFromFieldType<typename Op::ResultType>>,
421 /// Date + Integral -> Date
422 /// Integral + Date -> Date
423 Case<std::is_same_v<Op, PlusImpl<T0, T1>>, Switch<
424 Case<IsIntegral<RightDataType>, LeftDataType>,
425 Case<IsIntegral<LeftDataType>, RightDataType>>>,
426 /// Date - Date -> Int32
427 /// Date - Integral -> Date
428 Case<std::is_same_v<Op, MinusImpl<T0, T1>>, Switch<
429 Case<std::is_same_v<LeftDataType, RightDataType>, DataTypeInt32>,
430 Case<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>, LeftDataType>>>,
431 /// least(Date, Date) -> Date
432 /// greatest(Date, Date) -> Date
433 Case<std::is_same_v<LeftDataType, RightDataType> && (std::is_same_v<Op, LeastBaseImpl<T0, T1>> || std::is_same_v<Op, GreatestBaseImpl<T0, T1>>),
434 LeftDataType>,
435 /// Date % Int32 -> int32
436 Case<std::is_same_v<Op, ModuloImpl<T0, T1>>, Switch<
437 Case<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>, RightDataType>,
438 Case<IsDateOrDateTime<LeftDataType> && IsFloatingPoint<RightDataType>, DataTypeInt32>>>>;
439};
440
441
442template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
443class FunctionBinaryArithmetic : public IFunction
444{
445 const Context & context;
446 bool check_decimal_overflow = true;
447
448 template <typename F>
449 static bool castType(const IDataType * type, F && f)
450 {
451 return castTypeToEither<
452 DataTypeUInt8,
453 DataTypeUInt16,
454 DataTypeUInt32,
455 DataTypeUInt64,
456 DataTypeInt8,
457 DataTypeInt16,
458 DataTypeInt32,
459 DataTypeInt64,
460 DataTypeFloat32,
461 DataTypeFloat64,
462 DataTypeDate,
463 DataTypeDateTime,
464 DataTypeDecimal<Decimal32>,
465 DataTypeDecimal<Decimal64>,
466 DataTypeDecimal<Decimal128>
467 >(type, std::forward<F>(f));
468 }
469
470 template <typename F>
471 static bool castBothTypes(const IDataType * left, const IDataType * right, F && f)
472 {
473 return castType(left, [&](const auto & left_) { return castType(right, [&](const auto & right_) { return f(left_, right_); }); });
474 }
475
476 FunctionOverloadResolverPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
477 {
478 /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
479 /// We construct another function (example: addMonths) and call it.
480
481 bool function_is_plus = std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>;
482 bool function_is_minus = std::is_same_v<Op<UInt8, UInt8>, MinusImpl<UInt8, UInt8>>;
483
484 if (!function_is_plus && !function_is_minus)
485 return {};
486
487 int interval_arg = 1;
488 const DataTypeInterval * interval_data_type = checkAndGetDataType<DataTypeInterval>(type1.get());
489 if (!interval_data_type)
490 {
491 interval_arg = 0;
492 interval_data_type = checkAndGetDataType<DataTypeInterval>(type0.get());
493 }
494 if (!interval_data_type)
495 return {};
496
497 if (interval_arg == 0 && function_is_minus)
498 throw Exception("Wrong order of arguments for function " + getName() + ": argument of type Interval cannot be first.",
499 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
500
501 const DataTypeDate * date_data_type = checkAndGetDataType<DataTypeDate>(interval_arg == 0 ? type1.get() : type0.get());
502 const DataTypeDateTime * date_time_data_type = nullptr;
503 if (!date_data_type)
504 {
505 date_time_data_type = checkAndGetDataType<DataTypeDateTime>(interval_arg == 0 ? type1.get() : type0.get());
506 if (!date_time_data_type)
507 throw Exception("Wrong argument types for function " + getName() + ": if one argument is Interval, then another must be Date or DateTime.",
508 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
509 }
510
511 std::stringstream function_name;
512 function_name << (function_is_plus ? "add" : "subtract") << interval_data_type->getKind().toString() << 's';
513
514 return FunctionFactory::instance().get(function_name.str(), context);
515 }
516
517 bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
518 {
519 if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
520 return false;
521
522 WhichDataType which0(type0);
523 WhichDataType which1(type1);
524
525 return (which0.isAggregateFunction() && which1.isNativeUInt())
526 || (which0.isNativeUInt() && which1.isAggregateFunction());
527 }
528
529 bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) const
530 {
531 if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>)
532 return false;
533
534 WhichDataType which0(type0);
535 WhichDataType which1(type1);
536
537 return which0.isAggregateFunction() && which1.isAggregateFunction();
538 }
539
540 /// Multiply aggregation state by integer constant: by merging it with itself specified number of times.
541 void executeAggregateMultiply(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
542 {
543 ColumnNumbers new_arguments = arguments;
544 if (WhichDataType(block.getByPosition(new_arguments[1]).type).isAggregateFunction())
545 std::swap(new_arguments[0], new_arguments[1]);
546
547 if (!isColumnConst(*block.getByPosition(new_arguments[1]).column))
548 throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName()
549 + " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN};
550
551 const IColumn & agg_state_column = *block.getByPosition(new_arguments[0]).column;
552 bool agg_state_is_const = isColumnConst(agg_state_column);
553 const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>(
554 agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column);
555
556 AggregateFunctionPtr function = column.getAggregateFunction();
557
558
559 size_t size = agg_state_is_const ? 1 : input_rows_count;
560
561 auto column_to = ColumnAggregateFunction::create(function);
562 column_to->reserve(size);
563
564 auto column_from = ColumnAggregateFunction::create(function);
565 column_from->reserve(size);
566
567 for (size_t i = 0; i < size; ++i)
568 {
569 column_to->insertDefault();
570 column_from->insertFrom(column.getData()[i]);
571 }
572
573 auto & vec_to = column_to->getData();
574 auto & vec_from = column_from->getData();
575
576 UInt64 m = typeid_cast<const ColumnConst *>(block.getByPosition(new_arguments[1]).column.get())->getValue<UInt64>();
577
578 // Since we merge the function states by ourselves, we have to have an
579 // Arena for this. Pass it to the resulting column so that the arena
580 // has a proper lifetime.
581 auto arena = std::make_shared<Arena>();
582 column_to->addArena(arena);
583
584 /// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations
585 /// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
586 while (m)
587 {
588 if (m % 2)
589 {
590 for (size_t i = 0; i < size; ++i)
591 function->merge(vec_to[i], vec_from[i], arena.get());
592 --m;
593 }
594 else
595 {
596 for (size_t i = 0; i < size; ++i)
597 function->merge(vec_from[i], vec_from[i], arena.get());
598 m /= 2;
599 }
600 }
601
602 if (agg_state_is_const)
603 block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count);
604 else
605 block.getByPosition(result).column = std::move(column_to);
606 }
607
608 /// Merge two aggregation states together.
609 void executeAggregateAddition(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
610 {
611 const IColumn & lhs_column = *block.getByPosition(arguments[0]).column;
612 const IColumn & rhs_column = *block.getByPosition(arguments[1]).column;
613
614 bool lhs_is_const = isColumnConst(lhs_column);
615 bool rhs_is_const = isColumnConst(rhs_column);
616
617 const ColumnAggregateFunction & lhs = typeid_cast<const ColumnAggregateFunction &>(
618 lhs_is_const ? assert_cast<const ColumnConst &>(lhs_column).getDataColumn() : lhs_column);
619 const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>(
620 rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column);
621
622 AggregateFunctionPtr function = lhs.getAggregateFunction();
623
624 size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count;
625
626 auto column_to = ColumnAggregateFunction::create(function);
627 column_to->reserve(size);
628
629 for (size_t i = 0; i < size; ++i)
630 {
631 column_to->insertFrom(lhs.getData()[lhs_is_const ? 0 : i]);
632 column_to->insertMergeFrom(rhs.getData()[rhs_is_const ? 0 : i]);
633 }
634
635 if (lhs_is_const && rhs_is_const)
636 block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count);
637 else
638 block.getByPosition(result).column = std::move(column_to);
639 }
640
641 void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments,
642 size_t result, size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const
643 {
644 ColumnNumbers new_arguments = arguments;
645
646 /// Interval argument must be second.
647 if (WhichDataType(block.getByPosition(arguments[0]).type).isInterval())
648 std::swap(new_arguments[0], new_arguments[1]);
649
650 /// Change interval argument type to its representation
651 Block new_block = block;
652 new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
653
654 ColumnsWithTypeAndName new_arguments_with_type_and_name =
655 {new_block.getByPosition(new_arguments[0]), new_block.getByPosition(new_arguments[1])};
656 auto function = function_builder->build(new_arguments_with_type_and_name);
657
658 function->execute(new_block, new_arguments, result, input_rows_count);
659 block.getByPosition(result).column = new_block.getByPosition(result).column;
660 }
661
662public:
663 static constexpr auto name = Name::name;
664 static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
665
666 FunctionBinaryArithmetic(const Context & context_)
667 : context(context_),
668 check_decimal_overflow(decimalCheckArithmeticOverflow(context))
669 {}
670
671 String getName() const override
672 {
673 return name;
674 }
675
676 size_t getNumberOfArguments() const override { return 2; }
677
678 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
679 {
680 /// Special case when multiply aggregate function state
681 if (isAggregateMultiply(arguments[0], arguments[1]))
682 {
683 if (WhichDataType(arguments[0]).isAggregateFunction())
684 return arguments[0];
685 return arguments[1];
686 }
687
688 /// Special case - addition of two aggregate functions states
689 if (isAggregateAddition(arguments[0], arguments[1]))
690 {
691 if (!arguments[0]->equals(*arguments[1]))
692 throw Exception("Cannot add aggregate states of different functions: "
693 + arguments[0]->getName() + " and " + arguments[1]->getName(), ErrorCodes::CANNOT_ADD_DIFFERENT_AGGREGATE_STATES);
694
695 return arguments[0];
696 }
697
698 /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
699 if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
700 {
701 ColumnsWithTypeAndName new_arguments(2);
702
703 for (size_t i = 0; i < 2; ++i)
704 new_arguments[i].type = arguments[i];
705
706 /// Interval argument must be second.
707 if (WhichDataType(new_arguments[0].type).isInterval())
708 std::swap(new_arguments[0], new_arguments[1]);
709
710 /// Change interval argument to its representation
711 new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
712
713 auto function = function_builder->build(new_arguments);
714 return function->getReturnType();
715 }
716
717 DataTypePtr type_res;
718 bool valid = castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right)
719 {
720 using LeftDataType = std::decay_t<decltype(left)>;
721 using RightDataType = std::decay_t<decltype(right)>;
722 using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
723 if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
724 {
725 if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>)
726 {
727 constexpr bool is_multiply = std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>;
728 constexpr bool is_division = std::is_same_v<Op<UInt8, UInt8>, DivideFloatingImpl<UInt8, UInt8>> ||
729 std::is_same_v<Op<UInt8, UInt8>, DivideIntegralImpl<UInt8, UInt8>> ||
730 std::is_same_v<Op<UInt8, UInt8>, DivideIntegralOrZeroImpl<UInt8, UInt8>>;
731
732 ResultDataType result_type = decimalResultType(left, right, is_multiply, is_division);
733 type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale());
734 }
735 else if constexpr (IsDataTypeDecimal<LeftDataType>)
736 type_res = std::make_shared<LeftDataType>(left.getPrecision(), left.getScale());
737 else if constexpr (IsDataTypeDecimal<RightDataType>)
738 type_res = std::make_shared<RightDataType>(right.getPrecision(), right.getScale());
739 else
740 type_res = std::make_shared<ResultDataType>();
741 return true;
742 }
743 return false;
744 });
745 if (!valid)
746 throw Exception("Illegal types " + arguments[0]->getName() + " and " + arguments[1]->getName() + " of arguments of function " + getName(),
747 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
748 return type_res;
749 }
750
751 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
752 {
753 /// Special case when multiply aggregate function state
754 if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
755 {
756 executeAggregateMultiply(block, arguments, result, input_rows_count);
757 return;
758 }
759
760 /// Special case - addition of two aggregate functions states
761 if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
762 {
763 executeAggregateAddition(block, arguments, result, input_rows_count);
764 return;
765 }
766
767 /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
768 if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
769 {
770 executeDateTimeIntervalPlusMinus(block, arguments, result, input_rows_count, function_builder);
771 return;
772 }
773
774 auto * left_generic = block.getByPosition(arguments[0]).type.get();
775 auto * right_generic = block.getByPosition(arguments[1]).type.get();
776 bool valid = castBothTypes(left_generic, right_generic, [&](const auto & left, const auto & right)
777 {
778 using LeftDataType = std::decay_t<decltype(left)>;
779 using RightDataType = std::decay_t<decltype(right)>;
780 using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
781 if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
782 {
783 constexpr bool result_is_decimal = IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
784 constexpr bool is_multiply = std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>;
785 constexpr bool is_division = std::is_same_v<Op<UInt8, UInt8>, DivideFloatingImpl<UInt8, UInt8>> ||
786 std::is_same_v<Op<UInt8, UInt8>, DivideIntegralImpl<UInt8, UInt8>> ||
787 std::is_same_v<Op<UInt8, UInt8>, DivideIntegralOrZeroImpl<UInt8, UInt8>>;
788
789 using T0 = typename LeftDataType::FieldType;
790 using T1 = typename RightDataType::FieldType;
791 using ResultType = typename ResultDataType::FieldType;
792 using ColVecT0 = std::conditional_t<IsDecimalNumber<T0>, ColumnDecimal<T0>, ColumnVector<T0>>;
793 using ColVecT1 = std::conditional_t<IsDecimalNumber<T1>, ColumnDecimal<T1>, ColumnVector<T1>>;
794 using ColVecResult = std::conditional_t<IsDecimalNumber<ResultType>, ColumnDecimal<ResultType>, ColumnVector<ResultType>>;
795
796 /// Decimal operations need scale. Operations are on result type.
797 using OpImpl = std::conditional_t<IsDataTypeDecimal<ResultDataType>,
798 DecimalBinaryOperation<T0, T1, Op, ResultType>,
799 BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>>;
800
801 auto col_left_raw = block.getByPosition(arguments[0]).column.get();
802 auto col_right_raw = block.getByPosition(arguments[1]).column.get();
803 if (auto col_left = checkAndGetColumnConst<ColVecT0>(col_left_raw))
804 {
805 if (auto col_right = checkAndGetColumnConst<ColVecT1>(col_right_raw))
806 {
807 /// the only case with a non-vector result
808 if constexpr (result_is_decimal)
809 {
810 ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
811 typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
812 typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
813 if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
814 scale_a = right.getScaleMultiplier();
815
816 auto res = OpImpl::constant_constant(col_left->template getValue<T0>(), col_right->template getValue<T1>(),
817 scale_a, scale_b, check_decimal_overflow);
818 block.getByPosition(result).column =
819 ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
820 col_left->size(), toField(res, type.getScale()));
821
822 }
823 else
824 {
825 auto res = OpImpl::constant_constant(col_left->template getValue<T0>(), col_right->template getValue<T1>());
826 block.getByPosition(result).column = ResultDataType().createColumnConst(col_left->size(), toField(res));
827 }
828 return true;
829 }
830 }
831
832 typename ColVecResult::MutablePtr col_res = nullptr;
833 if constexpr (result_is_decimal)
834 {
835 ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
836 col_res = ColVecResult::create(0, type.getScale());
837 }
838 else
839 col_res = ColVecResult::create();
840
841 auto & vec_res = col_res->getData();
842 vec_res.resize(block.rows());
843
844 if (auto col_left_const = checkAndGetColumnConst<ColVecT0>(col_left_raw))
845 {
846 if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
847 {
848 if constexpr (result_is_decimal)
849 {
850 ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
851
852 typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
853 typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
854 if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
855 scale_a = right.getScaleMultiplier();
856
857 OpImpl::constant_vector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res,
858 scale_a, scale_b, check_decimal_overflow);
859 }
860 else
861 OpImpl::constant_vector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res);
862 }
863 else
864 return false;
865 }
866 else if (auto col_left = checkAndGetColumn<ColVecT0>(col_left_raw))
867 {
868 if constexpr (result_is_decimal)
869 {
870 ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
871
872 typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
873 typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
874 if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
875 scale_a = right.getScaleMultiplier();
876 if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
877 {
878 OpImpl::vector_vector(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b,
879 check_decimal_overflow);
880 }
881 else if (auto col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw))
882 {
883 OpImpl::vector_constant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res,
884 scale_a, scale_b, check_decimal_overflow);
885 }
886 else
887 return false;
888 }
889 else
890 {
891 if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
892 OpImpl::vector_vector(col_left->getData(), col_right->getData(), vec_res);
893 else if (auto col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw))
894 OpImpl::vector_constant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res);
895 else
896 return false;
897 }
898 }
899 else
900 return false;
901
902 block.getByPosition(result).column = std::move(col_res);
903 return true;
904 }
905 return false;
906 });
907 if (!valid)
908 throw Exception(getName() + "'s arguments do not match the expected data types", ErrorCodes::LOGICAL_ERROR);
909 }
910
911#if USE_EMBEDDED_COMPILER
912 bool isCompilableImpl(const DataTypes & arguments) const override
913 {
914 return castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right)
915 {
916 using LeftDataType = std::decay_t<decltype(left)>;
917 using RightDataType = std::decay_t<decltype(right)>;
918 using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
919 using OpSpec = Op<typename LeftDataType::FieldType, typename RightDataType::FieldType>;
920 return !std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable;
921 });
922 }
923
924 llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
925 {
926 llvm::Value * result = nullptr;
927 castBothTypes(types[0].get(), types[1].get(), [&](const auto & left, const auto & right)
928 {
929 using LeftDataType = std::decay_t<decltype(left)>;
930 using RightDataType = std::decay_t<decltype(right)>;
931 using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
932 using OpSpec = Op<typename LeftDataType::FieldType, typename RightDataType::FieldType>;
933 if constexpr (!std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable)
934 {
935 auto & b = static_cast<llvm::IRBuilder<> &>(builder);
936 auto type = std::make_shared<ResultDataType>();
937 auto * lval = nativeCast(b, types[0], values[0](), type);
938 auto * rval = nativeCast(b, types[1], values[1](), type);
939 result = OpSpec::compile(b, lval, rval, std::is_signed_v<typename ResultDataType::FieldType>);
940 return true;
941 }
942 return false;
943 });
944 return result;
945 }
946#endif
947
948 bool canBeExecutedOnDefaultArguments() const override { return valid_on_default_arguments; }
949};
950
951}
952