1 | #include <DataStreams/SummingSortedBlockInputStream.h> |
2 | #include <DataTypes/DataTypesNumber.h> |
3 | #include <DataTypes/NestedUtils.h> |
4 | #include <DataTypes/DataTypeTuple.h> |
5 | #include <DataTypes/DataTypeArray.h> |
6 | #include <DataTypes/DataTypeAggregateFunction.h> |
7 | #include <Columns/ColumnAggregateFunction.h> |
8 | #include <Columns/ColumnTuple.h> |
9 | #include <Common/StringUtils/StringUtils.h> |
10 | #include <Common/FieldVisitors.h> |
11 | #include <common/logger_useful.h> |
12 | #include <Common/typeid_cast.h> |
13 | #include <Common/assert_cast.h> |
14 | |
15 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
16 | #include <Functions/FunctionFactory.h> |
17 | #include <Functions/FunctionHelpers.h> |
18 | #include <Interpreters/Context.h> |
19 | |
20 | |
21 | namespace DB |
22 | { |
23 | |
24 | namespace ErrorCodes |
25 | { |
26 | extern const int LOGICAL_ERROR; |
27 | } |
28 | |
29 | |
30 | namespace |
31 | { |
32 | bool isInPrimaryKey(const SortDescription & description, const std::string & name, const size_t number) |
33 | { |
34 | for (auto & desc : description) |
35 | if (desc.column_name == name || (desc.column_name.empty() && desc.column_number == number)) |
36 | return true; |
37 | |
38 | return false; |
39 | } |
40 | } |
41 | |
42 | |
43 | SummingSortedBlockInputStream::SummingSortedBlockInputStream( |
44 | const BlockInputStreams & inputs_, |
45 | const SortDescription & description_, |
46 | /// List of columns to be summed. If empty, all numeric columns that are not in the description are taken. |
47 | const Names & column_names_to_sum, |
48 | size_t max_block_size_) |
49 | : MergingSortedBlockInputStream(inputs_, description_, max_block_size_) |
50 | { |
51 | current_row.resize(num_columns); |
52 | |
53 | /// name of nested structure -> the column numbers that refer to it. |
54 | std::unordered_map<std::string, std::vector<size_t>> discovered_maps; |
55 | |
56 | /** Fill in the column numbers, which must be summed. |
57 | * This can only be numeric columns that are not part of the sort key. |
58 | * If a non-empty column_names_to_sum is specified, then we only take these columns. |
59 | * Some columns from column_names_to_sum may not be found. This is ignored. |
60 | */ |
61 | for (size_t i = 0; i < num_columns; ++i) |
62 | { |
63 | const ColumnWithTypeAndName & column = header.safeGetByPosition(i); |
64 | |
65 | /// Discover nested Maps and find columns for summation |
66 | if (typeid_cast<const DataTypeArray *>(column.type.get())) |
67 | { |
68 | const auto map_name = Nested::extractTableName(column.name); |
69 | /// if nested table name ends with `Map` it is a possible candidate for special handling |
70 | if (map_name == column.name || !endsWith(map_name, "Map" )) |
71 | { |
72 | column_numbers_not_to_aggregate.push_back(i); |
73 | continue; |
74 | } |
75 | |
76 | discovered_maps[map_name].emplace_back(i); |
77 | } |
78 | else |
79 | { |
80 | bool is_agg_func = WhichDataType(column.type).isAggregateFunction(); |
81 | |
82 | /// There are special const columns for example after prewere sections. |
83 | if ((!column.type->isSummable() && !is_agg_func) || isColumnConst(*column.column)) |
84 | { |
85 | column_numbers_not_to_aggregate.push_back(i); |
86 | continue; |
87 | } |
88 | |
89 | /// Are they inside the PK? |
90 | if (isInPrimaryKey(description, column.name, i)) |
91 | { |
92 | column_numbers_not_to_aggregate.push_back(i); |
93 | continue; |
94 | } |
95 | |
96 | if (column_names_to_sum.empty() |
97 | || column_names_to_sum.end() != |
98 | std::find(column_names_to_sum.begin(), column_names_to_sum.end(), column.name)) |
99 | { |
100 | // Create aggregator to sum this column |
101 | AggregateDescription desc; |
102 | desc.is_agg_func_type = is_agg_func; |
103 | desc.column_numbers = {i}; |
104 | |
105 | if (!is_agg_func) |
106 | { |
107 | desc.init("sumWithOverflow" , {column.type}); |
108 | } |
109 | |
110 | columns_to_aggregate.emplace_back(std::move(desc)); |
111 | } |
112 | else |
113 | { |
114 | // Column is not going to be summed, use last value |
115 | column_numbers_not_to_aggregate.push_back(i); |
116 | } |
117 | } |
118 | } |
119 | |
120 | /// select actual nested Maps from list of candidates |
121 | for (const auto & map : discovered_maps) |
122 | { |
123 | /// map should contain at least two elements (key -> value) |
124 | if (map.second.size() < 2) |
125 | { |
126 | for (auto col : map.second) |
127 | column_numbers_not_to_aggregate.push_back(col); |
128 | continue; |
129 | } |
130 | |
131 | /// no elements of map could be in primary key |
132 | auto column_num_it = map.second.begin(); |
133 | for (; column_num_it != map.second.end(); ++column_num_it) |
134 | if (isInPrimaryKey(description, header.safeGetByPosition(*column_num_it).name, *column_num_it)) |
135 | break; |
136 | if (column_num_it != map.second.end()) |
137 | { |
138 | for (auto col : map.second) |
139 | column_numbers_not_to_aggregate.push_back(col); |
140 | continue; |
141 | } |
142 | |
143 | DataTypes argument_types; |
144 | AggregateDescription desc; |
145 | MapDescription map_desc; |
146 | |
147 | column_num_it = map.second.begin(); |
148 | for (; column_num_it != map.second.end(); ++column_num_it) |
149 | { |
150 | const ColumnWithTypeAndName & key_col = header.safeGetByPosition(*column_num_it); |
151 | const String & name = key_col.name; |
152 | const IDataType & nested_type = *static_cast<const DataTypeArray *>(key_col.type.get())->getNestedType(); |
153 | |
154 | if (column_num_it == map.second.begin() |
155 | || endsWith(name, "ID" ) |
156 | || endsWith(name, "Key" ) |
157 | || endsWith(name, "Type" )) |
158 | { |
159 | if (!nested_type.isValueRepresentedByInteger()) |
160 | break; |
161 | |
162 | map_desc.key_col_nums.push_back(*column_num_it); |
163 | } |
164 | else |
165 | { |
166 | if (!nested_type.isSummable()) |
167 | break; |
168 | |
169 | map_desc.val_col_nums.push_back(*column_num_it); |
170 | } |
171 | |
172 | // Add column to function arguments |
173 | desc.column_numbers.push_back(*column_num_it); |
174 | argument_types.push_back(key_col.type); |
175 | } |
176 | |
177 | if (column_num_it != map.second.end()) |
178 | { |
179 | for (auto col : map.second) |
180 | column_numbers_not_to_aggregate.push_back(col); |
181 | continue; |
182 | } |
183 | |
184 | if (map_desc.key_col_nums.size() == 1) |
185 | { |
186 | // Create summation for all value columns in the map |
187 | desc.init("sumMapWithOverflow" , argument_types); |
188 | columns_to_aggregate.emplace_back(std::move(desc)); |
189 | } |
190 | else |
191 | { |
192 | // Fall back to legacy mergeMaps for composite keys |
193 | for (auto col : map.second) |
194 | column_numbers_not_to_aggregate.push_back(col); |
195 | maps_to_sum.emplace_back(std::move(map_desc)); |
196 | } |
197 | } |
198 | } |
199 | |
200 | |
201 | void SummingSortedBlockInputStream::insertCurrentRowIfNeeded(MutableColumns & merged_columns) |
202 | { |
203 | /// We have nothing to aggregate. It means that it could be non-zero, because we have columns_not_to_aggregate. |
204 | if (columns_to_aggregate.empty()) |
205 | current_row_is_zero = false; |
206 | |
207 | for (auto & desc : columns_to_aggregate) |
208 | { |
209 | // Do not insert if the aggregation state hasn't been created |
210 | if (desc.created) |
211 | { |
212 | if (desc.is_agg_func_type) |
213 | { |
214 | current_row_is_zero = false; |
215 | } |
216 | else |
217 | { |
218 | try |
219 | { |
220 | desc.function->insertResultInto(desc.state.data(), *desc.merged_column); |
221 | |
222 | /// Update zero status of current row |
223 | if (desc.column_numbers.size() == 1) |
224 | { |
225 | // Flag row as non-empty if at least one column number if non-zero |
226 | current_row_is_zero = current_row_is_zero && desc.merged_column->isDefaultAt(desc.merged_column->size() - 1); |
227 | } |
228 | else |
229 | { |
230 | /// It is sumMapWithOverflow aggregate function. |
231 | /// Assume that the row isn't empty in this case (just because it is compatible with previous version) |
232 | current_row_is_zero = false; |
233 | } |
234 | } |
235 | catch (...) |
236 | { |
237 | desc.destroyState(); |
238 | throw; |
239 | } |
240 | } |
241 | desc.destroyState(); |
242 | } |
243 | else |
244 | desc.merged_column->insertDefault(); |
245 | } |
246 | |
247 | /// If it is "zero" row, then rollback the insertion |
248 | /// (at this moment we need rollback only cols from columns_to_aggregate) |
249 | if (current_row_is_zero) |
250 | { |
251 | for (auto & desc : columns_to_aggregate) |
252 | desc.merged_column->popBack(1); |
253 | |
254 | return; |
255 | } |
256 | |
257 | for (auto i : column_numbers_not_to_aggregate) |
258 | merged_columns[i]->insert(current_row[i]); |
259 | |
260 | /// Update per-block and per-group flags |
261 | ++merged_rows; |
262 | } |
263 | |
264 | |
265 | Block SummingSortedBlockInputStream::readImpl() |
266 | { |
267 | if (finished) |
268 | return Block(); |
269 | |
270 | MutableColumns merged_columns; |
271 | init(merged_columns); |
272 | |
273 | if (has_collation) |
274 | throw Exception("Logical error: " + getName() + " does not support collations" , ErrorCodes::LOGICAL_ERROR); |
275 | |
276 | if (merged_columns.empty()) |
277 | return {}; |
278 | |
279 | /// Update aggregation result columns for current block |
280 | for (auto & desc : columns_to_aggregate) |
281 | { |
282 | // Wrap aggregated columns in a tuple to match function signature |
283 | if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType())) |
284 | { |
285 | size_t tuple_size = desc.column_numbers.size(); |
286 | MutableColumns tuple_columns(tuple_size); |
287 | for (size_t i = 0; i < tuple_size; ++i) |
288 | tuple_columns[i] = header.safeGetByPosition(desc.column_numbers[i]).column->cloneEmpty(); |
289 | |
290 | desc.merged_column = ColumnTuple::create(std::move(tuple_columns)); |
291 | } |
292 | else |
293 | desc.merged_column = header.safeGetByPosition(desc.column_numbers[0]).column->cloneEmpty(); |
294 | } |
295 | |
296 | merge(merged_columns, queue_without_collation); |
297 | Block res = header.cloneWithColumns(std::move(merged_columns)); |
298 | |
299 | /// Place aggregation results into block. |
300 | for (auto & desc : columns_to_aggregate) |
301 | { |
302 | if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType())) |
303 | { |
304 | /// Unpack tuple into block. |
305 | size_t tuple_size = desc.column_numbers.size(); |
306 | for (size_t i = 0; i < tuple_size; ++i) |
307 | res.getByPosition(desc.column_numbers[i]).column = assert_cast<const ColumnTuple &>(*desc.merged_column).getColumnPtr(i); |
308 | } |
309 | else |
310 | res.getByPosition(desc.column_numbers[0]).column = std::move(desc.merged_column); |
311 | } |
312 | |
313 | return res; |
314 | } |
315 | |
316 | |
317 | void SummingSortedBlockInputStream::merge(MutableColumns & merged_columns, std::priority_queue<SortCursor> & queue) |
318 | { |
319 | merged_rows = 0; |
320 | |
321 | /// Take the rows in needed order and put them in `merged_columns` until rows no more than `max_block_size` |
322 | while (!queue.empty()) |
323 | { |
324 | SortCursor current = queue.top(); |
325 | |
326 | setPrimaryKeyRef(next_key, current); |
327 | |
328 | bool key_differs; |
329 | |
330 | if (current_key.empty()) /// The first key encountered. |
331 | { |
332 | key_differs = true; |
333 | current_row_is_zero = true; |
334 | } |
335 | else |
336 | key_differs = next_key != current_key; |
337 | |
338 | if (key_differs) |
339 | { |
340 | if (!current_key.empty()) |
341 | /// Write the data for the previous group. |
342 | insertCurrentRowIfNeeded(merged_columns); |
343 | |
344 | if (merged_rows >= max_block_size) |
345 | { |
346 | /// The block is now full and the last row is calculated completely. |
347 | current_key.reset(); |
348 | return; |
349 | } |
350 | |
351 | current_key.swap(next_key); |
352 | |
353 | setRow(current_row, current); |
354 | |
355 | /// Reset aggregation states for next row |
356 | for (auto & desc : columns_to_aggregate) |
357 | desc.createState(); |
358 | |
359 | // Start aggregations with current row |
360 | addRow(current); |
361 | |
362 | if (maps_to_sum.empty()) |
363 | { |
364 | /// We have only columns_to_aggregate. The status of current row will be determined |
365 | /// in 'insertCurrentRowIfNeeded' method on the values of aggregate functions. |
366 | current_row_is_zero = true; |
367 | } |
368 | else |
369 | { |
370 | /// We have complex maps that will be summed with 'mergeMap' method. |
371 | /// The single row is considered non zero, and the status after merging with other rows |
372 | /// will be determined in the branch below (when key_differs == false). |
373 | current_row_is_zero = false; |
374 | } |
375 | } |
376 | else |
377 | { |
378 | addRow(current); |
379 | |
380 | // Merge maps only for same rows |
381 | for (const auto & desc : maps_to_sum) |
382 | if (mergeMap(desc, current_row, current)) |
383 | current_row_is_zero = false; |
384 | } |
385 | |
386 | queue.pop(); |
387 | |
388 | if (!current->isLast()) |
389 | { |
390 | current->next(); |
391 | queue.push(current); |
392 | } |
393 | else |
394 | { |
395 | /// We get the next block from the corresponding source, if there is one. |
396 | fetchNextBlock(current, queue); |
397 | } |
398 | } |
399 | |
400 | /// We will write the data for the last group, if it is non-zero. |
401 | /// If it is zero, and without it the output stream will be empty, we will write it anyway. |
402 | insertCurrentRowIfNeeded(merged_columns); |
403 | finished = true; |
404 | } |
405 | |
406 | |
407 | bool SummingSortedBlockInputStream::mergeMap(const MapDescription & desc, Row & row, SortCursor & cursor) |
408 | { |
409 | /// Strongly non-optimal. |
410 | |
411 | Row & left = row; |
412 | Row right(left.size()); |
413 | |
414 | for (size_t col_num : desc.key_col_nums) |
415 | right[col_num] = (*cursor->all_columns[col_num])[cursor->pos].template get<Array>(); |
416 | |
417 | for (size_t col_num : desc.val_col_nums) |
418 | right[col_num] = (*cursor->all_columns[col_num])[cursor->pos].template get<Array>(); |
419 | |
420 | auto at_ith_column_jth_row = [&](const Row & matrix, size_t i, size_t j) -> const Field & |
421 | { |
422 | return matrix[i].get<Array>()[j]; |
423 | }; |
424 | |
425 | auto tuple_of_nth_columns_at_jth_row = [&](const Row & matrix, const ColumnNumbers & col_nums, size_t j) -> Array |
426 | { |
427 | size_t size = col_nums.size(); |
428 | Array res(size); |
429 | for (size_t col_num_index = 0; col_num_index < size; ++col_num_index) |
430 | res[col_num_index] = at_ith_column_jth_row(matrix, col_nums[col_num_index], j); |
431 | return res; |
432 | }; |
433 | |
434 | std::map<Array, Array> merged; |
435 | |
436 | auto accumulate = [](Array & dst, const Array & src) |
437 | { |
438 | bool has_non_zero = false; |
439 | size_t size = dst.size(); |
440 | for (size_t i = 0; i < size; ++i) |
441 | if (applyVisitor(FieldVisitorSum(src[i]), dst[i])) |
442 | has_non_zero = true; |
443 | return has_non_zero; |
444 | }; |
445 | |
446 | auto merge = [&](const Row & matrix) |
447 | { |
448 | size_t rows = matrix[desc.key_col_nums[0]].get<Array>().size(); |
449 | |
450 | for (size_t j = 0; j < rows; ++j) |
451 | { |
452 | Array key = tuple_of_nth_columns_at_jth_row(matrix, desc.key_col_nums, j); |
453 | Array value = tuple_of_nth_columns_at_jth_row(matrix, desc.val_col_nums, j); |
454 | |
455 | auto it = merged.find(key); |
456 | if (merged.end() == it) |
457 | merged.emplace(std::move(key), std::move(value)); |
458 | else |
459 | { |
460 | if (!accumulate(it->second, value)) |
461 | merged.erase(it); |
462 | } |
463 | } |
464 | }; |
465 | |
466 | merge(left); |
467 | merge(right); |
468 | |
469 | for (size_t col_num : desc.key_col_nums) |
470 | row[col_num] = Array(merged.size()); |
471 | for (size_t col_num : desc.val_col_nums) |
472 | row[col_num] = Array(merged.size()); |
473 | |
474 | size_t row_num = 0; |
475 | for (const auto & key_value : merged) |
476 | { |
477 | for (size_t col_num_index = 0, size = desc.key_col_nums.size(); col_num_index < size; ++col_num_index) |
478 | row[desc.key_col_nums[col_num_index]].get<Array>()[row_num] = key_value.first[col_num_index]; |
479 | |
480 | for (size_t col_num_index = 0, size = desc.val_col_nums.size(); col_num_index < size; ++col_num_index) |
481 | row[desc.val_col_nums[col_num_index]].get<Array>()[row_num] = key_value.second[col_num_index]; |
482 | |
483 | ++row_num; |
484 | } |
485 | |
486 | return row_num != 0; |
487 | } |
488 | |
489 | |
490 | void SummingSortedBlockInputStream::addRow(SortCursor & cursor) |
491 | { |
492 | for (auto & desc : columns_to_aggregate) |
493 | { |
494 | if (!desc.created) |
495 | throw Exception("Logical error in SummingSortedBlockInputStream, there are no description" , ErrorCodes::LOGICAL_ERROR); |
496 | |
497 | if (desc.is_agg_func_type) |
498 | { |
499 | // desc.state is not used for AggregateFunction types |
500 | auto & col = cursor->all_columns[desc.column_numbers[0]]; |
501 | assert_cast<ColumnAggregateFunction &>(*desc.merged_column).insertMergeFrom(*col, cursor->pos); |
502 | } |
503 | else |
504 | { |
505 | // Specialized case for unary functions |
506 | if (desc.column_numbers.size() == 1) |
507 | { |
508 | auto & col = cursor->all_columns[desc.column_numbers[0]]; |
509 | desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, nullptr); |
510 | } |
511 | else |
512 | { |
513 | // Gather all source columns into a vector |
514 | ColumnRawPtrs columns(desc.column_numbers.size()); |
515 | for (size_t i = 0; i < desc.column_numbers.size(); ++i) |
516 | columns[i] = cursor->all_columns[desc.column_numbers[i]]; |
517 | |
518 | desc.add_function(desc.function.get(), desc.state.data(), columns.data(), cursor->pos, nullptr); |
519 | } |
520 | } |
521 | } |
522 | } |
523 | |
524 | } |
525 | |