| 1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
| 2 | #include <AggregateFunctions/AggregateFunctionCombinatorFactory.h> |
| 3 | |
| 4 | #include <DataTypes/DataTypeAggregateFunction.h> |
| 5 | #include <DataTypes/DataTypeArray.h> |
| 6 | #include <DataTypes/DataTypeNullable.h> |
| 7 | #include <DataTypes/DataTypesNumber.h> |
| 8 | #include <DataTypes/DataTypeLowCardinality.h> |
| 9 | |
| 10 | #include <IO/WriteBuffer.h> |
| 11 | #include <IO/WriteHelpers.h> |
| 12 | |
| 13 | #include <Interpreters/Context.h> |
| 14 | |
| 15 | #include <Common/StringUtils/StringUtils.h> |
| 16 | #include <Common/typeid_cast.h> |
| 17 | |
| 18 | #include <Poco/String.h> |
| 19 | #include "registerAggregateFunctions.h" |
| 20 | |
| 21 | |
| 22 | namespace DB |
| 23 | { |
| 24 | |
| 25 | namespace ErrorCodes |
| 26 | { |
| 27 | extern const int UNKNOWN_AGGREGATE_FUNCTION; |
| 28 | extern const int LOGICAL_ERROR; |
| 29 | } |
| 30 | |
| 31 | |
| 32 | void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) |
| 33 | { |
| 34 | if (creator == nullptr) |
| 35 | throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " |
| 36 | " a null constructor" , ErrorCodes::LOGICAL_ERROR); |
| 37 | |
| 38 | if (!aggregate_functions.emplace(name, creator).second) |
| 39 | throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique" , |
| 40 | ErrorCodes::LOGICAL_ERROR); |
| 41 | |
| 42 | if (case_sensitiveness == CaseInsensitive |
| 43 | && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) |
| 44 | throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique" , |
| 45 | ErrorCodes::LOGICAL_ERROR); |
| 46 | } |
| 47 | |
| 48 | static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types) |
| 49 | { |
| 50 | DataTypes res_types; |
| 51 | res_types.reserve(types.size()); |
| 52 | for (const auto & type : types) |
| 53 | res_types.emplace_back(recursiveRemoveLowCardinality(type)); |
| 54 | |
| 55 | return res_types; |
| 56 | } |
| 57 | |
| 58 | AggregateFunctionPtr AggregateFunctionFactory::get( |
| 59 | const String & name, |
| 60 | const DataTypes & argument_types, |
| 61 | const Array & parameters, |
| 62 | int recursion_level) const |
| 63 | { |
| 64 | auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); |
| 65 | |
| 66 | /// If one of types is Nullable, we apply aggregate function combinator "Null". |
| 67 | |
| 68 | if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), |
| 69 | [](const auto & type) { return type->isNullable(); })) |
| 70 | { |
| 71 | AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null" ); |
| 72 | if (!combinator) |
| 73 | throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments." , ErrorCodes::LOGICAL_ERROR); |
| 74 | |
| 75 | DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); |
| 76 | Array nested_parameters = combinator->transformParameters(parameters); |
| 77 | |
| 78 | AggregateFunctionPtr nested_function; |
| 79 | |
| 80 | /// A little hack - if we have NULL arguments, don't even create nested function. |
| 81 | /// Combinator will check if nested_function was created. |
| 82 | if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), |
| 83 | [](const auto & type) { return type->onlyNull(); })) |
| 84 | nested_function = getImpl(name, nested_types, nested_parameters, recursion_level); |
| 85 | |
| 86 | return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters); |
| 87 | } |
| 88 | |
| 89 | auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); |
| 90 | if (!res) |
| 91 | throw Exception("Logical error: AggregateFunctionFactory returned nullptr" , ErrorCodes::LOGICAL_ERROR); |
| 92 | return res; |
| 93 | } |
| 94 | |
| 95 | |
| 96 | AggregateFunctionPtr AggregateFunctionFactory::getImpl( |
| 97 | const String & name_param, |
| 98 | const DataTypes & argument_types, |
| 99 | const Array & parameters, |
| 100 | int recursion_level) const |
| 101 | { |
| 102 | String name = getAliasToOrName(name_param); |
| 103 | /// Find by exact match. |
| 104 | if (auto it = aggregate_functions.find(name); it != aggregate_functions.end()) |
| 105 | return it->second(name, argument_types, parameters); |
| 106 | |
| 107 | /// Find by case-insensitive name. |
| 108 | /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. |
| 109 | if (recursion_level == 0) |
| 110 | { |
| 111 | if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end()) |
| 112 | return it->second(name, argument_types, parameters); |
| 113 | } |
| 114 | |
| 115 | /// Combinators of aggregate functions. |
| 116 | /// For every aggregate function 'agg' and combiner '-Comb' there is combined aggregate function with name 'aggComb', |
| 117 | /// that can have different number and/or types of arguments, different result type and different behaviour. |
| 118 | |
| 119 | if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name)) |
| 120 | { |
| 121 | if (combinator->isForInternalUsageOnly()) |
| 122 | throw Exception("Aggregate function combinator '" + combinator->getName() + "' is only for internal usage" , ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
| 123 | |
| 124 | String nested_name = name.substr(0, name.size() - combinator->getName().size()); |
| 125 | DataTypes nested_types = combinator->transformArguments(argument_types); |
| 126 | Array nested_parameters = combinator->transformParameters(parameters); |
| 127 | |
| 128 | AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); |
| 129 | |
| 130 | return combinator->transformAggregateFunction(nested_function, argument_types, parameters); |
| 131 | } |
| 132 | |
| 133 | auto hints = this->getHints(name); |
| 134 | if (!hints.empty()) |
| 135 | throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
| 136 | else |
| 137 | throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
| 138 | } |
| 139 | |
| 140 | |
| 141 | AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const |
| 142 | { |
| 143 | return isAggregateFunctionName(name) |
| 144 | ? get(name, argument_types, parameters) |
| 145 | : nullptr; |
| 146 | } |
| 147 | |
| 148 | |
| 149 | bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const |
| 150 | { |
| 151 | if (aggregate_functions.count(name) || isAlias(name)) |
| 152 | return true; |
| 153 | |
| 154 | String name_lowercase = Poco::toLower(name); |
| 155 | if (recursion_level == 0 && (case_insensitive_aggregate_functions.count(name_lowercase) || isAlias(name_lowercase))) |
| 156 | return true; |
| 157 | |
| 158 | if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name)) |
| 159 | return isAggregateFunctionName(name.substr(0, name.size() - combinator->getName().size()), recursion_level + 1); |
| 160 | |
| 161 | return false; |
| 162 | } |
| 163 | |
| 164 | AggregateFunctionFactory & AggregateFunctionFactory::instance() |
| 165 | { |
| 166 | static AggregateFunctionFactory ret; |
| 167 | return ret; |
| 168 | } |
| 169 | |
| 170 | } |
| 171 | |