1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionSumMap.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include <Functions/FunctionHelpers.h> |
6 | #include <IO/WriteHelpers.h> |
7 | #include "registerAggregateFunctions.h" |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | |
13 | namespace |
14 | { |
15 | |
16 | struct WithOverflowPolicy |
17 | { |
18 | /// Overflow, meaning that the returned type is the same as the input type. |
19 | static DataTypePtr promoteType(const DataTypePtr & data_type) { return data_type; } |
20 | }; |
21 | |
22 | struct WithoutOverflowPolicy |
23 | { |
24 | /// No overflow, meaning we promote the types if necessary. |
25 | static DataTypePtr promoteType(const DataTypePtr & data_type) |
26 | { |
27 | if (!data_type->canBePromoted()) |
28 | throw Exception{"Values to be summed are expected to be Numeric, Float or Decimal." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
29 | |
30 | return data_type->promoteNumericType(); |
31 | } |
32 | }; |
33 | |
34 | template <typename T> |
35 | using SumMapWithOverflow = AggregateFunctionSumMap<T, WithOverflowPolicy>; |
36 | |
37 | template <typename T> |
38 | using SumMapWithoutOverflow = AggregateFunctionSumMap<T, WithoutOverflowPolicy>; |
39 | |
40 | template <typename T> |
41 | using SumMapFilteredWithOverflow = AggregateFunctionSumMapFiltered<T, WithOverflowPolicy>; |
42 | |
43 | template <typename T> |
44 | using SumMapFilteredWithoutOverflow = AggregateFunctionSumMapFiltered<T, WithoutOverflowPolicy>; |
45 | |
46 | using SumMapArgs = std::pair<DataTypePtr, DataTypes>; |
47 | |
48 | SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments) |
49 | { |
50 | if (arguments.size() < 2) |
51 | throw Exception("Aggregate function " + name + " requires at least two arguments of Array type." , |
52 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
53 | |
54 | const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get()); |
55 | if (!array_type) |
56 | throw Exception("First argument for function " + name + " must be an array." , |
57 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
58 | |
59 | |
60 | DataTypePtr keys_type = array_type->getNestedType(); |
61 | |
62 | DataTypes values_types; |
63 | values_types.reserve(arguments.size() - 1); |
64 | for (size_t i = 1; i < arguments.size(); ++i) |
65 | { |
66 | array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get()); |
67 | if (!array_type) |
68 | throw Exception("Argument #" + toString(i) + " for function " + name + " must be an array." , |
69 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
70 | values_types.push_back(array_type->getNestedType()); |
71 | } |
72 | |
73 | return {std::move(keys_type), std::move(values_types)}; |
74 | } |
75 | |
76 | template <template <typename> class Function> |
77 | AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) |
78 | { |
79 | assertNoParameters(name, params); |
80 | |
81 | auto [keys_type, values_types] = parseArguments(name, arguments); |
82 | |
83 | AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments)); |
84 | if (!res) |
85 | res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments)); |
86 | if (!res) |
87 | throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
88 | |
89 | return res; |
90 | } |
91 | |
92 | template <template <typename> class Function> |
93 | AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params) |
94 | { |
95 | if (params.size() != 1) |
96 | throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type." , |
97 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
98 | |
99 | Array keys_to_keep; |
100 | if (!params.front().tryGet<Array>(keys_to_keep)) |
101 | throw Exception("Aggregate function " + name + " requires an Array as parameter." , |
102 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
103 | |
104 | auto [keys_type, values_types] = parseArguments(name, arguments); |
105 | |
106 | AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); |
107 | if (!res) |
108 | res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); |
109 | if (!res) |
110 | throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
111 | |
112 | return res; |
113 | } |
114 | } |
115 | |
116 | void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) |
117 | { |
118 | factory.registerFunction("sumMap" , createAggregateFunctionSumMap<SumMapWithoutOverflow>); |
119 | factory.registerFunction("sumMapWithOverflow" , createAggregateFunctionSumMap<SumMapWithOverflow>); |
120 | factory.registerFunction("sumMapFiltered" , createAggregateFunctionSumMapFiltered<SumMapFilteredWithoutOverflow>); |
121 | factory.registerFunction("sumMapFilteredWithOverflow" , createAggregateFunctionSumMapFiltered<SumMapFilteredWithOverflow>); |
122 | } |
123 | |
124 | } |
125 | |