1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionSum.h>
3#include <AggregateFunctions/Helpers.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5#include "registerAggregateFunctions.h"
6
7
8namespace DB
9{
10
11namespace
12{
13
14template <typename T>
15struct SumSimple
16{
17 /// @note It uses slow Decimal128 (cause we need such a variant). sumWithOverflow is faster for Decimal32/64
18 using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
19 using AggregateDataType = AggregateFunctionSumData<ResultType>;
20 using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
21};
22
23template <typename T>
24struct SumSameType
25{
26 using ResultType = T;
27 using AggregateDataType = AggregateFunctionSumData<ResultType>;
28 using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
29};
30
31template <typename T>
32struct SumKahan
33{
34 using ResultType = Float64;
35 using AggregateDataType = AggregateFunctionSumKahanData<ResultType>;
36 using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
37};
38
39template <typename T> using AggregateFunctionSumSimple = typename SumSimple<T>::Function;
40template <typename T> using AggregateFunctionSumWithOverflow = typename SumSameType<T>::Function;
41template <typename T> using AggregateFunctionSumKahan =
42 std::conditional_t<IsDecimalNumber<T>, typename SumSimple<T>::Function, typename SumKahan<T>::Function>;
43
44
45template <template <typename> class Function>
46AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const DataTypes & argument_types, const Array & parameters)
47{
48 assertNoParameters(name, parameters);
49 assertUnary(name, argument_types);
50
51 AggregateFunctionPtr res;
52 DataTypePtr data_type = argument_types[0];
53 if (isDecimal(data_type))
54 res.reset(createWithDecimalType<Function>(*data_type, *data_type, argument_types));
55 else
56 res.reset(createWithNumericType<Function>(*data_type, argument_types));
57
58 if (!res)
59 throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
60 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
61 return res;
62}
63
64}
65
66void registerAggregateFunctionSum(AggregateFunctionFactory & factory)
67{
68 factory.registerFunction("sum", createAggregateFunctionSum<AggregateFunctionSumSimple>, AggregateFunctionFactory::CaseInsensitive);
69 factory.registerFunction("sumWithOverflow", createAggregateFunctionSum<AggregateFunctionSumWithOverflow>);
70 factory.registerFunction("sumKahan", createAggregateFunctionSum<AggregateFunctionSumKahan>);
71}
72
73}
74