1 | #pragma once |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <DataTypes/DataTypesNumber.h> |
8 | #include <DataTypes/DataTypeString.h> |
9 | |
10 | #include <Columns/ColumnVector.h> |
11 | #include <Columns/ColumnArray.h> |
12 | #include <Columns/ColumnString.h> |
13 | |
14 | #include <Common/ArenaAllocator.h> |
15 | #include <Common/assert_cast.h> |
16 | |
17 | #include <AggregateFunctions/IAggregateFunction.h> |
18 | |
19 | #include <common/likely.h> |
20 | #include <type_traits> |
21 | |
22 | #define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF |
23 | |
24 | |
25 | namespace DB |
26 | { |
27 | |
28 | namespace ErrorCodes |
29 | { |
30 | extern const int TOO_LARGE_ARRAY_SIZE; |
31 | extern const int LOGICAL_ERROR; |
32 | } |
33 | |
34 | |
35 | /// A particular case is an implementation for numeric types. |
36 | template <typename T> |
37 | struct GroupArrayNumericData |
38 | { |
39 | // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena |
40 | using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>; |
41 | using Array = PODArray<T, 32, Allocator>; |
42 | |
43 | Array value; |
44 | }; |
45 | |
46 | |
47 | template <typename T, typename Tlimit_num_elems> |
48 | class GroupArrayNumericImpl final |
49 | : public IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>> |
50 | { |
51 | static constexpr bool limit_num_elems = Tlimit_num_elems::value; |
52 | DataTypePtr & data_type; |
53 | UInt64 max_elems; |
54 | |
55 | public: |
56 | explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) |
57 | : IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>({data_type_}, {}) |
58 | , data_type(this->argument_types[0]), max_elems(max_elems_) {} |
59 | |
60 | String getName() const override { return "groupArray" ; } |
61 | |
62 | DataTypePtr getReturnType() const override |
63 | { |
64 | return std::make_shared<DataTypeArray>(data_type); |
65 | } |
66 | |
67 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
68 | { |
69 | if (limit_num_elems && this->data(place).value.size() >= max_elems) |
70 | return; |
71 | |
72 | this->data(place).value.push_back(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], arena); |
73 | } |
74 | |
75 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
76 | { |
77 | auto & cur_elems = this->data(place); |
78 | auto & rhs_elems = this->data(rhs); |
79 | |
80 | if (!limit_num_elems) |
81 | { |
82 | if (rhs_elems.value.size()) |
83 | cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena); |
84 | } |
85 | else |
86 | { |
87 | UInt64 elems_to_insert = std::min(static_cast<size_t>(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); |
88 | if (elems_to_insert) |
89 | cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.begin() + elems_to_insert, arena); |
90 | } |
91 | } |
92 | |
93 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
94 | { |
95 | const auto & value = this->data(place).value; |
96 | size_t size = value.size(); |
97 | writeVarUInt(size, buf); |
98 | buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0])); |
99 | } |
100 | |
101 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
102 | { |
103 | size_t size = 0; |
104 | readVarUInt(size, buf); |
105 | |
106 | if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE)) |
107 | throw Exception("Too large array size" , ErrorCodes::TOO_LARGE_ARRAY_SIZE); |
108 | |
109 | if (limit_num_elems && unlikely(size > max_elems)) |
110 | throw Exception("Too large array size, it should not exceed " + toString(max_elems), ErrorCodes::TOO_LARGE_ARRAY_SIZE); |
111 | |
112 | auto & value = this->data(place).value; |
113 | |
114 | value.resize(size, arena); |
115 | buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0])); |
116 | } |
117 | |
118 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
119 | { |
120 | const auto & value = this->data(place).value; |
121 | size_t size = value.size(); |
122 | |
123 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
124 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
125 | |
126 | offsets_to.push_back(offsets_to.back() + size); |
127 | |
128 | if (size) |
129 | { |
130 | typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData(); |
131 | data_to.insert(this->data(place).value.begin(), this->data(place).value.end()); |
132 | } |
133 | } |
134 | |
135 | bool allocatesMemoryInArena() const override |
136 | { |
137 | return true; |
138 | } |
139 | }; |
140 | |
141 | |
142 | /// General case |
143 | |
144 | |
145 | /// Nodes used to implement a linked list for storage of groupArray states |
146 | |
147 | template <typename Node> |
148 | struct GroupArrayListNodeBase |
149 | { |
150 | Node * next; |
151 | UInt64 size; // size of payload |
152 | |
153 | /// Returns pointer to actual payload |
154 | char * data() |
155 | { |
156 | static_assert(sizeof(GroupArrayListNodeBase) == sizeof(Node)); |
157 | return reinterpret_cast<char *>(this) + sizeof(Node); |
158 | } |
159 | |
160 | /// Clones existing node (does not modify next field) |
161 | Node * clone(Arena * arena) |
162 | { |
163 | return reinterpret_cast<Node *>(const_cast<char *>(arena->alignedInsert(reinterpret_cast<char *>(this), sizeof(Node) + size, alignof(Node)))); |
164 | } |
165 | |
166 | /// Write node to buffer |
167 | void write(WriteBuffer & buf) |
168 | { |
169 | writeVarUInt(size, buf); |
170 | buf.write(data(), size); |
171 | } |
172 | |
173 | /// Reads and allocates node from ReadBuffer's data (doesn't set next) |
174 | static Node * read(ReadBuffer & buf, Arena * arena) |
175 | { |
176 | UInt64 size; |
177 | readVarUInt(size, buf); |
178 | |
179 | Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node))); |
180 | node->size = size; |
181 | buf.read(node->data(), size); |
182 | return node; |
183 | } |
184 | }; |
185 | |
186 | struct GroupArrayListNodeString : public GroupArrayListNodeBase<GroupArrayListNodeString> |
187 | { |
188 | using Node = GroupArrayListNodeString; |
189 | |
190 | /// Create node from string |
191 | static Node * allocate(const IColumn & column, size_t row_num, Arena * arena) |
192 | { |
193 | StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num); |
194 | |
195 | Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node))); |
196 | node->next = nullptr; |
197 | node->size = string.size; |
198 | memcpy(node->data(), string.data, string.size); |
199 | |
200 | return node; |
201 | } |
202 | |
203 | void insertInto(IColumn & column) |
204 | { |
205 | assert_cast<ColumnString &>(column).insertData(data(), size); |
206 | } |
207 | }; |
208 | |
209 | struct GroupArrayListNodeGeneral : public GroupArrayListNodeBase<GroupArrayListNodeGeneral> |
210 | { |
211 | using Node = GroupArrayListNodeGeneral; |
212 | |
213 | static Node * allocate(const IColumn & column, size_t row_num, Arena * arena) |
214 | { |
215 | const char * begin = arena->alignedAlloc(sizeof(Node), alignof(Node)); |
216 | StringRef value = column.serializeValueIntoArena(row_num, *arena, begin); |
217 | |
218 | Node * node = reinterpret_cast<Node *>(const_cast<char *>(begin)); |
219 | node->next = nullptr; |
220 | node->size = value.size; |
221 | |
222 | return node; |
223 | } |
224 | |
225 | void insertInto(IColumn & column) |
226 | { |
227 | column.deserializeAndInsertFromArena(data()); |
228 | } |
229 | }; |
230 | |
231 | |
232 | template <typename Node> |
233 | struct GroupArrayGeneralListData |
234 | { |
235 | UInt64 elems = 0; |
236 | Node * first = nullptr; |
237 | Node * last = nullptr; |
238 | }; |
239 | |
240 | |
241 | /// Implementation of groupArray for String or any ComplexObject via linked list |
242 | /// It has poor performance in case of many small objects |
243 | template <typename Node, bool limit_num_elems> |
244 | class GroupArrayGeneralListImpl final |
245 | : public IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, limit_num_elems>> |
246 | { |
247 | using Data = GroupArrayGeneralListData<Node>; |
248 | static Data & data(AggregateDataPtr place) { return *reinterpret_cast<Data*>(place); } |
249 | static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); } |
250 | |
251 | DataTypePtr & data_type; |
252 | UInt64 max_elems; |
253 | |
254 | public: |
255 | GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) |
256 | : IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, limit_num_elems>>({data_type_}, {}) |
257 | , data_type(this->argument_types[0]), max_elems(max_elems_) {} |
258 | |
259 | String getName() const override { return "groupArray" ; } |
260 | |
261 | DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); } |
262 | |
263 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
264 | { |
265 | if (limit_num_elems && data(place).elems >= max_elems) |
266 | return; |
267 | |
268 | Node * node = Node::allocate(*columns[0], row_num, arena); |
269 | |
270 | if (unlikely(!data(place).first)) |
271 | { |
272 | data(place).first = node; |
273 | data(place).last = node; |
274 | } |
275 | else |
276 | { |
277 | data(place).last->next = node; |
278 | data(place).last = node; |
279 | } |
280 | |
281 | ++data(place).elems; |
282 | } |
283 | |
284 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
285 | { |
286 | /// It is sadly, but rhs's Arena could be destroyed |
287 | |
288 | if (!data(rhs).first) /// rhs state is empty |
289 | return; |
290 | |
291 | UInt64 new_elems; |
292 | UInt64 cur_elems = data(place).elems; |
293 | if (limit_num_elems) |
294 | { |
295 | if (data(place).elems >= max_elems) |
296 | return; |
297 | |
298 | new_elems = std::min(data(place).elems + data(rhs).elems, max_elems); |
299 | } |
300 | else |
301 | { |
302 | new_elems = data(place).elems + data(rhs).elems; |
303 | } |
304 | |
305 | Node * p_rhs = data(rhs).first; |
306 | Node * p_lhs; |
307 | |
308 | if (unlikely(!data(place).last)) /// lhs state is empty |
309 | { |
310 | p_lhs = p_rhs->clone(arena); |
311 | data(place).first = data(place).last = p_lhs; |
312 | p_rhs = p_rhs->next; |
313 | ++cur_elems; |
314 | } |
315 | else |
316 | { |
317 | p_lhs = data(place).last; |
318 | } |
319 | |
320 | for (; cur_elems < new_elems; ++cur_elems) |
321 | { |
322 | Node * p_new = p_rhs->clone(arena); |
323 | p_lhs->next = p_new; |
324 | p_rhs = p_rhs->next; |
325 | p_lhs = p_new; |
326 | } |
327 | |
328 | p_lhs->next = nullptr; |
329 | data(place).last = p_lhs; |
330 | data(place).elems = new_elems; |
331 | } |
332 | |
333 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
334 | { |
335 | writeVarUInt(data(place).elems, buf); |
336 | |
337 | Node * p = data(place).first; |
338 | while (p) |
339 | { |
340 | p->write(buf); |
341 | p = p->next; |
342 | } |
343 | } |
344 | |
345 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
346 | { |
347 | UInt64 elems; |
348 | readVarUInt(elems, buf); |
349 | data(place).elems = elems; |
350 | |
351 | if (unlikely(elems == 0)) |
352 | return; |
353 | |
354 | if (unlikely(elems > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE)) |
355 | throw Exception("Too large array size" , ErrorCodes::TOO_LARGE_ARRAY_SIZE); |
356 | |
357 | if (limit_num_elems && unlikely(elems > max_elems)) |
358 | throw Exception("Too large array size, it should not exceed " + toString(max_elems), ErrorCodes::TOO_LARGE_ARRAY_SIZE); |
359 | |
360 | Node * prev = Node::read(buf, arena); |
361 | data(place).first = prev; |
362 | |
363 | for (UInt64 i = 1; i < elems; ++i) |
364 | { |
365 | Node * cur = Node::read(buf, arena); |
366 | prev->next = cur; |
367 | prev = cur; |
368 | } |
369 | |
370 | prev->next = nullptr; |
371 | data(place).last = prev; |
372 | } |
373 | |
374 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
375 | { |
376 | auto & column_array = assert_cast<ColumnArray &>(to); |
377 | |
378 | auto & offsets = column_array.getOffsets(); |
379 | offsets.push_back(offsets.back() + data(place).elems); |
380 | |
381 | auto & column_data = column_array.getData(); |
382 | |
383 | if (std::is_same_v<Node, GroupArrayListNodeString>) |
384 | { |
385 | auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets(); |
386 | string_offsets.reserve(string_offsets.size() + data(place).elems); |
387 | } |
388 | |
389 | Node * p = data(place).first; |
390 | while (p) |
391 | { |
392 | p->insertInto(column_data); |
393 | p = p->next; |
394 | } |
395 | } |
396 | |
397 | bool allocatesMemoryInArena() const override |
398 | { |
399 | return true; |
400 | } |
401 | }; |
402 | |
403 | #undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE |
404 | |
405 | } |
406 | |