1#pragma once
2
3#include <Columns/ColumnArray.h>
4#include <Common/assert_cast.h>
5#include <DataTypes/DataTypeArray.h>
6#include <AggregateFunctions/IAggregateFunction.h>
7
8#include <IO/WriteBuffer.h>
9#include <IO/ReadBuffer.h>
10#include <IO/WriteHelpers.h>
11#include <IO/ReadHelpers.h>
12
13
14namespace DB
15{
16
17namespace ErrorCodes
18{
19 extern const int PARAMETER_OUT_OF_BOUND;
20 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
21 extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
22}
23
24
25struct AggregateFunctionForEachData
26{
27 size_t dynamic_array_size = 0;
28 char * array_of_aggregate_datas = nullptr;
29};
30
31/** Adaptor for aggregate functions.
32 * Adding -ForEach suffix to aggregate function
33 * will convert that aggregate function to a function, accepting arrays,
34 * and applies aggregation for each corresponding elements of arrays independently,
35 * returning arrays of aggregated values on corresponding positions.
36 *
37 * Example: sumForEach of:
38 * [1, 2],
39 * [3, 4, 5],
40 * [6, 7]
41 * will return:
42 * [10, 13, 5]
43 *
44 * TODO Allow variable number of arguments.
45 */
46class AggregateFunctionForEach final : public IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>
47{
48private:
49 AggregateFunctionPtr nested_func;
50 size_t nested_size_of_data = 0;
51 size_t num_arguments;
52
53 AggregateFunctionForEachData & ensureAggregateData(AggregateDataPtr place, size_t new_size, Arena & arena) const
54 {
55 AggregateFunctionForEachData & state = data(place);
56
57 /// Ensure we have aggreate states for new_size elements, allocate
58 /// from arena if needed. When reallocating, we can't copy the
59 /// states to new buffer with memcpy, because they may contain pointers
60 /// to themselves. In particular, this happens when a state contains
61 /// a PODArrayWithStackMemory, which stores small number of elements
62 /// inline. This is why we create new empty states in the new buffer,
63 /// and merge the old states to them.
64 size_t old_size = state.dynamic_array_size;
65 if (old_size < new_size)
66 {
67 char * old_state = state.array_of_aggregate_datas;
68 char * new_state = arena.alignedAlloc(
69 new_size * nested_size_of_data,
70 nested_func->alignOfData());
71
72 size_t i;
73 try
74 {
75 for (i = 0; i < new_size; ++i)
76 {
77 nested_func->create(&new_state[i * nested_size_of_data]);
78 }
79 }
80 catch (...)
81 {
82 size_t cleanup_size = i;
83
84 for (i = 0; i < cleanup_size; ++i)
85 {
86 nested_func->destroy(&new_state[i * nested_size_of_data]);
87 }
88
89 throw;
90 }
91
92 for (i = 0; i < old_size; i++)
93 {
94 nested_func->merge(&new_state[i * nested_size_of_data],
95 &old_state[i * nested_size_of_data],
96 &arena);
97 }
98
99 state.array_of_aggregate_datas = new_state;
100 state.dynamic_array_size = new_size;
101 }
102
103 return state;
104 }
105
106public:
107 AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments)
108 : IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, {})
109 , nested_func(nested_), num_arguments(arguments.size())
110 {
111 nested_size_of_data = nested_func->sizeOfData();
112
113 if (arguments.empty())
114 throw Exception("Aggregate function " + getName() + " require at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
115
116 for (const auto & type : arguments)
117 if (!isArray(type))
118 throw Exception("All arguments for aggregate function " + getName() + " must be arrays", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
119 }
120
121 String getName() const override
122 {
123 return nested_func->getName() + "ForEach";
124 }
125
126 DataTypePtr getReturnType() const override
127 {
128 return std::make_shared<DataTypeArray>(nested_func->getReturnType());
129 }
130
131 void destroy(AggregateDataPtr place) const noexcept override
132 {
133 AggregateFunctionForEachData & state = data(place);
134
135 char * nested_state = state.array_of_aggregate_datas;
136 for (size_t i = 0; i < state.dynamic_array_size; ++i)
137 {
138 nested_func->destroy(nested_state);
139 nested_state += nested_size_of_data;
140 }
141 }
142
143 bool hasTrivialDestructor() const override
144 {
145 return nested_func->hasTrivialDestructor();
146 }
147
148 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
149 {
150 const IColumn * nested[num_arguments];
151
152 for (size_t i = 0; i < num_arguments; ++i)
153 nested[i] = &assert_cast<const ColumnArray &>(*columns[i]).getData();
154
155 const ColumnArray & first_array_column = assert_cast<const ColumnArray &>(*columns[0]);
156 const IColumn::Offsets & offsets = first_array_column.getOffsets();
157
158 size_t begin = offsets[row_num - 1];
159 size_t end = offsets[row_num];
160
161 /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
162 for (size_t i = 1; i < num_arguments; ++i)
163 {
164 const ColumnArray & ith_column = assert_cast<const ColumnArray &>(*columns[i]);
165 const IColumn::Offsets & ith_offsets = ith_column.getOffsets();
166
167 if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
168 throw Exception("Arrays passed to " + getName() + " aggregate function have different sizes", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
169 }
170
171 AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena);
172
173 char * nested_state = state.array_of_aggregate_datas;
174 for (size_t i = begin; i < end; ++i)
175 {
176 nested_func->add(nested_state, nested, i, arena);
177 nested_state += nested_size_of_data;
178 }
179 }
180
181 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
182 {
183 const AggregateFunctionForEachData & rhs_state = data(rhs);
184 AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena);
185
186 const char * rhs_nested_state = rhs_state.array_of_aggregate_datas;
187 char * nested_state = state.array_of_aggregate_datas;
188
189 for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i)
190 {
191 nested_func->merge(nested_state, rhs_nested_state, arena);
192
193 rhs_nested_state += nested_size_of_data;
194 nested_state += nested_size_of_data;
195 }
196 }
197
198 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
199 {
200 const AggregateFunctionForEachData & state = data(place);
201 writeBinary(state.dynamic_array_size, buf);
202
203 const char * nested_state = state.array_of_aggregate_datas;
204 for (size_t i = 0; i < state.dynamic_array_size; ++i)
205 {
206 nested_func->serialize(nested_state, buf);
207 nested_state += nested_size_of_data;
208 }
209 }
210
211 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
212 {
213 AggregateFunctionForEachData & state = data(place);
214
215 size_t new_size = 0;
216 readBinary(new_size, buf);
217
218 ensureAggregateData(place, new_size, *arena);
219
220 char * nested_state = state.array_of_aggregate_datas;
221 for (size_t i = 0; i < new_size; ++i)
222 {
223 nested_func->deserialize(nested_state, buf, arena);
224 nested_state += nested_size_of_data;
225 }
226 }
227
228 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
229 {
230 const AggregateFunctionForEachData & state = data(place);
231
232 ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
233 ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
234 IColumn & elems_to = arr_to.getData();
235
236 const char * nested_state = state.array_of_aggregate_datas;
237 for (size_t i = 0; i < state.dynamic_array_size; ++i)
238 {
239 nested_func->insertResultInto(nested_state, elems_to);
240 nested_state += nested_size_of_data;
241 }
242
243 offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
244 }
245
246 bool allocatesMemoryInArena() const override
247 {
248 return true;
249 }
250};
251
252
253}
254