1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
3
4#include <DataTypes/DataTypeAggregateFunction.h>
5#include <DataTypes/DataTypeArray.h>
6#include <DataTypes/DataTypeNullable.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <DataTypes/DataTypeLowCardinality.h>
9
10#include <IO/WriteBuffer.h>
11#include <IO/WriteHelpers.h>
12
13#include <Interpreters/Context.h>
14
15#include <Common/StringUtils/StringUtils.h>
16#include <Common/typeid_cast.h>
17
18#include <Poco/String.h>
19#include "registerAggregateFunctions.h"
20
21
22namespace DB
23{
24
25namespace ErrorCodes
26{
27 extern const int UNKNOWN_AGGREGATE_FUNCTION;
28 extern const int LOGICAL_ERROR;
29}
30
31
32void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness)
33{
34 if (creator == nullptr)
35 throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
36 " a null constructor", ErrorCodes::LOGICAL_ERROR);
37
38 if (!aggregate_functions.emplace(name, creator).second)
39 throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
40 ErrorCodes::LOGICAL_ERROR);
41
42 if (case_sensitiveness == CaseInsensitive
43 && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second)
44 throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
45 ErrorCodes::LOGICAL_ERROR);
46}
47
48static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types)
49{
50 DataTypes res_types;
51 res_types.reserve(types.size());
52 for (const auto & type : types)
53 res_types.emplace_back(recursiveRemoveLowCardinality(type));
54
55 return res_types;
56}
57
58AggregateFunctionPtr AggregateFunctionFactory::get(
59 const String & name,
60 const DataTypes & argument_types,
61 const Array & parameters,
62 int recursion_level) const
63{
64 auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
65
66 /// If one of types is Nullable, we apply aggregate function combinator "Null".
67
68 if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
69 [](const auto & type) { return type->isNullable(); }))
70 {
71 AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null");
72 if (!combinator)
73 throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR);
74
75 DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
76 Array nested_parameters = combinator->transformParameters(parameters);
77
78 AggregateFunctionPtr nested_function;
79
80 /// A little hack - if we have NULL arguments, don't even create nested function.
81 /// Combinator will check if nested_function was created.
82 if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
83 [](const auto & type) { return type->onlyNull(); }))
84 nested_function = getImpl(name, nested_types, nested_parameters, recursion_level);
85
86 return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters);
87 }
88
89 auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level);
90 if (!res)
91 throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR);
92 return res;
93}
94
95
96AggregateFunctionPtr AggregateFunctionFactory::getImpl(
97 const String & name_param,
98 const DataTypes & argument_types,
99 const Array & parameters,
100 int recursion_level) const
101{
102 String name = getAliasToOrName(name_param);
103 /// Find by exact match.
104 if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
105 return it->second(name, argument_types, parameters);
106
107 /// Find by case-insensitive name.
108 /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names.
109 if (recursion_level == 0)
110 {
111 if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end())
112 return it->second(name, argument_types, parameters);
113 }
114
115 /// Combinators of aggregate functions.
116 /// For every aggregate function 'agg' and combiner '-Comb' there is combined aggregate function with name 'aggComb',
117 /// that can have different number and/or types of arguments, different result type and different behaviour.
118
119 if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
120 {
121 if (combinator->isForInternalUsageOnly())
122 throw Exception("Aggregate function combinator '" + combinator->getName() + "' is only for internal usage", ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
123
124 String nested_name = name.substr(0, name.size() - combinator->getName().size());
125 DataTypes nested_types = combinator->transformArguments(argument_types);
126 Array nested_parameters = combinator->transformParameters(parameters);
127
128 AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1);
129
130 return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
131 }
132
133 auto hints = this->getHints(name);
134 if (!hints.empty())
135 throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
136 else
137 throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
138}
139
140
141AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const
142{
143 return isAggregateFunctionName(name)
144 ? get(name, argument_types, parameters)
145 : nullptr;
146}
147
148
149bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const
150{
151 if (aggregate_functions.count(name) || isAlias(name))
152 return true;
153
154 String name_lowercase = Poco::toLower(name);
155 if (recursion_level == 0 && (case_insensitive_aggregate_functions.count(name_lowercase) || isAlias(name_lowercase)))
156 return true;
157
158 if (AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix(name))
159 return isAggregateFunctionName(name.substr(0, name.size() - combinator->getName().size()), recursion_level + 1);
160
161 return false;
162}
163
164AggregateFunctionFactory & AggregateFunctionFactory::instance()
165{
166 static AggregateFunctionFactory ret;
167 return ret;
168}
169
170}
171