1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionAvg.h>
3#include <AggregateFunctions/Helpers.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5#include "registerAggregateFunctions.h"
6
7namespace DB
8{
9
10namespace
11{
12
13template <typename T>
14struct Avg
15{
16 using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
17 using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType, UInt64>>;
18};
19
20template <typename T>
21using AggregateFuncAvg = typename Avg<T>::Function;
22
23AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters)
24{
25 assertNoParameters(name, parameters);
26 assertUnary(name, argument_types);
27
28 AggregateFunctionPtr res;
29 DataTypePtr data_type = argument_types[0];
30 if (isDecimal(data_type))
31 res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type, argument_types));
32 else
33 res.reset(createWithNumericType<AggregateFuncAvg>(*data_type, argument_types));
34
35 if (!res)
36 throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
37 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
38 return res;
39}
40
41}
42
43void registerAggregateFunctionAvg(AggregateFunctionFactory & factory)
44{
45 factory.registerFunction("avg", createAggregateFunctionAvg, AggregateFunctionFactory::CaseInsensitive);
46}
47
48}
49