1#pragma once
2
3#include <IO/WriteHelpers.h>
4#include <IO/ReadHelpers.h>
5
6#include <DataTypes/DataTypeArray.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <DataTypes/DataTypeString.h>
9
10#include <Columns/ColumnVector.h>
11#include <Columns/ColumnArray.h>
12#include <Columns/ColumnString.h>
13
14#include <Common/ArenaAllocator.h>
15#include <Common/assert_cast.h>
16
17#include <AggregateFunctions/IAggregateFunction.h>
18
19#include <common/likely.h>
20#include <type_traits>
21
22#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF
23
24
25namespace DB
26{
27
28namespace ErrorCodes
29{
30 extern const int TOO_LARGE_ARRAY_SIZE;
31 extern const int LOGICAL_ERROR;
32}
33
34
35/// A particular case is an implementation for numeric types.
36template <typename T>
37struct GroupArrayNumericData
38{
39 // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
40 using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
41 using Array = PODArray<T, 32, Allocator>;
42
43 Array value;
44};
45
46
47template <typename T, typename Tlimit_num_elems>
48class GroupArrayNumericImpl final
49 : public IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>
50{
51 static constexpr bool limit_num_elems = Tlimit_num_elems::value;
52 DataTypePtr & data_type;
53 UInt64 max_elems;
54
55public:
56 explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
57 : IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>({data_type_}, {})
58 , data_type(this->argument_types[0]), max_elems(max_elems_) {}
59
60 String getName() const override { return "groupArray"; }
61
62 DataTypePtr getReturnType() const override
63 {
64 return std::make_shared<DataTypeArray>(data_type);
65 }
66
67 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
68 {
69 if (limit_num_elems && this->data(place).value.size() >= max_elems)
70 return;
71
72 this->data(place).value.push_back(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], arena);
73 }
74
75 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
76 {
77 auto & cur_elems = this->data(place);
78 auto & rhs_elems = this->data(rhs);
79
80 if (!limit_num_elems)
81 {
82 if (rhs_elems.value.size())
83 cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
84 }
85 else
86 {
87 UInt64 elems_to_insert = std::min(static_cast<size_t>(max_elems) - cur_elems.value.size(), rhs_elems.value.size());
88 if (elems_to_insert)
89 cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.begin() + elems_to_insert, arena);
90 }
91 }
92
93 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
94 {
95 const auto & value = this->data(place).value;
96 size_t size = value.size();
97 writeVarUInt(size, buf);
98 buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
99 }
100
101 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
102 {
103 size_t size = 0;
104 readVarUInt(size, buf);
105
106 if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
107 throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
108
109 if (limit_num_elems && unlikely(size > max_elems))
110 throw Exception("Too large array size, it should not exceed " + toString(max_elems), ErrorCodes::TOO_LARGE_ARRAY_SIZE);
111
112 auto & value = this->data(place).value;
113
114 value.resize(size, arena);
115 buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
116 }
117
118 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
119 {
120 const auto & value = this->data(place).value;
121 size_t size = value.size();
122
123 ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
124 ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
125
126 offsets_to.push_back(offsets_to.back() + size);
127
128 if (size)
129 {
130 typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
131 data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
132 }
133 }
134
135 bool allocatesMemoryInArena() const override
136 {
137 return true;
138 }
139};
140
141
142/// General case
143
144
145/// Nodes used to implement a linked list for storage of groupArray states
146
147template <typename Node>
148struct GroupArrayListNodeBase
149{
150 Node * next;
151 UInt64 size; // size of payload
152
153 /// Returns pointer to actual payload
154 char * data()
155 {
156 static_assert(sizeof(GroupArrayListNodeBase) == sizeof(Node));
157 return reinterpret_cast<char *>(this) + sizeof(Node);
158 }
159
160 /// Clones existing node (does not modify next field)
161 Node * clone(Arena * arena)
162 {
163 return reinterpret_cast<Node *>(const_cast<char *>(arena->alignedInsert(reinterpret_cast<char *>(this), sizeof(Node) + size, alignof(Node))));
164 }
165
166 /// Write node to buffer
167 void write(WriteBuffer & buf)
168 {
169 writeVarUInt(size, buf);
170 buf.write(data(), size);
171 }
172
173 /// Reads and allocates node from ReadBuffer's data (doesn't set next)
174 static Node * read(ReadBuffer & buf, Arena * arena)
175 {
176 UInt64 size;
177 readVarUInt(size, buf);
178
179 Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
180 node->size = size;
181 buf.read(node->data(), size);
182 return node;
183 }
184};
185
186struct GroupArrayListNodeString : public GroupArrayListNodeBase<GroupArrayListNodeString>
187{
188 using Node = GroupArrayListNodeString;
189
190 /// Create node from string
191 static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
192 {
193 StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
194
195 Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node)));
196 node->next = nullptr;
197 node->size = string.size;
198 memcpy(node->data(), string.data, string.size);
199
200 return node;
201 }
202
203 void insertInto(IColumn & column)
204 {
205 assert_cast<ColumnString &>(column).insertData(data(), size);
206 }
207};
208
209struct GroupArrayListNodeGeneral : public GroupArrayListNodeBase<GroupArrayListNodeGeneral>
210{
211 using Node = GroupArrayListNodeGeneral;
212
213 static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
214 {
215 const char * begin = arena->alignedAlloc(sizeof(Node), alignof(Node));
216 StringRef value = column.serializeValueIntoArena(row_num, *arena, begin);
217
218 Node * node = reinterpret_cast<Node *>(const_cast<char *>(begin));
219 node->next = nullptr;
220 node->size = value.size;
221
222 return node;
223 }
224
225 void insertInto(IColumn & column)
226 {
227 column.deserializeAndInsertFromArena(data());
228 }
229};
230
231
232template <typename Node>
233struct GroupArrayGeneralListData
234{
235 UInt64 elems = 0;
236 Node * first = nullptr;
237 Node * last = nullptr;
238};
239
240
241/// Implementation of groupArray for String or any ComplexObject via linked list
242/// It has poor performance in case of many small objects
243template <typename Node, bool limit_num_elems>
244class GroupArrayGeneralListImpl final
245 : public IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, limit_num_elems>>
246{
247 using Data = GroupArrayGeneralListData<Node>;
248 static Data & data(AggregateDataPtr place) { return *reinterpret_cast<Data*>(place); }
249 static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
250
251 DataTypePtr & data_type;
252 UInt64 max_elems;
253
254public:
255 GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
256 : IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, limit_num_elems>>({data_type_}, {})
257 , data_type(this->argument_types[0]), max_elems(max_elems_) {}
258
259 String getName() const override { return "groupArray"; }
260
261 DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
262
263 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
264 {
265 if (limit_num_elems && data(place).elems >= max_elems)
266 return;
267
268 Node * node = Node::allocate(*columns[0], row_num, arena);
269
270 if (unlikely(!data(place).first))
271 {
272 data(place).first = node;
273 data(place).last = node;
274 }
275 else
276 {
277 data(place).last->next = node;
278 data(place).last = node;
279 }
280
281 ++data(place).elems;
282 }
283
284 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
285 {
286 /// It is sadly, but rhs's Arena could be destroyed
287
288 if (!data(rhs).first) /// rhs state is empty
289 return;
290
291 UInt64 new_elems;
292 UInt64 cur_elems = data(place).elems;
293 if (limit_num_elems)
294 {
295 if (data(place).elems >= max_elems)
296 return;
297
298 new_elems = std::min(data(place).elems + data(rhs).elems, max_elems);
299 }
300 else
301 {
302 new_elems = data(place).elems + data(rhs).elems;
303 }
304
305 Node * p_rhs = data(rhs).first;
306 Node * p_lhs;
307
308 if (unlikely(!data(place).last)) /// lhs state is empty
309 {
310 p_lhs = p_rhs->clone(arena);
311 data(place).first = data(place).last = p_lhs;
312 p_rhs = p_rhs->next;
313 ++cur_elems;
314 }
315 else
316 {
317 p_lhs = data(place).last;
318 }
319
320 for (; cur_elems < new_elems; ++cur_elems)
321 {
322 Node * p_new = p_rhs->clone(arena);
323 p_lhs->next = p_new;
324 p_rhs = p_rhs->next;
325 p_lhs = p_new;
326 }
327
328 p_lhs->next = nullptr;
329 data(place).last = p_lhs;
330 data(place).elems = new_elems;
331 }
332
333 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
334 {
335 writeVarUInt(data(place).elems, buf);
336
337 Node * p = data(place).first;
338 while (p)
339 {
340 p->write(buf);
341 p = p->next;
342 }
343 }
344
345 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
346 {
347 UInt64 elems;
348 readVarUInt(elems, buf);
349 data(place).elems = elems;
350
351 if (unlikely(elems == 0))
352 return;
353
354 if (unlikely(elems > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
355 throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
356
357 if (limit_num_elems && unlikely(elems > max_elems))
358 throw Exception("Too large array size, it should not exceed " + toString(max_elems), ErrorCodes::TOO_LARGE_ARRAY_SIZE);
359
360 Node * prev = Node::read(buf, arena);
361 data(place).first = prev;
362
363 for (UInt64 i = 1; i < elems; ++i)
364 {
365 Node * cur = Node::read(buf, arena);
366 prev->next = cur;
367 prev = cur;
368 }
369
370 prev->next = nullptr;
371 data(place).last = prev;
372 }
373
374 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
375 {
376 auto & column_array = assert_cast<ColumnArray &>(to);
377
378 auto & offsets = column_array.getOffsets();
379 offsets.push_back(offsets.back() + data(place).elems);
380
381 auto & column_data = column_array.getData();
382
383 if (std::is_same_v<Node, GroupArrayListNodeString>)
384 {
385 auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
386 string_offsets.reserve(string_offsets.size() + data(place).elems);
387 }
388
389 Node * p = data(place).first;
390 while (p)
391 {
392 p->insertInto(column_data);
393 p = p->next;
394 }
395 }
396
397 bool allocatesMemoryInArena() const override
398 {
399 return true;
400 }
401};
402
403#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE
404
405}
406