1#include <AggregateFunctions/AggregateFunctionState.h>
2#include <AggregateFunctions/AggregateFunctionMerge.h>
3#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
4#include <DataTypes/DataTypeAggregateFunction.h>
5#include "registerAggregateFunctions.h"
6
7
8namespace DB
9{
10
11namespace ErrorCodes
12{
13 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
14 extern const int BAD_ARGUMENTS;
15}
16
17class AggregateFunctionCombinatorState final : public IAggregateFunctionCombinator
18{
19public:
20 String getName() const override { return "State"; }
21
22 DataTypes transformArguments(const DataTypes & arguments) const override
23 {
24 return arguments;
25 }
26
27 AggregateFunctionPtr transformAggregateFunction(
28 const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
29 {
30 return std::make_shared<AggregateFunctionState>(nested_function, arguments, params);
31 }
32};
33
34void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory & factory)
35{
36 factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorState>());
37}
38
39
40DataTypePtr AggregateFunctionState::getReturnType() const
41{
42 auto ptr = std::make_shared<DataTypeAggregateFunction>(nested_func, arguments, params);
43
44 /// Special case: it is -MergeState combinator.
45 /// We must return AggregateFunction(agg, ...) instead of AggregateFunction(aggMerge, ...)
46 if (typeid_cast<const AggregateFunctionMerge *>(ptr->getFunction().get()))
47 {
48 if (arguments.size() != 1)
49 throw Exception("Combinator -MergeState expects only one argument", ErrorCodes::BAD_ARGUMENTS);
50
51 if (!typeid_cast<const DataTypeAggregateFunction *>(arguments[0].get()))
52 throw Exception("Combinator -MergeState expects argument with AggregateFunction type", ErrorCodes::BAD_ARGUMENTS);
53
54 return arguments[0];
55 }
56 if (arguments.size() > 0)
57 {
58 DataTypePtr argument_type_ptr = arguments[0];
59 WhichDataType which(*argument_type_ptr);
60 if (which.idx == TypeIndex::AggregateFunction)
61 {
62 if (arguments.size() != 1)
63 throw Exception("Nested aggregation expects only one argument", ErrorCodes::BAD_ARGUMENTS);
64 return arguments[0];
65 }
66 }
67
68 return ptr;
69}
70
71}
72