1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionSum.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include "registerAggregateFunctions.h" |
6 | |
7 | |
8 | namespace DB |
9 | { |
10 | |
11 | namespace |
12 | { |
13 | |
14 | template <typename T> |
15 | struct SumSimple |
16 | { |
17 | /// @note It uses slow Decimal128 (cause we need such a variant). sumWithOverflow is faster for Decimal32/64 |
18 | using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>; |
19 | using AggregateDataType = AggregateFunctionSumData<ResultType>; |
20 | using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>; |
21 | }; |
22 | |
23 | template <typename T> |
24 | struct SumSameType |
25 | { |
26 | using ResultType = T; |
27 | using AggregateDataType = AggregateFunctionSumData<ResultType>; |
28 | using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>; |
29 | }; |
30 | |
31 | template <typename T> |
32 | struct SumKahan |
33 | { |
34 | using ResultType = Float64; |
35 | using AggregateDataType = AggregateFunctionSumKahanData<ResultType>; |
36 | using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>; |
37 | }; |
38 | |
39 | template <typename T> using AggregateFunctionSumSimple = typename SumSimple<T>::Function; |
40 | template <typename T> using AggregateFunctionSumWithOverflow = typename SumSameType<T>::Function; |
41 | template <typename T> using AggregateFunctionSumKahan = |
42 | std::conditional_t<IsDecimalNumber<T>, typename SumSimple<T>::Function, typename SumKahan<T>::Function>; |
43 | |
44 | |
45 | template <template <typename> class Function> |
46 | AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const DataTypes & argument_types, const Array & parameters) |
47 | { |
48 | assertNoParameters(name, parameters); |
49 | assertUnary(name, argument_types); |
50 | |
51 | AggregateFunctionPtr res; |
52 | DataTypePtr data_type = argument_types[0]; |
53 | if (isDecimal(data_type)) |
54 | res.reset(createWithDecimalType<Function>(*data_type, *data_type, argument_types)); |
55 | else |
56 | res.reset(createWithNumericType<Function>(*data_type, argument_types)); |
57 | |
58 | if (!res) |
59 | throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, |
60 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
61 | return res; |
62 | } |
63 | |
64 | } |
65 | |
66 | void registerAggregateFunctionSum(AggregateFunctionFactory & factory) |
67 | { |
68 | factory.registerFunction("sum" , createAggregateFunctionSum<AggregateFunctionSumSimple>, AggregateFunctionFactory::CaseInsensitive); |
69 | factory.registerFunction("sumWithOverflow" , createAggregateFunctionSum<AggregateFunctionSumWithOverflow>); |
70 | factory.registerFunction("sumKahan" , createAggregateFunctionSum<AggregateFunctionSumKahan>); |
71 | } |
72 | |
73 | } |
74 | |