1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionSumMap.h>
3#include <AggregateFunctions/Helpers.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5#include <Functions/FunctionHelpers.h>
6#include <IO/WriteHelpers.h>
7#include "registerAggregateFunctions.h"
8
9
10namespace DB
11{
12
13namespace
14{
15
16struct WithOverflowPolicy
17{
18 /// Overflow, meaning that the returned type is the same as the input type.
19 static DataTypePtr promoteType(const DataTypePtr & data_type) { return data_type; }
20};
21
22struct WithoutOverflowPolicy
23{
24 /// No overflow, meaning we promote the types if necessary.
25 static DataTypePtr promoteType(const DataTypePtr & data_type)
26 {
27 if (!data_type->canBePromoted())
28 throw Exception{"Values to be summed are expected to be Numeric, Float or Decimal.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
29
30 return data_type->promoteNumericType();
31 }
32};
33
34template <typename T>
35using SumMapWithOverflow = AggregateFunctionSumMap<T, WithOverflowPolicy>;
36
37template <typename T>
38using SumMapWithoutOverflow = AggregateFunctionSumMap<T, WithoutOverflowPolicy>;
39
40template <typename T>
41using SumMapFilteredWithOverflow = AggregateFunctionSumMapFiltered<T, WithOverflowPolicy>;
42
43template <typename T>
44using SumMapFilteredWithoutOverflow = AggregateFunctionSumMapFiltered<T, WithoutOverflowPolicy>;
45
46using SumMapArgs = std::pair<DataTypePtr, DataTypes>;
47
48SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
49{
50 if (arguments.size() < 2)
51 throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
52 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
53
54 const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
55 if (!array_type)
56 throw Exception("First argument for function " + name + " must be an array.",
57 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
58
59
60 DataTypePtr keys_type = array_type->getNestedType();
61
62 DataTypes values_types;
63 values_types.reserve(arguments.size() - 1);
64 for (size_t i = 1; i < arguments.size(); ++i)
65 {
66 array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
67 if (!array_type)
68 throw Exception("Argument #" + toString(i) + " for function " + name + " must be an array.",
69 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
70 values_types.push_back(array_type->getNestedType());
71 }
72
73 return {std::move(keys_type), std::move(values_types)};
74}
75
76template <template <typename> class Function>
77AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
78{
79 assertNoParameters(name, params);
80
81 auto [keys_type, values_types] = parseArguments(name, arguments);
82
83 AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
84 if (!res)
85 res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments));
86 if (!res)
87 throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
88
89 return res;
90}
91
92template <template <typename> class Function>
93AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
94{
95 if (params.size() != 1)
96 throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type.",
97 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
98
99 Array keys_to_keep;
100 if (!params.front().tryGet<Array>(keys_to_keep))
101 throw Exception("Aggregate function " + name + " requires an Array as parameter.",
102 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
103
104 auto [keys_type, values_types] = parseArguments(name, arguments);
105
106 AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
107 if (!res)
108 res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
109 if (!res)
110 throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
111
112 return res;
113}
114}
115
116void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
117{
118 factory.registerFunction("sumMap", createAggregateFunctionSumMap<SumMapWithoutOverflow>);
119 factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<SumMapWithOverflow>);
120 factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<SumMapFilteredWithoutOverflow>);
121 factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<SumMapFilteredWithOverflow>);
122}
123
124}
125