1 | #include <AggregateFunctions/AggregateFunctionForEach.h> |
---|---|
2 | #include <AggregateFunctions/AggregateFunctionCombinatorFactory.h> |
3 | #include <Common/typeid_cast.h> |
4 | #include "registerAggregateFunctions.h" |
5 | |
6 | |
7 | namespace DB |
8 | { |
9 | |
10 | namespace ErrorCodes |
11 | { |
12 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
13 | } |
14 | |
15 | class AggregateFunctionCombinatorForEach final : public IAggregateFunctionCombinator |
16 | { |
17 | public: |
18 | String getName() const override { return "ForEach"; } |
19 | |
20 | DataTypes transformArguments(const DataTypes & arguments) const override |
21 | { |
22 | DataTypes nested_arguments; |
23 | for (const auto & type : arguments) |
24 | { |
25 | if (const DataTypeArray * array = typeid_cast<const DataTypeArray *>(type.get())) |
26 | nested_arguments.push_back(array->getNestedType()); |
27 | else |
28 | throw Exception("Illegal type "+ type->getName() + " of argument" |
29 | " for aggregate function with "+ getName() + " suffix. Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
30 | } |
31 | |
32 | return nested_arguments; |
33 | } |
34 | |
35 | AggregateFunctionPtr transformAggregateFunction( |
36 | const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override |
37 | { |
38 | return std::make_shared<AggregateFunctionForEach>(nested_function, arguments); |
39 | } |
40 | }; |
41 | |
42 | void registerAggregateFunctionCombinatorForEach(AggregateFunctionCombinatorFactory & factory) |
43 | { |
44 | factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorForEach>()); |
45 | } |
46 | |
47 | } |
48 |