1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionAvg.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include "registerAggregateFunctions.h" |
6 | |
7 | namespace DB |
8 | { |
9 | |
10 | namespace |
11 | { |
12 | |
13 | template <typename T> |
14 | struct Avg |
15 | { |
16 | using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>; |
17 | using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType, UInt64>>; |
18 | }; |
19 | |
20 | template <typename T> |
21 | using AggregateFuncAvg = typename Avg<T>::Function; |
22 | |
23 | AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters) |
24 | { |
25 | assertNoParameters(name, parameters); |
26 | assertUnary(name, argument_types); |
27 | |
28 | AggregateFunctionPtr res; |
29 | DataTypePtr data_type = argument_types[0]; |
30 | if (isDecimal(data_type)) |
31 | res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type, argument_types)); |
32 | else |
33 | res.reset(createWithNumericType<AggregateFuncAvg>(*data_type, argument_types)); |
34 | |
35 | if (!res) |
36 | throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, |
37 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
38 | return res; |
39 | } |
40 | |
41 | } |
42 | |
43 | void registerAggregateFunctionAvg(AggregateFunctionFactory & factory) |
44 | { |
45 | factory.registerFunction("avg" , createAggregateFunctionAvg, AggregateFunctionFactory::CaseInsensitive); |
46 | } |
47 | |
48 | } |
49 | |