| 1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
| 2 | #include <AggregateFunctions/AggregateFunctionSumMap.h> |
| 3 | #include <AggregateFunctions/Helpers.h> |
| 4 | #include <AggregateFunctions/FactoryHelpers.h> |
| 5 | #include <Functions/FunctionHelpers.h> |
| 6 | #include <IO/WriteHelpers.h> |
| 7 | #include "registerAggregateFunctions.h" |
| 8 | |
| 9 | |
| 10 | namespace DB |
| 11 | { |
| 12 | |
| 13 | namespace |
| 14 | { |
| 15 | |
| 16 | struct WithOverflowPolicy |
| 17 | { |
| 18 | /// Overflow, meaning that the returned type is the same as the input type. |
| 19 | static DataTypePtr promoteType(const DataTypePtr & data_type) { return data_type; } |
| 20 | }; |
| 21 | |
| 22 | struct WithoutOverflowPolicy |
| 23 | { |
| 24 | /// No overflow, meaning we promote the types if necessary. |
| 25 | static DataTypePtr promoteType(const DataTypePtr & data_type) |
| 26 | { |
| 27 | if (!data_type->canBePromoted()) |
| 28 | throw Exception{"Values to be summed are expected to be Numeric, Float or Decimal." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
| 29 | |
| 30 | return data_type->promoteNumericType(); |
| 31 | } |
| 32 | }; |
| 33 | |
| 34 | template <typename T> |
| 35 | using SumMapWithOverflow = AggregateFunctionSumMap<T, WithOverflowPolicy>; |
| 36 | |
| 37 | template <typename T> |
| 38 | using SumMapWithoutOverflow = AggregateFunctionSumMap<T, WithoutOverflowPolicy>; |
| 39 | |
| 40 | template <typename T> |
| 41 | using SumMapFilteredWithOverflow = AggregateFunctionSumMapFiltered<T, WithOverflowPolicy>; |
| 42 | |
| 43 | template <typename T> |
| 44 | using SumMapFilteredWithoutOverflow = AggregateFunctionSumMapFiltered<T, WithoutOverflowPolicy>; |
| 45 | |
| 46 | using SumMapArgs = std::pair<DataTypePtr, DataTypes>; |
| 47 | |
| 48 | SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments) |
| 49 | { |
| 50 | if (arguments.size() < 2) |
| 51 | throw Exception("Aggregate function " + name + " requires at least two arguments of Array type." , |
| 52 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 53 | |
| 54 | const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get()); |
| 55 | if (!array_type) |
| 56 | throw Exception("First argument for function " + name + " must be an array." , |
| 57 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 58 | |
| 59 | |
| 60 | DataTypePtr keys_type = array_type->getNestedType(); |
| 61 | |
| 62 | DataTypes values_types; |
| 63 | values_types.reserve(arguments.size() - 1); |
| 64 | for (size_t i = 1; i < arguments.size(); ++i) |
| 65 | { |
| 66 | array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get()); |
| 67 | if (!array_type) |
| 68 | throw Exception("Argument #" + toString(i) + " for function " + name + " must be an array." , |
| 69 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 70 | values_types.push_back(array_type->getNestedType()); |
| 71 | } |
| 72 | |
| 73 | return {std::move(keys_type), std::move(values_types)}; |
| 74 | } |
| 75 | |
| 76 | template <template <typename> class Function> |
| 77 | AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) |
| 78 | { |
| 79 | assertNoParameters(name, params); |
| 80 | |
| 81 | auto [keys_type, values_types] = parseArguments(name, arguments); |
| 82 | |
| 83 | AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments)); |
| 84 | if (!res) |
| 85 | res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments)); |
| 86 | if (!res) |
| 87 | throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 88 | |
| 89 | return res; |
| 90 | } |
| 91 | |
| 92 | template <template <typename> class Function> |
| 93 | AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params) |
| 94 | { |
| 95 | if (params.size() != 1) |
| 96 | throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type." , |
| 97 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 98 | |
| 99 | Array keys_to_keep; |
| 100 | if (!params.front().tryGet<Array>(keys_to_keep)) |
| 101 | throw Exception("Aggregate function " + name + " requires an Array as parameter." , |
| 102 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 103 | |
| 104 | auto [keys_type, values_types] = parseArguments(name, arguments); |
| 105 | |
| 106 | AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); |
| 107 | if (!res) |
| 108 | res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); |
| 109 | if (!res) |
| 110 | throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 111 | |
| 112 | return res; |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) |
| 117 | { |
| 118 | factory.registerFunction("sumMap" , createAggregateFunctionSumMap<SumMapWithoutOverflow>); |
| 119 | factory.registerFunction("sumMapWithOverflow" , createAggregateFunctionSumMap<SumMapWithOverflow>); |
| 120 | factory.registerFunction("sumMapFiltered" , createAggregateFunctionSumMapFiltered<SumMapFilteredWithoutOverflow>); |
| 121 | factory.registerFunction("sumMapFilteredWithOverflow" , createAggregateFunctionSumMapFiltered<SumMapFilteredWithOverflow>); |
| 122 | } |
| 123 | |
| 124 | } |
| 125 | |