1#include <DataStreams/SummingSortedBlockInputStream.h>
2#include <DataTypes/DataTypesNumber.h>
3#include <DataTypes/NestedUtils.h>
4#include <DataTypes/DataTypeTuple.h>
5#include <DataTypes/DataTypeArray.h>
6#include <DataTypes/DataTypeAggregateFunction.h>
7#include <Columns/ColumnAggregateFunction.h>
8#include <Columns/ColumnTuple.h>
9#include <Common/StringUtils/StringUtils.h>
10#include <Common/FieldVisitors.h>
11#include <common/logger_useful.h>
12#include <Common/typeid_cast.h>
13#include <Common/assert_cast.h>
14
15#include <AggregateFunctions/AggregateFunctionFactory.h>
16#include <Functions/FunctionFactory.h>
17#include <Functions/FunctionHelpers.h>
18#include <Interpreters/Context.h>
19
20
21namespace DB
22{
23
24namespace ErrorCodes
25{
26 extern const int LOGICAL_ERROR;
27}
28
29
30namespace
31{
32 bool isInPrimaryKey(const SortDescription & description, const std::string & name, const size_t number)
33 {
34 for (auto & desc : description)
35 if (desc.column_name == name || (desc.column_name.empty() && desc.column_number == number))
36 return true;
37
38 return false;
39 }
40}
41
42
43SummingSortedBlockInputStream::SummingSortedBlockInputStream(
44 const BlockInputStreams & inputs_,
45 const SortDescription & description_,
46 /// List of columns to be summed. If empty, all numeric columns that are not in the description are taken.
47 const Names & column_names_to_sum,
48 size_t max_block_size_)
49 : MergingSortedBlockInputStream(inputs_, description_, max_block_size_)
50{
51 current_row.resize(num_columns);
52
53 /// name of nested structure -> the column numbers that refer to it.
54 std::unordered_map<std::string, std::vector<size_t>> discovered_maps;
55
56 /** Fill in the column numbers, which must be summed.
57 * This can only be numeric columns that are not part of the sort key.
58 * If a non-empty column_names_to_sum is specified, then we only take these columns.
59 * Some columns from column_names_to_sum may not be found. This is ignored.
60 */
61 for (size_t i = 0; i < num_columns; ++i)
62 {
63 const ColumnWithTypeAndName & column = header.safeGetByPosition(i);
64
65 /// Discover nested Maps and find columns for summation
66 if (typeid_cast<const DataTypeArray *>(column.type.get()))
67 {
68 const auto map_name = Nested::extractTableName(column.name);
69 /// if nested table name ends with `Map` it is a possible candidate for special handling
70 if (map_name == column.name || !endsWith(map_name, "Map"))
71 {
72 column_numbers_not_to_aggregate.push_back(i);
73 continue;
74 }
75
76 discovered_maps[map_name].emplace_back(i);
77 }
78 else
79 {
80 bool is_agg_func = WhichDataType(column.type).isAggregateFunction();
81
82 /// There are special const columns for example after prewere sections.
83 if ((!column.type->isSummable() && !is_agg_func) || isColumnConst(*column.column))
84 {
85 column_numbers_not_to_aggregate.push_back(i);
86 continue;
87 }
88
89 /// Are they inside the PK?
90 if (isInPrimaryKey(description, column.name, i))
91 {
92 column_numbers_not_to_aggregate.push_back(i);
93 continue;
94 }
95
96 if (column_names_to_sum.empty()
97 || column_names_to_sum.end() !=
98 std::find(column_names_to_sum.begin(), column_names_to_sum.end(), column.name))
99 {
100 // Create aggregator to sum this column
101 AggregateDescription desc;
102 desc.is_agg_func_type = is_agg_func;
103 desc.column_numbers = {i};
104
105 if (!is_agg_func)
106 {
107 desc.init("sumWithOverflow", {column.type});
108 }
109
110 columns_to_aggregate.emplace_back(std::move(desc));
111 }
112 else
113 {
114 // Column is not going to be summed, use last value
115 column_numbers_not_to_aggregate.push_back(i);
116 }
117 }
118 }
119
120 /// select actual nested Maps from list of candidates
121 for (const auto & map : discovered_maps)
122 {
123 /// map should contain at least two elements (key -> value)
124 if (map.second.size() < 2)
125 {
126 for (auto col : map.second)
127 column_numbers_not_to_aggregate.push_back(col);
128 continue;
129 }
130
131 /// no elements of map could be in primary key
132 auto column_num_it = map.second.begin();
133 for (; column_num_it != map.second.end(); ++column_num_it)
134 if (isInPrimaryKey(description, header.safeGetByPosition(*column_num_it).name, *column_num_it))
135 break;
136 if (column_num_it != map.second.end())
137 {
138 for (auto col : map.second)
139 column_numbers_not_to_aggregate.push_back(col);
140 continue;
141 }
142
143 DataTypes argument_types;
144 AggregateDescription desc;
145 MapDescription map_desc;
146
147 column_num_it = map.second.begin();
148 for (; column_num_it != map.second.end(); ++column_num_it)
149 {
150 const ColumnWithTypeAndName & key_col = header.safeGetByPosition(*column_num_it);
151 const String & name = key_col.name;
152 const IDataType & nested_type = *static_cast<const DataTypeArray *>(key_col.type.get())->getNestedType();
153
154 if (column_num_it == map.second.begin()
155 || endsWith(name, "ID")
156 || endsWith(name, "Key")
157 || endsWith(name, "Type"))
158 {
159 if (!nested_type.isValueRepresentedByInteger())
160 break;
161
162 map_desc.key_col_nums.push_back(*column_num_it);
163 }
164 else
165 {
166 if (!nested_type.isSummable())
167 break;
168
169 map_desc.val_col_nums.push_back(*column_num_it);
170 }
171
172 // Add column to function arguments
173 desc.column_numbers.push_back(*column_num_it);
174 argument_types.push_back(key_col.type);
175 }
176
177 if (column_num_it != map.second.end())
178 {
179 for (auto col : map.second)
180 column_numbers_not_to_aggregate.push_back(col);
181 continue;
182 }
183
184 if (map_desc.key_col_nums.size() == 1)
185 {
186 // Create summation for all value columns in the map
187 desc.init("sumMapWithOverflow", argument_types);
188 columns_to_aggregate.emplace_back(std::move(desc));
189 }
190 else
191 {
192 // Fall back to legacy mergeMaps for composite keys
193 for (auto col : map.second)
194 column_numbers_not_to_aggregate.push_back(col);
195 maps_to_sum.emplace_back(std::move(map_desc));
196 }
197 }
198}
199
200
201void SummingSortedBlockInputStream::insertCurrentRowIfNeeded(MutableColumns & merged_columns)
202{
203 /// We have nothing to aggregate. It means that it could be non-zero, because we have columns_not_to_aggregate.
204 if (columns_to_aggregate.empty())
205 current_row_is_zero = false;
206
207 for (auto & desc : columns_to_aggregate)
208 {
209 // Do not insert if the aggregation state hasn't been created
210 if (desc.created)
211 {
212 if (desc.is_agg_func_type)
213 {
214 current_row_is_zero = false;
215 }
216 else
217 {
218 try
219 {
220 desc.function->insertResultInto(desc.state.data(), *desc.merged_column);
221
222 /// Update zero status of current row
223 if (desc.column_numbers.size() == 1)
224 {
225 // Flag row as non-empty if at least one column number if non-zero
226 current_row_is_zero = current_row_is_zero && desc.merged_column->isDefaultAt(desc.merged_column->size() - 1);
227 }
228 else
229 {
230 /// It is sumMapWithOverflow aggregate function.
231 /// Assume that the row isn't empty in this case (just because it is compatible with previous version)
232 current_row_is_zero = false;
233 }
234 }
235 catch (...)
236 {
237 desc.destroyState();
238 throw;
239 }
240 }
241 desc.destroyState();
242 }
243 else
244 desc.merged_column->insertDefault();
245 }
246
247 /// If it is "zero" row, then rollback the insertion
248 /// (at this moment we need rollback only cols from columns_to_aggregate)
249 if (current_row_is_zero)
250 {
251 for (auto & desc : columns_to_aggregate)
252 desc.merged_column->popBack(1);
253
254 return;
255 }
256
257 for (auto i : column_numbers_not_to_aggregate)
258 merged_columns[i]->insert(current_row[i]);
259
260 /// Update per-block and per-group flags
261 ++merged_rows;
262}
263
264
265Block SummingSortedBlockInputStream::readImpl()
266{
267 if (finished)
268 return Block();
269
270 MutableColumns merged_columns;
271 init(merged_columns);
272
273 if (has_collation)
274 throw Exception("Logical error: " + getName() + " does not support collations", ErrorCodes::LOGICAL_ERROR);
275
276 if (merged_columns.empty())
277 return {};
278
279 /// Update aggregation result columns for current block
280 for (auto & desc : columns_to_aggregate)
281 {
282 // Wrap aggregated columns in a tuple to match function signature
283 if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType()))
284 {
285 size_t tuple_size = desc.column_numbers.size();
286 MutableColumns tuple_columns(tuple_size);
287 for (size_t i = 0; i < tuple_size; ++i)
288 tuple_columns[i] = header.safeGetByPosition(desc.column_numbers[i]).column->cloneEmpty();
289
290 desc.merged_column = ColumnTuple::create(std::move(tuple_columns));
291 }
292 else
293 desc.merged_column = header.safeGetByPosition(desc.column_numbers[0]).column->cloneEmpty();
294 }
295
296 merge(merged_columns, queue_without_collation);
297 Block res = header.cloneWithColumns(std::move(merged_columns));
298
299 /// Place aggregation results into block.
300 for (auto & desc : columns_to_aggregate)
301 {
302 if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType()))
303 {
304 /// Unpack tuple into block.
305 size_t tuple_size = desc.column_numbers.size();
306 for (size_t i = 0; i < tuple_size; ++i)
307 res.getByPosition(desc.column_numbers[i]).column = assert_cast<const ColumnTuple &>(*desc.merged_column).getColumnPtr(i);
308 }
309 else
310 res.getByPosition(desc.column_numbers[0]).column = std::move(desc.merged_column);
311 }
312
313 return res;
314}
315
316
317void SummingSortedBlockInputStream::merge(MutableColumns & merged_columns, std::priority_queue<SortCursor> & queue)
318{
319 merged_rows = 0;
320
321 /// Take the rows in needed order and put them in `merged_columns` until rows no more than `max_block_size`
322 while (!queue.empty())
323 {
324 SortCursor current = queue.top();
325
326 setPrimaryKeyRef(next_key, current);
327
328 bool key_differs;
329
330 if (current_key.empty()) /// The first key encountered.
331 {
332 key_differs = true;
333 current_row_is_zero = true;
334 }
335 else
336 key_differs = next_key != current_key;
337
338 if (key_differs)
339 {
340 if (!current_key.empty())
341 /// Write the data for the previous group.
342 insertCurrentRowIfNeeded(merged_columns);
343
344 if (merged_rows >= max_block_size)
345 {
346 /// The block is now full and the last row is calculated completely.
347 current_key.reset();
348 return;
349 }
350
351 current_key.swap(next_key);
352
353 setRow(current_row, current);
354
355 /// Reset aggregation states for next row
356 for (auto & desc : columns_to_aggregate)
357 desc.createState();
358
359 // Start aggregations with current row
360 addRow(current);
361
362 if (maps_to_sum.empty())
363 {
364 /// We have only columns_to_aggregate. The status of current row will be determined
365 /// in 'insertCurrentRowIfNeeded' method on the values of aggregate functions.
366 current_row_is_zero = true;
367 }
368 else
369 {
370 /// We have complex maps that will be summed with 'mergeMap' method.
371 /// The single row is considered non zero, and the status after merging with other rows
372 /// will be determined in the branch below (when key_differs == false).
373 current_row_is_zero = false;
374 }
375 }
376 else
377 {
378 addRow(current);
379
380 // Merge maps only for same rows
381 for (const auto & desc : maps_to_sum)
382 if (mergeMap(desc, current_row, current))
383 current_row_is_zero = false;
384 }
385
386 queue.pop();
387
388 if (!current->isLast())
389 {
390 current->next();
391 queue.push(current);
392 }
393 else
394 {
395 /// We get the next block from the corresponding source, if there is one.
396 fetchNextBlock(current, queue);
397 }
398 }
399
400 /// We will write the data for the last group, if it is non-zero.
401 /// If it is zero, and without it the output stream will be empty, we will write it anyway.
402 insertCurrentRowIfNeeded(merged_columns);
403 finished = true;
404}
405
406
407bool SummingSortedBlockInputStream::mergeMap(const MapDescription & desc, Row & row, SortCursor & cursor)
408{
409 /// Strongly non-optimal.
410
411 Row & left = row;
412 Row right(left.size());
413
414 for (size_t col_num : desc.key_col_nums)
415 right[col_num] = (*cursor->all_columns[col_num])[cursor->pos].template get<Array>();
416
417 for (size_t col_num : desc.val_col_nums)
418 right[col_num] = (*cursor->all_columns[col_num])[cursor->pos].template get<Array>();
419
420 auto at_ith_column_jth_row = [&](const Row & matrix, size_t i, size_t j) -> const Field &
421 {
422 return matrix[i].get<Array>()[j];
423 };
424
425 auto tuple_of_nth_columns_at_jth_row = [&](const Row & matrix, const ColumnNumbers & col_nums, size_t j) -> Array
426 {
427 size_t size = col_nums.size();
428 Array res(size);
429 for (size_t col_num_index = 0; col_num_index < size; ++col_num_index)
430 res[col_num_index] = at_ith_column_jth_row(matrix, col_nums[col_num_index], j);
431 return res;
432 };
433
434 std::map<Array, Array> merged;
435
436 auto accumulate = [](Array & dst, const Array & src)
437 {
438 bool has_non_zero = false;
439 size_t size = dst.size();
440 for (size_t i = 0; i < size; ++i)
441 if (applyVisitor(FieldVisitorSum(src[i]), dst[i]))
442 has_non_zero = true;
443 return has_non_zero;
444 };
445
446 auto merge = [&](const Row & matrix)
447 {
448 size_t rows = matrix[desc.key_col_nums[0]].get<Array>().size();
449
450 for (size_t j = 0; j < rows; ++j)
451 {
452 Array key = tuple_of_nth_columns_at_jth_row(matrix, desc.key_col_nums, j);
453 Array value = tuple_of_nth_columns_at_jth_row(matrix, desc.val_col_nums, j);
454
455 auto it = merged.find(key);
456 if (merged.end() == it)
457 merged.emplace(std::move(key), std::move(value));
458 else
459 {
460 if (!accumulate(it->second, value))
461 merged.erase(it);
462 }
463 }
464 };
465
466 merge(left);
467 merge(right);
468
469 for (size_t col_num : desc.key_col_nums)
470 row[col_num] = Array(merged.size());
471 for (size_t col_num : desc.val_col_nums)
472 row[col_num] = Array(merged.size());
473
474 size_t row_num = 0;
475 for (const auto & key_value : merged)
476 {
477 for (size_t col_num_index = 0, size = desc.key_col_nums.size(); col_num_index < size; ++col_num_index)
478 row[desc.key_col_nums[col_num_index]].get<Array>()[row_num] = key_value.first[col_num_index];
479
480 for (size_t col_num_index = 0, size = desc.val_col_nums.size(); col_num_index < size; ++col_num_index)
481 row[desc.val_col_nums[col_num_index]].get<Array>()[row_num] = key_value.second[col_num_index];
482
483 ++row_num;
484 }
485
486 return row_num != 0;
487}
488
489
490void SummingSortedBlockInputStream::addRow(SortCursor & cursor)
491{
492 for (auto & desc : columns_to_aggregate)
493 {
494 if (!desc.created)
495 throw Exception("Logical error in SummingSortedBlockInputStream, there are no description", ErrorCodes::LOGICAL_ERROR);
496
497 if (desc.is_agg_func_type)
498 {
499 // desc.state is not used for AggregateFunction types
500 auto & col = cursor->all_columns[desc.column_numbers[0]];
501 assert_cast<ColumnAggregateFunction &>(*desc.merged_column).insertMergeFrom(*col, cursor->pos);
502 }
503 else
504 {
505 // Specialized case for unary functions
506 if (desc.column_numbers.size() == 1)
507 {
508 auto & col = cursor->all_columns[desc.column_numbers[0]];
509 desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr);
510 }
511 else
512 {
513 // Gather all source columns into a vector
514 ColumnRawPtrs columns(desc.column_numbers.size());
515 for (size_t i = 0; i < desc.column_numbers.size(); ++i)
516 columns[i] = cursor->all_columns[desc.column_numbers[i]];
517
518 desc.add_function(desc.function.get(), desc.state.data(), columns.data(), cursor->pos, nullptr);
519 }
520 }
521 }
522}
523
524}
525