1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
3#include <AggregateFunctions/Helpers.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5#include "registerAggregateFunctions.h"
6
7namespace DB
8{
9
10namespace
11{
12
13template <typename T>
14struct AvgWeighted
15{
16 using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
17 using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType, FieldType>>;
18};
19
20template <typename T>
21using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function;
22
23AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
24{
25 assertNoParameters(name, parameters);
26 assertBinary(name, argument_types);
27
28 AggregateFunctionPtr res;
29 const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
30 const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
31 if (!data_type->equals(*data_type_weight))
32 throw Exception("Different types " + data_type->getName() + " and " + data_type_weight->getName() + " of arguments for aggregate function " + name,
33 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
34 if (isDecimal(data_type))
35 res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
36 else
37 res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
38
39 if (!res)
40 throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
41 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
42 return res;
43}
44
45}
46
47void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
48{
49 factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
50}
51
52}
53