1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionCombinatorFactory.h> |
3 | |
4 | #include <DataTypes/DataTypeAggregateFunction.h> |
5 | #include <DataTypes/DataTypeArray.h> |
6 | #include <DataTypes/DataTypeNullable.h> |
7 | #include <DataTypes/DataTypesNumber.h> |
8 | #include <DataTypes/DataTypeLowCardinality.h> |
9 | |
10 | #include <IO/WriteBuffer.h> |
11 | #include <IO/WriteHelpers.h> |
12 | |
13 | #include <Interpreters/Context.h> |
14 | |
15 | #include <Common/StringUtils/StringUtils.h> |
16 | #include <Common/typeid_cast.h> |
17 | |
18 | #include <Poco/String.h> |
19 | #include "registerAggregateFunctions.h" |
20 | |
21 | |
22 | namespace DB |
23 | { |
24 | |
25 | namespace ErrorCodes |
26 | { |
27 | extern const int UNKNOWN_AGGREGATE_FUNCTION; |
28 | extern const int LOGICAL_ERROR; |
29 | } |
30 | |
31 | |
32 | void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) |
33 | { |
34 | if (creator == nullptr) |
35 | throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " |
36 | " a null constructor" , ErrorCodes::LOGICAL_ERROR); |
37 | |
38 | if (!aggregate_functions.emplace(name, creator).second) |
39 | throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique" , |
40 | ErrorCodes::LOGICAL_ERROR); |
41 | |
42 | if (case_sensitiveness == CaseInsensitive |
43 | && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) |
44 | throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique" , |
45 | ErrorCodes::LOGICAL_ERROR); |
46 | } |
47 | |
48 | static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types) |
49 | { |
50 | DataTypes res_types; |
51 | res_types.reserve(types.size()); |
52 | for (const auto & type : types) |
53 | res_types.emplace_back(recursiveRemoveLowCardinality(type)); |
54 | |
55 | return res_types; |
56 | } |
57 | |
58 | AggregateFunctionPtr AggregateFunctionFactory::get( |
59 | const String & name, |
60 | const DataTypes & argument_types, |
61 | const Array & parameters, |
62 | int recursion_level) const |
63 | { |
64 | auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); |
65 | |
66 | /// If one of types is Nullable, we apply aggregate function combinator "Null". |
67 | |
68 | if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), |
69 | [](const auto & type) { return type->isNullable(); })) |
70 | { |
71 | AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null" ); |
72 | if (!combinator) |
73 | throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments." , ErrorCodes::LOGICAL_ERROR); |
74 | |
75 | DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); |
76 | Array nested_parameters = combinator->transformParameters(parameters); |
77 | |
78 | AggregateFunctionPtr nested_function; |
79 | |
80 | /// A little hack - if we have NULL arguments, don't even create nested function. |
81 | /// Combinator will check if nested_function was created. |
82 | if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), |
83 | [](const auto & type) { return type->onlyNull(); })) |
84 | nested_function = getImpl(name, nested_types, nested_parameters, recursion_level); |
85 | |
86 | return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters); |
87 | } |
88 | |
89 | auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); |
90 | if (!res) |
91 | throw Exception("Logical error: AggregateFunctionFactory returned nullptr" , ErrorCodes::LOGICAL_ERROR); |
92 | return res; |
93 | } |
94 | |
95 | |
96 | AggregateFunctionPtr AggregateFunctionFactory::getImpl( |
97 | const String & name_param, |
98 | const DataTypes & argument_types, |
99 | const Array & parameters, |
100 | int recursion_level) const |
101 | { |
102 | String name = getAliasToOrName(name_param); |
103 | /// Find by exact match. |
104 | if (auto it = aggregate_functions.find(name); it != aggregate_functions.end()) |
105 | return it->second(name, argument_types, parameters); |
106 | |
107 | /// Find by case-insensitive name. |
108 | /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. |
109 | if (recursion_level == 0) |
110 | { |
111 | if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end()) |
112 | return it->second(name, argument_types, parameters); |
113 | } |
114 | |
115 | /// Combinators of aggregate functions. |
116 | /// For every aggregate function 'agg' and combiner '-Comb' there is combined aggregate function with name 'aggComb', |
117 | /// that can have different number and/or types of arguments, different result type and different behaviour. |
118 | |
119 | if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name)) |
120 | { |
121 | if (combinator->isForInternalUsageOnly()) |
122 | throw Exception("Aggregate function combinator '" + combinator->getName() + "' is only for internal usage" , ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
123 | |
124 | String nested_name = name.substr(0, name.size() - combinator->getName().size()); |
125 | DataTypes nested_types = combinator->transformArguments(argument_types); |
126 | Array nested_parameters = combinator->transformParameters(parameters); |
127 | |
128 | AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); |
129 | |
130 | return combinator->transformAggregateFunction(nested_function, argument_types, parameters); |
131 | } |
132 | |
133 | auto hints = this->getHints(name); |
134 | if (!hints.empty()) |
135 | throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
136 | else |
137 | throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); |
138 | } |
139 | |
140 | |
141 | AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const |
142 | { |
143 | return isAggregateFunctionName(name) |
144 | ? get(name, argument_types, parameters) |
145 | : nullptr; |
146 | } |
147 | |
148 | |
149 | bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const |
150 | { |
151 | if (aggregate_functions.count(name) || isAlias(name)) |
152 | return true; |
153 | |
154 | String name_lowercase = Poco::toLower(name); |
155 | if (recursion_level == 0 && (case_insensitive_aggregate_functions.count(name_lowercase) || isAlias(name_lowercase))) |
156 | return true; |
157 | |
158 | if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name)) |
159 | return isAggregateFunctionName(name.substr(0, name.size() - combinator->getName().size()), recursion_level + 1); |
160 | |
161 | return false; |
162 | } |
163 | |
164 | AggregateFunctionFactory & AggregateFunctionFactory::instance() |
165 | { |
166 | static AggregateFunctionFactory ret; |
167 | return ret; |
168 | } |
169 | |
170 | } |
171 | |