1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionAvgWeighted.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include "registerAggregateFunctions.h" |
6 | |
7 | namespace DB |
8 | { |
9 | |
10 | namespace |
11 | { |
12 | |
13 | template <typename T> |
14 | struct AvgWeighted |
15 | { |
16 | using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>; |
17 | using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType, FieldType>>; |
18 | }; |
19 | |
20 | template <typename T> |
21 | using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function; |
22 | |
23 | AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters) |
24 | { |
25 | assertNoParameters(name, parameters); |
26 | assertBinary(name, argument_types); |
27 | |
28 | AggregateFunctionPtr res; |
29 | const auto data_type = static_cast<const DataTypePtr>(argument_types[0]); |
30 | const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]); |
31 | if (!data_type->equals(*data_type_weight)) |
32 | throw Exception("Different types " + data_type->getName() + " and " + data_type_weight->getName() + " of arguments for aggregate function " + name, |
33 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
34 | if (isDecimal(data_type)) |
35 | res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types)); |
36 | else |
37 | res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types)); |
38 | |
39 | if (!res) |
40 | throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name, |
41 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
42 | return res; |
43 | } |
44 | |
45 | } |
46 | |
47 | void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory) |
48 | { |
49 | factory.registerFunction("avgWeighted" , createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive); |
50 | } |
51 | |
52 | } |
53 | |