1 | #include <Columns/ColumnAggregateFunction.h> |
2 | #include <Columns/ColumnsCommon.h> |
3 | #include <Common/assert_cast.h> |
4 | #include <AggregateFunctions/AggregateFunctionState.h> |
5 | #include <DataStreams/ColumnGathererStream.h> |
6 | #include <IO/WriteBufferFromArena.h> |
7 | #include <IO/WriteBufferFromString.h> |
8 | #include <IO/Operators.h> |
9 | #include <Common/SipHash.h> |
10 | #include <Common/AlignedBuffer.h> |
11 | #include <Common/typeid_cast.h> |
12 | #include <Common/Arena.h> |
13 | |
14 | #include <AggregateFunctions/AggregateFunctionMLMethod.h> |
15 | |
16 | namespace DB |
17 | { |
18 | |
19 | namespace ErrorCodes |
20 | { |
21 | extern const int PARAMETER_OUT_OF_BOUND; |
22 | extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; |
23 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
24 | } |
25 | |
26 | |
27 | ColumnAggregateFunction::~ColumnAggregateFunction() |
28 | { |
29 | if (!func->hasTrivialDestructor() && !src) |
30 | for (auto val : data) |
31 | func->destroy(val); |
32 | } |
33 | |
34 | void ColumnAggregateFunction::addArena(ConstArenaPtr arena_) |
35 | { |
36 | foreign_arenas.push_back(arena_); |
37 | } |
38 | |
39 | MutableColumnPtr ColumnAggregateFunction::convertToValues() const |
40 | { |
41 | /** If the aggregate function returns an unfinalized/unfinished state, |
42 | * then you just need to copy pointers to it and also shared ownership of data. |
43 | * |
44 | * Also replace the aggregate function with the nested function. |
45 | * That is, if this column is the states of the aggregate function `aggState`, |
46 | * then we return the same column, but with the states of the aggregate function `agg`. |
47 | * These are the same states, changing only the function to which they correspond. |
48 | * |
49 | * Further is quite difficult to understand. |
50 | * Example when this happens: |
51 | * |
52 | * SELECT k, finalizeAggregation(quantileTimingState(0.5)(x)) FROM ... GROUP BY k WITH TOTALS |
53 | * |
54 | * This calculates the aggregate function `quantileTimingState`. |
55 | * Its return type AggregateFunction(quantileTiming(0.5), UInt64)`. |
56 | * Due to the presence of WITH TOTALS, during aggregation the states of this aggregate function will be stored |
57 | * in the ColumnAggregateFunction column of type |
58 | * AggregateFunction(quantileTimingState(0.5), UInt64). |
59 | * Then, in `TotalsHavingBlockInputStream`, it will be called `convertToValues` method, |
60 | * to get the "ready" values. |
61 | * But it just converts a column of type |
62 | * `AggregateFunction(quantileTimingState(0.5), UInt64)` |
63 | * into `AggregateFunction(quantileTiming(0.5), UInt64)` |
64 | * - in the same states. |
65 | * |
66 | * Then `finalizeAggregation` function will be calculated, which will call `convertToValues` already on the result. |
67 | * And this converts a column of type |
68 | * AggregateFunction(quantileTiming(0.5), UInt64) |
69 | * into UInt16 - already finished result of `quantileTiming`. |
70 | */ |
71 | if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get())) |
72 | { |
73 | auto res = createView(); |
74 | res->set(function_state->getNestedFunction()); |
75 | res->data.assign(data.begin(), data.end()); |
76 | return res; |
77 | } |
78 | |
79 | MutableColumnPtr res = func->getReturnType()->createColumn(); |
80 | res->reserve(data.size()); |
81 | |
82 | for (auto val : data) |
83 | func->insertResultInto(val, *res); |
84 | |
85 | return res; |
86 | } |
87 | |
88 | MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const |
89 | { |
90 | MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn(); |
91 | res->reserve(data.size()); |
92 | |
93 | auto ML_function = func.get(); |
94 | if (ML_function) |
95 | { |
96 | if (data.size() == 1) |
97 | { |
98 | /// Case for const column. Predict using single model. |
99 | ML_function->predictValues(data[0], *res, block, 0, block.rows(), arguments, context); |
100 | } |
101 | else |
102 | { |
103 | /// Case for non-constant column. Use different aggregate function for each row. |
104 | size_t row_num = 0; |
105 | for (auto val : data) |
106 | { |
107 | ML_function->predictValues(val, *res, block, row_num, 1, arguments, context); |
108 | ++row_num; |
109 | } |
110 | } |
111 | } |
112 | else |
113 | { |
114 | throw Exception("Illegal aggregate function is passed" , |
115 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
116 | } |
117 | return res; |
118 | } |
119 | |
120 | void ColumnAggregateFunction::ensureOwnership() |
121 | { |
122 | if (src) |
123 | { |
124 | /// We must copy all data from src and take ownership. |
125 | size_t size = data.size(); |
126 | |
127 | Arena & arena = createOrGetArena(); |
128 | size_t size_of_state = func->sizeOfData(); |
129 | size_t align_of_state = func->alignOfData(); |
130 | |
131 | size_t rollback_pos = 0; |
132 | try |
133 | { |
134 | for (size_t i = 0; i < size; ++i) |
135 | { |
136 | ConstAggregateDataPtr old_place = data[i]; |
137 | data[i] = arena.alignedAlloc(size_of_state, align_of_state); |
138 | func->create(data[i]); |
139 | ++rollback_pos; |
140 | func->merge(data[i], old_place, &arena); |
141 | } |
142 | } |
143 | catch (...) |
144 | { |
145 | /// If we failed to take ownership, destroy all temporary data. |
146 | |
147 | if (!func->hasTrivialDestructor()) |
148 | for (size_t i = 0; i < rollback_pos; ++i) |
149 | func->destroy(data[i]); |
150 | |
151 | throw; |
152 | } |
153 | |
154 | /// Now we own all data. |
155 | src.reset(); |
156 | } |
157 | } |
158 | |
159 | |
160 | void ColumnAggregateFunction::insertRangeFrom(const IColumn & from, size_t start, size_t length) |
161 | { |
162 | const ColumnAggregateFunction & from_concrete = assert_cast<const ColumnAggregateFunction &>(from); |
163 | |
164 | if (start + length > from_concrete.data.size()) |
165 | throw Exception("Parameters start = " + toString(start) + ", length = " + toString(length) |
166 | + " are out of bound in ColumnAggregateFunction::insertRangeFrom method" |
167 | " (data.size() = " |
168 | + toString(from_concrete.data.size()) |
169 | + ")." , |
170 | ErrorCodes::PARAMETER_OUT_OF_BOUND); |
171 | |
172 | if (!empty() && src.get() != &from_concrete) |
173 | { |
174 | /// Must create new states of aggregate function and take ownership of it, |
175 | /// because ownership of states of aggregate function cannot be shared for individual rows, |
176 | /// (only as a whole). |
177 | |
178 | size_t end = start + length; |
179 | for (size_t i = start; i < end; ++i) |
180 | insertFrom(from, i); |
181 | } |
182 | else |
183 | { |
184 | /// Keep shared ownership of aggregation states. |
185 | src = from_concrete.getPtr(); |
186 | |
187 | size_t old_size = data.size(); |
188 | data.resize(old_size + length); |
189 | memcpy(data.data() + old_size, &from_concrete.data[start], length * sizeof(data[0])); |
190 | } |
191 | } |
192 | |
193 | |
194 | ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_size_hint) const |
195 | { |
196 | size_t size = data.size(); |
197 | if (size != filter.size()) |
198 | throw Exception("Size of filter doesn't match size of column." , ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); |
199 | |
200 | if (size == 0) |
201 | return cloneEmpty(); |
202 | |
203 | auto res = createView(); |
204 | auto & res_data = res->data; |
205 | |
206 | if (result_size_hint) |
207 | res_data.reserve(result_size_hint > 0 ? result_size_hint : size); |
208 | |
209 | for (size_t i = 0; i < size; ++i) |
210 | if (filter[i]) |
211 | res_data.push_back(data[i]); |
212 | |
213 | /// To save RAM in case of too strong filtering. |
214 | if (res_data.size() * 2 < res_data.capacity()) |
215 | res_data = Container(res_data.cbegin(), res_data.cend()); |
216 | |
217 | return res; |
218 | } |
219 | |
220 | |
221 | ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limit) const |
222 | { |
223 | size_t size = data.size(); |
224 | |
225 | if (limit == 0) |
226 | limit = size; |
227 | else |
228 | limit = std::min(size, limit); |
229 | |
230 | if (perm.size() < limit) |
231 | throw Exception("Size of permutation is less than required." , ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); |
232 | |
233 | auto res = createView(); |
234 | |
235 | res->data.resize(limit); |
236 | for (size_t i = 0; i < limit; ++i) |
237 | res->data[i] = data[perm[i]]; |
238 | |
239 | return res; |
240 | } |
241 | |
242 | ColumnPtr ColumnAggregateFunction::index(const IColumn & indexes, size_t limit) const |
243 | { |
244 | return selectIndexImpl(*this, indexes, limit); |
245 | } |
246 | |
247 | template <typename Type> |
248 | ColumnPtr ColumnAggregateFunction::indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const |
249 | { |
250 | auto res = createView(); |
251 | |
252 | res->data.resize(limit); |
253 | for (size_t i = 0; i < limit; ++i) |
254 | res->data[i] = data[indexes[i]]; |
255 | |
256 | return res; |
257 | } |
258 | |
259 | INSTANTIATE_INDEX_IMPL(ColumnAggregateFunction) |
260 | |
261 | /// Is required to support operations with Set |
262 | void ColumnAggregateFunction::updateHashWithValue(size_t n, SipHash & hash) const |
263 | { |
264 | WriteBufferFromOwnString wbuf; |
265 | func->serialize(data[n], wbuf); |
266 | hash.update(wbuf.str().c_str(), wbuf.str().size()); |
267 | } |
268 | |
269 | /// The returned size is less than real size. The reason is that some parts of |
270 | /// aggregate function data may be allocated on shared arenas. These arenas are |
271 | /// used for several blocks, and also may be updated concurrently from other |
272 | /// threads, so we can't know the size of these data. |
273 | size_t ColumnAggregateFunction::byteSize() const |
274 | { |
275 | return data.size() * sizeof(data[0]) |
276 | + (my_arena ? my_arena->size() : 0); |
277 | } |
278 | |
279 | /// Like in byteSize(), the size is underestimated. |
280 | size_t ColumnAggregateFunction::allocatedBytes() const |
281 | { |
282 | return data.allocated_bytes() |
283 | + (my_arena ? my_arena->size() : 0); |
284 | } |
285 | |
286 | void ColumnAggregateFunction::protect() |
287 | { |
288 | data.protect(); |
289 | } |
290 | |
291 | MutableColumnPtr ColumnAggregateFunction::cloneEmpty() const |
292 | { |
293 | return create(func); |
294 | } |
295 | |
296 | String ColumnAggregateFunction::getTypeString() const |
297 | { |
298 | return DataTypeAggregateFunction(func, func->getArgumentTypes(), func->getParameters()).getName(); |
299 | } |
300 | |
301 | Field ColumnAggregateFunction::operator[](size_t n) const |
302 | { |
303 | Field field = AggregateFunctionStateData(); |
304 | field.get<AggregateFunctionStateData &>().name = getTypeString(); |
305 | { |
306 | WriteBufferFromString buffer(field.get<AggregateFunctionStateData &>().data); |
307 | func->serialize(data[n], buffer); |
308 | } |
309 | return field; |
310 | } |
311 | |
312 | void ColumnAggregateFunction::get(size_t n, Field & res) const |
313 | { |
314 | res = AggregateFunctionStateData(); |
315 | res.get<AggregateFunctionStateData &>().name = getTypeString(); |
316 | { |
317 | WriteBufferFromString buffer(res.get<AggregateFunctionStateData &>().data); |
318 | func->serialize(data[n], buffer); |
319 | } |
320 | } |
321 | |
322 | StringRef ColumnAggregateFunction::getDataAt(size_t n) const |
323 | { |
324 | return StringRef(reinterpret_cast<const char *>(&data[n]), sizeof(data[n])); |
325 | } |
326 | |
327 | void ColumnAggregateFunction::insertData(const char * pos, size_t /*length*/) |
328 | { |
329 | ensureOwnership(); |
330 | data.push_back(*reinterpret_cast<const AggregateDataPtr *>(pos)); |
331 | } |
332 | |
333 | void ColumnAggregateFunction::insertFrom(const IColumn & from, size_t n) |
334 | { |
335 | /// Must create new state of aggregate function and take ownership of it, |
336 | /// because ownership of states of aggregate function cannot be shared for individual rows, |
337 | /// (only as a whole, see comment above). |
338 | ensureOwnership(); |
339 | insertDefault(); |
340 | insertMergeFrom(from, n); |
341 | } |
342 | |
343 | void ColumnAggregateFunction::insertFrom(ConstAggregateDataPtr place) |
344 | { |
345 | ensureOwnership(); |
346 | insertDefault(); |
347 | insertMergeFrom(place); |
348 | } |
349 | |
350 | void ColumnAggregateFunction::insertMergeFrom(ConstAggregateDataPtr place) |
351 | { |
352 | func->merge(data.back(), place, &createOrGetArena()); |
353 | } |
354 | |
355 | void ColumnAggregateFunction::insertMergeFrom(const IColumn & from, size_t n) |
356 | { |
357 | insertMergeFrom(assert_cast<const ColumnAggregateFunction &>(from).data[n]); |
358 | } |
359 | |
360 | Arena & ColumnAggregateFunction::createOrGetArena() |
361 | { |
362 | if (unlikely(!my_arena)) |
363 | my_arena = std::make_shared<Arena>(); |
364 | |
365 | return *my_arena.get(); |
366 | } |
367 | |
368 | |
369 | static void pushBackAndCreateState(ColumnAggregateFunction::Container & data, Arena & arena, IAggregateFunction * func) |
370 | { |
371 | data.push_back(arena.alignedAlloc(func->sizeOfData(), func->alignOfData())); |
372 | try |
373 | { |
374 | func->create(data.back()); |
375 | } |
376 | catch (...) |
377 | { |
378 | data.pop_back(); |
379 | throw; |
380 | } |
381 | } |
382 | |
383 | void ColumnAggregateFunction::insert(const Field & x) |
384 | { |
385 | String type_string = getTypeString(); |
386 | |
387 | if (x.getType() != Field::Types::AggregateFunctionState) |
388 | throw Exception(String("Inserting field of type " ) + x.getTypeName() + " into ColumnAggregateFunction. " |
389 | "Expected " + Field::Types::toString(Field::Types::AggregateFunctionState), ErrorCodes::LOGICAL_ERROR); |
390 | |
391 | auto & field_name = x.get<const AggregateFunctionStateData &>().name; |
392 | if (type_string != field_name) |
393 | throw Exception("Cannot insert filed with type " + field_name + " into column with type " + type_string, |
394 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
395 | |
396 | ensureOwnership(); |
397 | Arena & arena = createOrGetArena(); |
398 | pushBackAndCreateState(data, arena, func.get()); |
399 | ReadBufferFromString read_buffer(x.get<const AggregateFunctionStateData &>().data); |
400 | func->deserialize(data.back(), read_buffer, &arena); |
401 | } |
402 | |
403 | void ColumnAggregateFunction::insertDefault() |
404 | { |
405 | ensureOwnership(); |
406 | Arena & arena = createOrGetArena(); |
407 | pushBackAndCreateState(data, arena, func.get()); |
408 | } |
409 | |
410 | StringRef ColumnAggregateFunction::serializeValueIntoArena(size_t n, Arena & dst, const char *& begin) const |
411 | { |
412 | WriteBufferFromArena out(dst, begin); |
413 | func->serialize(data[n], out); |
414 | return out.finish(); |
415 | } |
416 | |
417 | const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char * src_arena) |
418 | { |
419 | ensureOwnership(); |
420 | |
421 | /** Parameter "src_arena" points to Arena, from which we will deserialize the state. |
422 | * And "dst_arena" is another Arena, that aggregate function state will use to store its data. |
423 | */ |
424 | Arena & dst_arena = createOrGetArena(); |
425 | pushBackAndCreateState(data, dst_arena, func.get()); |
426 | |
427 | /** We will read from src_arena. |
428 | * There is no limit for reading - it is assumed, that we can read all that we need after src_arena pointer. |
429 | * Buf ReadBufferFromMemory requires some bound. We will use arbitrary big enough number, that will not overflow pointer. |
430 | * NOTE Technically, this is not compatible with C++ standard, |
431 | * as we cannot legally compare pointers after last element + 1 of some valid memory region. |
432 | * Probably this will not work under UBSan. |
433 | */ |
434 | ReadBufferFromMemory read_buffer(src_arena, std::numeric_limits<char *>::max() - src_arena - 1); |
435 | func->deserialize(data.back(), read_buffer, &dst_arena); |
436 | |
437 | return read_buffer.position(); |
438 | } |
439 | |
440 | void ColumnAggregateFunction::popBack(size_t n) |
441 | { |
442 | size_t size = data.size(); |
443 | size_t new_size = size - n; |
444 | |
445 | if (!src) |
446 | for (size_t i = new_size; i < size; ++i) |
447 | func->destroy(data[i]); |
448 | |
449 | data.resize_assume_reserved(new_size); |
450 | } |
451 | |
452 | ColumnPtr ColumnAggregateFunction::replicate(const IColumn::Offsets & offsets) const |
453 | { |
454 | size_t size = data.size(); |
455 | if (size != offsets.size()) |
456 | throw Exception("Size of offsets doesn't match size of column." , ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); |
457 | |
458 | if (size == 0) |
459 | return cloneEmpty(); |
460 | |
461 | auto res = createView(); |
462 | auto & res_data = res->data; |
463 | res_data.reserve(offsets.back()); |
464 | |
465 | IColumn::Offset prev_offset = 0; |
466 | for (size_t i = 0; i < size; ++i) |
467 | { |
468 | size_t size_to_replicate = offsets[i] - prev_offset; |
469 | prev_offset = offsets[i]; |
470 | |
471 | for (size_t j = 0; j < size_to_replicate; ++j) |
472 | res_data.push_back(data[i]); |
473 | } |
474 | |
475 | return res; |
476 | } |
477 | |
478 | MutableColumns ColumnAggregateFunction::scatter(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector) const |
479 | { |
480 | /// Columns with scattered values will point to this column as the owner of values. |
481 | MutableColumns columns(num_columns); |
482 | for (auto & column : columns) |
483 | column = createView(); |
484 | |
485 | size_t num_rows = size(); |
486 | |
487 | { |
488 | size_t reserve_size = num_rows / num_columns * 1.1; /// 1.1 is just a guess. Better to use n-sigma rule. |
489 | |
490 | if (reserve_size > 1) |
491 | for (auto & column : columns) |
492 | column->reserve(reserve_size); |
493 | } |
494 | |
495 | for (size_t i = 0; i < num_rows; ++i) |
496 | assert_cast<ColumnAggregateFunction &>(*columns[selector[i]]).data.push_back(data[i]); |
497 | |
498 | return columns; |
499 | } |
500 | |
501 | void ColumnAggregateFunction::getPermutation(bool /*reverse*/, size_t /*limit*/, int /*nan_direction_hint*/, IColumn::Permutation & res) const |
502 | { |
503 | size_t s = data.size(); |
504 | res.resize(s); |
505 | for (size_t i = 0; i < s; ++i) |
506 | res[i] = i; |
507 | } |
508 | |
509 | void ColumnAggregateFunction::gather(ColumnGathererStream & gatherer) |
510 | { |
511 | gatherer.gather(*this); |
512 | } |
513 | |
514 | void ColumnAggregateFunction::getExtremes(Field & min, Field & max) const |
515 | { |
516 | /// Place serialized default values into min/max. |
517 | |
518 | AlignedBuffer place_buffer(func->sizeOfData(), func->alignOfData()); |
519 | AggregateDataPtr place = place_buffer.data(); |
520 | |
521 | AggregateFunctionStateData serialized; |
522 | serialized.name = getTypeString(); |
523 | |
524 | func->create(place); |
525 | try |
526 | { |
527 | WriteBufferFromString buffer(serialized.data); |
528 | func->serialize(place, buffer); |
529 | } |
530 | catch (...) |
531 | { |
532 | func->destroy(place); |
533 | throw; |
534 | } |
535 | func->destroy(place); |
536 | |
537 | min = serialized; |
538 | max = serialized; |
539 | } |
540 | |
541 | namespace |
542 | { |
543 | |
544 | ConstArenas concatArenas(const ConstArenas & array, ConstArenaPtr arena) |
545 | { |
546 | ConstArenas result = array; |
547 | if (arena) |
548 | result.push_back(std::move(arena)); |
549 | |
550 | return result; |
551 | } |
552 | |
553 | } |
554 | |
555 | ColumnAggregateFunction::MutablePtr ColumnAggregateFunction::createView() const |
556 | { |
557 | auto res = create(func, concatArenas(foreign_arenas, my_arena)); |
558 | res->src = getPtr(); |
559 | return res; |
560 | } |
561 | |
562 | ColumnAggregateFunction::ColumnAggregateFunction(const ColumnAggregateFunction & src_) |
563 | : foreign_arenas(concatArenas(src_.foreign_arenas, src_.my_arena)), |
564 | func(src_.func), src(src_.getPtr()), data(src_.data.begin(), src_.data.end()) |
565 | { |
566 | } |
567 | |
568 | } |
569 | |