1 | #pragma once |
2 | |
3 | #include <Columns/ColumnArray.h> |
4 | #include <Common/assert_cast.h> |
5 | #include <DataTypes/DataTypeArray.h> |
6 | #include <AggregateFunctions/IAggregateFunction.h> |
7 | |
8 | #include <IO/WriteBuffer.h> |
9 | #include <IO/ReadBuffer.h> |
10 | #include <IO/WriteHelpers.h> |
11 | #include <IO/ReadHelpers.h> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | namespace ErrorCodes |
18 | { |
19 | extern const int PARAMETER_OUT_OF_BOUND; |
20 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
21 | extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; |
22 | } |
23 | |
24 | |
25 | struct AggregateFunctionForEachData |
26 | { |
27 | size_t dynamic_array_size = 0; |
28 | char * array_of_aggregate_datas = nullptr; |
29 | }; |
30 | |
31 | /** Adaptor for aggregate functions. |
32 | * Adding -ForEach suffix to aggregate function |
33 | * will convert that aggregate function to a function, accepting arrays, |
34 | * and applies aggregation for each corresponding elements of arrays independently, |
35 | * returning arrays of aggregated values on corresponding positions. |
36 | * |
37 | * Example: sumForEach of: |
38 | * [1, 2], |
39 | * [3, 4, 5], |
40 | * [6, 7] |
41 | * will return: |
42 | * [10, 13, 5] |
43 | * |
44 | * TODO Allow variable number of arguments. |
45 | */ |
46 | class AggregateFunctionForEach final : public IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach> |
47 | { |
48 | private: |
49 | AggregateFunctionPtr nested_func; |
50 | size_t nested_size_of_data = 0; |
51 | size_t num_arguments; |
52 | |
53 | AggregateFunctionForEachData & ensureAggregateData(AggregateDataPtr place, size_t new_size, Arena & arena) const |
54 | { |
55 | AggregateFunctionForEachData & state = data(place); |
56 | |
57 | /// Ensure we have aggreate states for new_size elements, allocate |
58 | /// from arena if needed. When reallocating, we can't copy the |
59 | /// states to new buffer with memcpy, because they may contain pointers |
60 | /// to themselves. In particular, this happens when a state contains |
61 | /// a PODArrayWithStackMemory, which stores small number of elements |
62 | /// inline. This is why we create new empty states in the new buffer, |
63 | /// and merge the old states to them. |
64 | size_t old_size = state.dynamic_array_size; |
65 | if (old_size < new_size) |
66 | { |
67 | char * old_state = state.array_of_aggregate_datas; |
68 | char * new_state = arena.alignedAlloc( |
69 | new_size * nested_size_of_data, |
70 | nested_func->alignOfData()); |
71 | |
72 | size_t i; |
73 | try |
74 | { |
75 | for (i = 0; i < new_size; ++i) |
76 | { |
77 | nested_func->create(&new_state[i * nested_size_of_data]); |
78 | } |
79 | } |
80 | catch (...) |
81 | { |
82 | size_t cleanup_size = i; |
83 | |
84 | for (i = 0; i < cleanup_size; ++i) |
85 | { |
86 | nested_func->destroy(&new_state[i * nested_size_of_data]); |
87 | } |
88 | |
89 | throw; |
90 | } |
91 | |
92 | for (i = 0; i < old_size; i++) |
93 | { |
94 | nested_func->merge(&new_state[i * nested_size_of_data], |
95 | &old_state[i * nested_size_of_data], |
96 | &arena); |
97 | } |
98 | |
99 | state.array_of_aggregate_datas = new_state; |
100 | state.dynamic_array_size = new_size; |
101 | } |
102 | |
103 | return state; |
104 | } |
105 | |
106 | public: |
107 | AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments) |
108 | : IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, {}) |
109 | , nested_func(nested_), num_arguments(arguments.size()) |
110 | { |
111 | nested_size_of_data = nested_func->sizeOfData(); |
112 | |
113 | if (arguments.empty()) |
114 | throw Exception("Aggregate function " + getName() + " require at least one argument" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
115 | |
116 | for (const auto & type : arguments) |
117 | if (!isArray(type)) |
118 | throw Exception("All arguments for aggregate function " + getName() + " must be arrays" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
119 | } |
120 | |
121 | String getName() const override |
122 | { |
123 | return nested_func->getName() + "ForEach" ; |
124 | } |
125 | |
126 | DataTypePtr getReturnType() const override |
127 | { |
128 | return std::make_shared<DataTypeArray>(nested_func->getReturnType()); |
129 | } |
130 | |
131 | void destroy(AggregateDataPtr place) const noexcept override |
132 | { |
133 | AggregateFunctionForEachData & state = data(place); |
134 | |
135 | char * nested_state = state.array_of_aggregate_datas; |
136 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
137 | { |
138 | nested_func->destroy(nested_state); |
139 | nested_state += nested_size_of_data; |
140 | } |
141 | } |
142 | |
143 | bool hasTrivialDestructor() const override |
144 | { |
145 | return nested_func->hasTrivialDestructor(); |
146 | } |
147 | |
148 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
149 | { |
150 | const IColumn * nested[num_arguments]; |
151 | |
152 | for (size_t i = 0; i < num_arguments; ++i) |
153 | nested[i] = &assert_cast<const ColumnArray &>(*columns[i]).getData(); |
154 | |
155 | const ColumnArray & first_array_column = assert_cast<const ColumnArray &>(*columns[0]); |
156 | const IColumn::Offsets & offsets = first_array_column.getOffsets(); |
157 | |
158 | size_t begin = offsets[row_num - 1]; |
159 | size_t end = offsets[row_num]; |
160 | |
161 | /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. |
162 | for (size_t i = 1; i < num_arguments; ++i) |
163 | { |
164 | const ColumnArray & ith_column = assert_cast<const ColumnArray &>(*columns[i]); |
165 | const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); |
166 | |
167 | if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) |
168 | throw Exception("Arrays passed to " + getName() + " aggregate function have different sizes" , ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); |
169 | } |
170 | |
171 | AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena); |
172 | |
173 | char * nested_state = state.array_of_aggregate_datas; |
174 | for (size_t i = begin; i < end; ++i) |
175 | { |
176 | nested_func->add(nested_state, nested, i, arena); |
177 | nested_state += nested_size_of_data; |
178 | } |
179 | } |
180 | |
181 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
182 | { |
183 | const AggregateFunctionForEachData & rhs_state = data(rhs); |
184 | AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena); |
185 | |
186 | const char * rhs_nested_state = rhs_state.array_of_aggregate_datas; |
187 | char * nested_state = state.array_of_aggregate_datas; |
188 | |
189 | for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) |
190 | { |
191 | nested_func->merge(nested_state, rhs_nested_state, arena); |
192 | |
193 | rhs_nested_state += nested_size_of_data; |
194 | nested_state += nested_size_of_data; |
195 | } |
196 | } |
197 | |
198 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
199 | { |
200 | const AggregateFunctionForEachData & state = data(place); |
201 | writeBinary(state.dynamic_array_size, buf); |
202 | |
203 | const char * nested_state = state.array_of_aggregate_datas; |
204 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
205 | { |
206 | nested_func->serialize(nested_state, buf); |
207 | nested_state += nested_size_of_data; |
208 | } |
209 | } |
210 | |
211 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
212 | { |
213 | AggregateFunctionForEachData & state = data(place); |
214 | |
215 | size_t new_size = 0; |
216 | readBinary(new_size, buf); |
217 | |
218 | ensureAggregateData(place, new_size, *arena); |
219 | |
220 | char * nested_state = state.array_of_aggregate_datas; |
221 | for (size_t i = 0; i < new_size; ++i) |
222 | { |
223 | nested_func->deserialize(nested_state, buf, arena); |
224 | nested_state += nested_size_of_data; |
225 | } |
226 | } |
227 | |
228 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
229 | { |
230 | const AggregateFunctionForEachData & state = data(place); |
231 | |
232 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
233 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
234 | IColumn & elems_to = arr_to.getData(); |
235 | |
236 | const char * nested_state = state.array_of_aggregate_datas; |
237 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
238 | { |
239 | nested_func->insertResultInto(nested_state, elems_to); |
240 | nested_state += nested_size_of_data; |
241 | } |
242 | |
243 | offsets_to.push_back(offsets_to.back() + state.dynamic_array_size); |
244 | } |
245 | |
246 | bool allocatesMemoryInArena() const override |
247 | { |
248 | return true; |
249 | } |
250 | }; |
251 | |
252 | |
253 | } |
254 | |