1#include <Columns/ColumnAggregateFunction.h>
2#include <Columns/ColumnsCommon.h>
3#include <Common/assert_cast.h>
4#include <AggregateFunctions/AggregateFunctionState.h>
5#include <DataStreams/ColumnGathererStream.h>
6#include <IO/WriteBufferFromArena.h>
7#include <IO/WriteBufferFromString.h>
8#include <IO/Operators.h>
9#include <Common/SipHash.h>
10#include <Common/AlignedBuffer.h>
11#include <Common/typeid_cast.h>
12#include <Common/Arena.h>
13
14#include <AggregateFunctions/AggregateFunctionMLMethod.h>
15
16namespace DB
17{
18
19namespace ErrorCodes
20{
21 extern const int PARAMETER_OUT_OF_BOUND;
22 extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
23 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
24}
25
26
27ColumnAggregateFunction::~ColumnAggregateFunction()
28{
29 if (!func->hasTrivialDestructor() && !src)
30 for (auto val : data)
31 func->destroy(val);
32}
33
34void ColumnAggregateFunction::addArena(ConstArenaPtr arena_)
35{
36 foreign_arenas.push_back(arena_);
37}
38
39MutableColumnPtr ColumnAggregateFunction::convertToValues() const
40{
41 /** If the aggregate function returns an unfinalized/unfinished state,
42 * then you just need to copy pointers to it and also shared ownership of data.
43 *
44 * Also replace the aggregate function with the nested function.
45 * That is, if this column is the states of the aggregate function `aggState`,
46 * then we return the same column, but with the states of the aggregate function `agg`.
47 * These are the same states, changing only the function to which they correspond.
48 *
49 * Further is quite difficult to understand.
50 * Example when this happens:
51 *
52 * SELECT k, finalizeAggregation(quantileTimingState(0.5)(x)) FROM ... GROUP BY k WITH TOTALS
53 *
54 * This calculates the aggregate function `quantileTimingState`.
55 * Its return type AggregateFunction(quantileTiming(0.5), UInt64)`.
56 * Due to the presence of WITH TOTALS, during aggregation the states of this aggregate function will be stored
57 * in the ColumnAggregateFunction column of type
58 * AggregateFunction(quantileTimingState(0.5), UInt64).
59 * Then, in `TotalsHavingBlockInputStream`, it will be called `convertToValues` method,
60 * to get the "ready" values.
61 * But it just converts a column of type
62 * `AggregateFunction(quantileTimingState(0.5), UInt64)`
63 * into `AggregateFunction(quantileTiming(0.5), UInt64)`
64 * - in the same states.
65 *
66 * Then `finalizeAggregation` function will be calculated, which will call `convertToValues` already on the result.
67 * And this converts a column of type
68 * AggregateFunction(quantileTiming(0.5), UInt64)
69 * into UInt16 - already finished result of `quantileTiming`.
70 */
71 if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
72 {
73 auto res = createView();
74 res->set(function_state->getNestedFunction());
75 res->data.assign(data.begin(), data.end());
76 return res;
77 }
78
79 MutableColumnPtr res = func->getReturnType()->createColumn();
80 res->reserve(data.size());
81
82 for (auto val : data)
83 func->insertResultInto(val, *res);
84
85 return res;
86}
87
88MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
89{
90 MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
91 res->reserve(data.size());
92
93 auto ML_function = func.get();
94 if (ML_function)
95 {
96 if (data.size() == 1)
97 {
98 /// Case for const column. Predict using single model.
99 ML_function->predictValues(data[0], *res, block, 0, block.rows(), arguments, context);
100 }
101 else
102 {
103 /// Case for non-constant column. Use different aggregate function for each row.
104 size_t row_num = 0;
105 for (auto val : data)
106 {
107 ML_function->predictValues(val, *res, block, row_num, 1, arguments, context);
108 ++row_num;
109 }
110 }
111 }
112 else
113 {
114 throw Exception("Illegal aggregate function is passed",
115 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
116 }
117 return res;
118}
119
120void ColumnAggregateFunction::ensureOwnership()
121{
122 if (src)
123 {
124 /// We must copy all data from src and take ownership.
125 size_t size = data.size();
126
127 Arena & arena = createOrGetArena();
128 size_t size_of_state = func->sizeOfData();
129 size_t align_of_state = func->alignOfData();
130
131 size_t rollback_pos = 0;
132 try
133 {
134 for (size_t i = 0; i < size; ++i)
135 {
136 ConstAggregateDataPtr old_place = data[i];
137 data[i] = arena.alignedAlloc(size_of_state, align_of_state);
138 func->create(data[i]);
139 ++rollback_pos;
140 func->merge(data[i], old_place, &arena);
141 }
142 }
143 catch (...)
144 {
145 /// If we failed to take ownership, destroy all temporary data.
146
147 if (!func->hasTrivialDestructor())
148 for (size_t i = 0; i < rollback_pos; ++i)
149 func->destroy(data[i]);
150
151 throw;
152 }
153
154 /// Now we own all data.
155 src.reset();
156 }
157}
158
159
160void ColumnAggregateFunction::insertRangeFrom(const IColumn & from, size_t start, size_t length)
161{
162 const ColumnAggregateFunction & from_concrete = assert_cast<const ColumnAggregateFunction &>(from);
163
164 if (start + length > from_concrete.data.size())
165 throw Exception("Parameters start = " + toString(start) + ", length = " + toString(length)
166 + " are out of bound in ColumnAggregateFunction::insertRangeFrom method"
167 " (data.size() = "
168 + toString(from_concrete.data.size())
169 + ").",
170 ErrorCodes::PARAMETER_OUT_OF_BOUND);
171
172 if (!empty() && src.get() != &from_concrete)
173 {
174 /// Must create new states of aggregate function and take ownership of it,
175 /// because ownership of states of aggregate function cannot be shared for individual rows,
176 /// (only as a whole).
177
178 size_t end = start + length;
179 for (size_t i = start; i < end; ++i)
180 insertFrom(from, i);
181 }
182 else
183 {
184 /// Keep shared ownership of aggregation states.
185 src = from_concrete.getPtr();
186
187 size_t old_size = data.size();
188 data.resize(old_size + length);
189 memcpy(data.data() + old_size, &from_concrete.data[start], length * sizeof(data[0]));
190 }
191}
192
193
194ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_size_hint) const
195{
196 size_t size = data.size();
197 if (size != filter.size())
198 throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
199
200 if (size == 0)
201 return cloneEmpty();
202
203 auto res = createView();
204 auto & res_data = res->data;
205
206 if (result_size_hint)
207 res_data.reserve(result_size_hint > 0 ? result_size_hint : size);
208
209 for (size_t i = 0; i < size; ++i)
210 if (filter[i])
211 res_data.push_back(data[i]);
212
213 /// To save RAM in case of too strong filtering.
214 if (res_data.size() * 2 < res_data.capacity())
215 res_data = Container(res_data.cbegin(), res_data.cend());
216
217 return res;
218}
219
220
221ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limit) const
222{
223 size_t size = data.size();
224
225 if (limit == 0)
226 limit = size;
227 else
228 limit = std::min(size, limit);
229
230 if (perm.size() < limit)
231 throw Exception("Size of permutation is less than required.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
232
233 auto res = createView();
234
235 res->data.resize(limit);
236 for (size_t i = 0; i < limit; ++i)
237 res->data[i] = data[perm[i]];
238
239 return res;
240}
241
242ColumnPtr ColumnAggregateFunction::index(const IColumn & indexes, size_t limit) const
243{
244 return selectIndexImpl(*this, indexes, limit);
245}
246
247template <typename Type>
248ColumnPtr ColumnAggregateFunction::indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const
249{
250 auto res = createView();
251
252 res->data.resize(limit);
253 for (size_t i = 0; i < limit; ++i)
254 res->data[i] = data[indexes[i]];
255
256 return res;
257}
258
259INSTANTIATE_INDEX_IMPL(ColumnAggregateFunction)
260
261/// Is required to support operations with Set
262void ColumnAggregateFunction::updateHashWithValue(size_t n, SipHash & hash) const
263{
264 WriteBufferFromOwnString wbuf;
265 func->serialize(data[n], wbuf);
266 hash.update(wbuf.str().c_str(), wbuf.str().size());
267}
268
269/// The returned size is less than real size. The reason is that some parts of
270/// aggregate function data may be allocated on shared arenas. These arenas are
271/// used for several blocks, and also may be updated concurrently from other
272/// threads, so we can't know the size of these data.
273size_t ColumnAggregateFunction::byteSize() const
274{
275 return data.size() * sizeof(data[0])
276 + (my_arena ? my_arena->size() : 0);
277}
278
279/// Like in byteSize(), the size is underestimated.
280size_t ColumnAggregateFunction::allocatedBytes() const
281{
282 return data.allocated_bytes()
283 + (my_arena ? my_arena->size() : 0);
284}
285
286void ColumnAggregateFunction::protect()
287{
288 data.protect();
289}
290
291MutableColumnPtr ColumnAggregateFunction::cloneEmpty() const
292{
293 return create(func);
294}
295
296String ColumnAggregateFunction::getTypeString() const
297{
298 return DataTypeAggregateFunction(func, func->getArgumentTypes(), func->getParameters()).getName();
299}
300
301Field ColumnAggregateFunction::operator[](size_t n) const
302{
303 Field field = AggregateFunctionStateData();
304 field.get<AggregateFunctionStateData &>().name = getTypeString();
305 {
306 WriteBufferFromString buffer(field.get<AggregateFunctionStateData &>().data);
307 func->serialize(data[n], buffer);
308 }
309 return field;
310}
311
312void ColumnAggregateFunction::get(size_t n, Field & res) const
313{
314 res = AggregateFunctionStateData();
315 res.get<AggregateFunctionStateData &>().name = getTypeString();
316 {
317 WriteBufferFromString buffer(res.get<AggregateFunctionStateData &>().data);
318 func->serialize(data[n], buffer);
319 }
320}
321
322StringRef ColumnAggregateFunction::getDataAt(size_t n) const
323{
324 return StringRef(reinterpret_cast<const char *>(&data[n]), sizeof(data[n]));
325}
326
327void ColumnAggregateFunction::insertData(const char * pos, size_t /*length*/)
328{
329 ensureOwnership();
330 data.push_back(*reinterpret_cast<const AggregateDataPtr *>(pos));
331}
332
333void ColumnAggregateFunction::insertFrom(const IColumn & from, size_t n)
334{
335 /// Must create new state of aggregate function and take ownership of it,
336 /// because ownership of states of aggregate function cannot be shared for individual rows,
337 /// (only as a whole, see comment above).
338 ensureOwnership();
339 insertDefault();
340 insertMergeFrom(from, n);
341}
342
343void ColumnAggregateFunction::insertFrom(ConstAggregateDataPtr place)
344{
345 ensureOwnership();
346 insertDefault();
347 insertMergeFrom(place);
348}
349
350void ColumnAggregateFunction::insertMergeFrom(ConstAggregateDataPtr place)
351{
352 func->merge(data.back(), place, &createOrGetArena());
353}
354
355void ColumnAggregateFunction::insertMergeFrom(const IColumn & from, size_t n)
356{
357 insertMergeFrom(assert_cast<const ColumnAggregateFunction &>(from).data[n]);
358}
359
360Arena & ColumnAggregateFunction::createOrGetArena()
361{
362 if (unlikely(!my_arena))
363 my_arena = std::make_shared<Arena>();
364
365 return *my_arena.get();
366}
367
368
369static void pushBackAndCreateState(ColumnAggregateFunction::Container & data, Arena & arena, IAggregateFunction * func)
370{
371 data.push_back(arena.alignedAlloc(func->sizeOfData(), func->alignOfData()));
372 try
373 {
374 func->create(data.back());
375 }
376 catch (...)
377 {
378 data.pop_back();
379 throw;
380 }
381}
382
383void ColumnAggregateFunction::insert(const Field & x)
384{
385 String type_string = getTypeString();
386
387 if (x.getType() != Field::Types::AggregateFunctionState)
388 throw Exception(String("Inserting field of type ") + x.getTypeName() + " into ColumnAggregateFunction. "
389 "Expected " + Field::Types::toString(Field::Types::AggregateFunctionState), ErrorCodes::LOGICAL_ERROR);
390
391 auto & field_name = x.get<const AggregateFunctionStateData &>().name;
392 if (type_string != field_name)
393 throw Exception("Cannot insert filed with type " + field_name + " into column with type " + type_string,
394 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
395
396 ensureOwnership();
397 Arena & arena = createOrGetArena();
398 pushBackAndCreateState(data, arena, func.get());
399 ReadBufferFromString read_buffer(x.get<const AggregateFunctionStateData &>().data);
400 func->deserialize(data.back(), read_buffer, &arena);
401}
402
403void ColumnAggregateFunction::insertDefault()
404{
405 ensureOwnership();
406 Arena & arena = createOrGetArena();
407 pushBackAndCreateState(data, arena, func.get());
408}
409
410StringRef ColumnAggregateFunction::serializeValueIntoArena(size_t n, Arena & dst, const char *& begin) const
411{
412 WriteBufferFromArena out(dst, begin);
413 func->serialize(data[n], out);
414 return out.finish();
415}
416
417const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char * src_arena)
418{
419 ensureOwnership();
420
421 /** Parameter "src_arena" points to Arena, from which we will deserialize the state.
422 * And "dst_arena" is another Arena, that aggregate function state will use to store its data.
423 */
424 Arena & dst_arena = createOrGetArena();
425 pushBackAndCreateState(data, dst_arena, func.get());
426
427 /** We will read from src_arena.
428 * There is no limit for reading - it is assumed, that we can read all that we need after src_arena pointer.
429 * Buf ReadBufferFromMemory requires some bound. We will use arbitrary big enough number, that will not overflow pointer.
430 * NOTE Technically, this is not compatible with C++ standard,
431 * as we cannot legally compare pointers after last element + 1 of some valid memory region.
432 * Probably this will not work under UBSan.
433 */
434 ReadBufferFromMemory read_buffer(src_arena, std::numeric_limits<char *>::max() - src_arena - 1);
435 func->deserialize(data.back(), read_buffer, &dst_arena);
436
437 return read_buffer.position();
438}
439
440void ColumnAggregateFunction::popBack(size_t n)
441{
442 size_t size = data.size();
443 size_t new_size = size - n;
444
445 if (!src)
446 for (size_t i = new_size; i < size; ++i)
447 func->destroy(data[i]);
448
449 data.resize_assume_reserved(new_size);
450}
451
452ColumnPtr ColumnAggregateFunction::replicate(const IColumn::Offsets & offsets) const
453{
454 size_t size = data.size();
455 if (size != offsets.size())
456 throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
457
458 if (size == 0)
459 return cloneEmpty();
460
461 auto res = createView();
462 auto & res_data = res->data;
463 res_data.reserve(offsets.back());
464
465 IColumn::Offset prev_offset = 0;
466 for (size_t i = 0; i < size; ++i)
467 {
468 size_t size_to_replicate = offsets[i] - prev_offset;
469 prev_offset = offsets[i];
470
471 for (size_t j = 0; j < size_to_replicate; ++j)
472 res_data.push_back(data[i]);
473 }
474
475 return res;
476}
477
478MutableColumns ColumnAggregateFunction::scatter(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector) const
479{
480 /// Columns with scattered values will point to this column as the owner of values.
481 MutableColumns columns(num_columns);
482 for (auto & column : columns)
483 column = createView();
484
485 size_t num_rows = size();
486
487 {
488 size_t reserve_size = num_rows / num_columns * 1.1; /// 1.1 is just a guess. Better to use n-sigma rule.
489
490 if (reserve_size > 1)
491 for (auto & column : columns)
492 column->reserve(reserve_size);
493 }
494
495 for (size_t i = 0; i < num_rows; ++i)
496 assert_cast<ColumnAggregateFunction &>(*columns[selector[i]]).data.push_back(data[i]);
497
498 return columns;
499}
500
501void ColumnAggregateFunction::getPermutation(bool /*reverse*/, size_t /*limit*/, int /*nan_direction_hint*/, IColumn::Permutation & res) const
502{
503 size_t s = data.size();
504 res.resize(s);
505 for (size_t i = 0; i < s; ++i)
506 res[i] = i;
507}
508
509void ColumnAggregateFunction::gather(ColumnGathererStream & gatherer)
510{
511 gatherer.gather(*this);
512}
513
514void ColumnAggregateFunction::getExtremes(Field & min, Field & max) const
515{
516 /// Place serialized default values into min/max.
517
518 AlignedBuffer place_buffer(func->sizeOfData(), func->alignOfData());
519 AggregateDataPtr place = place_buffer.data();
520
521 AggregateFunctionStateData serialized;
522 serialized.name = getTypeString();
523
524 func->create(place);
525 try
526 {
527 WriteBufferFromString buffer(serialized.data);
528 func->serialize(place, buffer);
529 }
530 catch (...)
531 {
532 func->destroy(place);
533 throw;
534 }
535 func->destroy(place);
536
537 min = serialized;
538 max = serialized;
539}
540
541namespace
542{
543
544ConstArenas concatArenas(const ConstArenas & array, ConstArenaPtr arena)
545{
546 ConstArenas result = array;
547 if (arena)
548 result.push_back(std::move(arena));
549
550 return result;
551}
552
553}
554
555ColumnAggregateFunction::MutablePtr ColumnAggregateFunction::createView() const
556{
557 auto res = create(func, concatArenas(foreign_arenas, my_arena));
558 res->src = getPtr();
559 return res;
560}
561
562ColumnAggregateFunction::ColumnAggregateFunction(const ColumnAggregateFunction & src_)
563 : foreign_arenas(concatArenas(src_.foreign_arenas, src_.my_arena)),
564 func(src_.func), src(src_.getPtr()), data(src_.data.begin(), src_.data.end())
565{
566}
567
568}
569