1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <Columns/ColumnAggregateFunction.h>
5#include <DataTypes/DataTypeAggregateFunction.h>
6#include <Common/AlignedBuffer.h>
7#include <Common/Arena.h>
8#include <ext/scope_guard.h>
9
10
11namespace DB
12{
13
14namespace ErrorCodes
15{
16 extern const int ILLEGAL_COLUMN;
17 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
18 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19}
20
21
22/** runningAccumulate(agg_state) - takes the states of the aggregate function and returns a column with values,
23 * are the result of the accumulation of these states for a set of block lines, from the first to the current line.
24 *
25 * Quite unusual function.
26 * Takes state of aggregate function (example runningAccumulate(uniqState(UserID))),
27 * and for each row of block, return result of aggregate function on merge of states of all previous rows and current row.
28 *
29 * So, result of function depends on partition of data to blocks and on order of data in block.
30 */
31class FunctionRunningAccumulate : public IFunction
32{
33public:
34 static constexpr auto name = "runningAccumulate";
35 static FunctionPtr create(const Context &)
36 {
37 return std::make_shared<FunctionRunningAccumulate>();
38 }
39
40 String getName() const override
41 {
42 return name;
43 }
44
45 bool isStateful() const override
46 {
47 return true;
48 }
49
50 bool isVariadic() const override { return true; }
51
52 size_t getNumberOfArguments() const override { return 0; }
53
54 bool isDeterministic() const override { return false; }
55
56 bool isDeterministicInScopeOfQuery() const override
57 {
58 return false;
59 }
60
61 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
62 {
63 if (arguments.size() < 1 || arguments.size() > 2)
64 throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.",
65 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
66
67 const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
68 if (!type)
69 throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
70 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
71
72 return type->getReturnType();
73 }
74
75 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
76 {
77 const ColumnAggregateFunction * column_with_states
78 = typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
79
80 if (!column_with_states)
81 throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
82 + " of first argument of function "
83 + getName(),
84 ErrorCodes::ILLEGAL_COLUMN);
85
86 ColumnPtr column_with_groups;
87
88 if (arguments.size() == 2)
89 column_with_groups = block.getByPosition(arguments[1]).column;
90
91 AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
92 const IAggregateFunction & agg_func = *aggregate_function_ptr;
93
94 AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());
95
96 /// Will pass empty arena if agg_func does not allocate memory in arena
97 std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;
98
99 auto result_column_ptr = agg_func.getReturnType()->createColumn();
100 IColumn & result_column = *result_column_ptr;
101 result_column.reserve(column_with_states->size());
102
103 const auto & states = column_with_states->getData();
104
105 bool state_created = false;
106 SCOPE_EXIT({
107 if (state_created)
108 agg_func.destroy(place.data());
109 });
110
111 size_t row_number = 0;
112 for (const auto & state_to_add : states)
113 {
114 if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0))
115 {
116 if (state_created)
117 {
118 agg_func.destroy(place.data());
119 state_created = false;
120 }
121
122 agg_func.create(place.data());
123 state_created = true;
124 }
125
126 agg_func.merge(place.data(), state_to_add, arena.get());
127 agg_func.insertResultInto(place.data(), result_column);
128
129 ++row_number;
130 }
131
132 block.getByPosition(result).column = std::move(result_column_ptr);
133 }
134};
135
136
137void registerFunctionRunningAccumulate(FunctionFactory & factory)
138{
139 factory.registerFunction<FunctionRunningAccumulate>();
140}
141
142}
143