1#include "duckdb/common/sort/partition_state.hpp"
2
3#include "duckdb/common/types/column/column_data_consumer.hpp"
4#include "duckdb/common/row_operations/row_operations.hpp"
5#include "duckdb/main/config.hpp"
6#include "duckdb/parallel/event.hpp"
7
8#include <numeric>
9
10namespace duckdb {
11
12PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions,
13 const Orders &orders, const Types &payload_types, bool external)
14 : count(0) {
15
16 RowLayout payload_layout;
17 payload_layout.Initialize(types: payload_types);
18 global_sort = make_uniq<GlobalSortState>(args&: buffer_manager, args: orders, args&: payload_layout);
19 global_sort->external = external;
20
21 // Set up a comparator for the partition subset
22 partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(num_prefix_cols: partitions.size());
23}
24
25int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const {
26 int part_cmp = 0;
27 if (partition_layout.all_constant) {
28 part_cmp = FastMemcmp(str1: left.entry_ptr, str2: right.entry_ptr, size: partition_layout.comparison_size);
29 } else {
30 part_cmp = Comparators::CompareTuple(left: left.scan, right: right.scan, l_ptr: left.entry_ptr, r_ptr: right.entry_ptr, sort_layout: partition_layout,
31 external_sort: left.external);
32 }
33 return part_cmp;
34}
35
36void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) {
37 D_ASSERT(count > 0);
38
39 SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN);
40 SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN);
41
42 partition_mask.SetValidUnsafe(0);
43 order_mask.SetValidUnsafe(0);
44 for (++curr; curr.GetIndex() < count; ++curr) {
45 // Compare the partition subset first because if that differs, then so does the full ordering
46 const auto part_cmp = ComparePartitions(left: prev, right: curr);
47 ;
48
49 if (part_cmp) {
50 partition_mask.SetValidUnsafe(curr.GetIndex());
51 order_mask.SetValidUnsafe(curr.GetIndex());
52 } else if (prev.Compare(other: curr)) {
53 order_mask.SetValidUnsafe(curr.GetIndex());
54 }
55 ++prev;
56 }
57}
58
59void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders,
60 const vector<unique_ptr<Expression>> &partition_bys,
61 const Orders &order_bys,
62 const vector<unique_ptr<BaseStatistics>> &partition_stats) {
63
64 // we sort by both 1) partition by expression list and 2) order by expressions
65 const auto partition_cols = partition_bys.size();
66 for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) {
67 auto &pexpr = partition_bys[prt_idx];
68
69 if (partition_stats.empty() || !partition_stats[prt_idx]) {
70 orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), args: nullptr);
71 } else {
72 orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(),
73 args: partition_stats[prt_idx]->ToUnique());
74 }
75 partitions.emplace_back(args: orders.back().Copy());
76 }
77
78 for (const auto &order : order_bys) {
79 orders.emplace_back(args: order.Copy());
80 }
81}
82
83PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context,
84 const vector<unique_ptr<Expression>> &partition_bys,
85 const vector<BoundOrderByNode> &order_bys,
86 const Types &payload_types,
87 const vector<unique_ptr<BaseStatistics>> &partition_stats,
88 idx_t estimated_cardinality)
89 : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)),
90 payload_types(payload_types), memory_per_thread(0), count(0) {
91
92 GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats);
93
94 memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context);
95 external = ClientConfig::GetConfig(context).force_external;
96
97 if (!orders.empty()) {
98 grouping_types = payload_types;
99 grouping_types.push_back(x: LogicalType::HASH);
100
101 ResizeGroupingData(cardinality: estimated_cardinality);
102 }
103}
104
105void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) {
106 // Have we started to combine? Then just live with it.
107 if (grouping_data && !grouping_data->GetPartitions().empty()) {
108 return;
109 }
110 // Is the average partition size too large?
111 const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE;
112 const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0;
113 auto new_bits = bits ? bits : 4;
114 while (new_bits < 10 && (cardinality / RadixPartitioning::NumberOfPartitions(radix_bits: new_bits)) > partition_size) {
115 ++new_bits;
116 }
117
118 // Repartition the grouping data
119 if (new_bits != bits) {
120 const auto hash_col_idx = payload_types.size();
121 grouping_data = make_uniq<RadixPartitionedColumnData>(args&: context, args&: grouping_types, args&: new_bits, args: hash_col_idx);
122 }
123}
124
125void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) {
126 // We are done if the local_partition is right sized.
127 auto &local_radix = local_partition->Cast<RadixPartitionedColumnData>();
128 if (local_radix.GetRadixBits() == grouping_data->GetRadixBits()) {
129 return;
130 }
131
132 // If the local partition is now too small, flush it and reallocate
133 auto new_partition = grouping_data->CreateShared();
134 auto new_append = make_uniq<PartitionedColumnDataAppendState>();
135 new_partition->InitializeAppendState(state&: *new_append);
136
137 local_partition->FlushAppendState(state&: *local_append);
138 auto &local_groups = local_partition->GetPartitions();
139 for (auto &local_group : local_groups) {
140 ColumnDataScanState scanner;
141 local_group->InitializeScan(state&: scanner);
142
143 DataChunk scan_chunk;
144 local_group->InitializeScanChunk(chunk&: scan_chunk);
145 for (scan_chunk.Reset(); local_group->Scan(state&: scanner, result&: scan_chunk); scan_chunk.Reset()) {
146 new_partition->Append(state&: *new_append, input&: scan_chunk);
147 }
148 }
149
150 // The append state has stale pointers to the old local partition, so nuke it from orbit.
151 new_partition->FlushAppendState(state&: *new_append);
152
153 local_partition = std::move(new_partition);
154 local_append = make_uniq<PartitionedColumnDataAppendState>();
155 local_partition->InitializeAppendState(state&: *local_append);
156}
157
158void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) {
159 // Make sure grouping_data doesn't change under us.
160 lock_guard<mutex> guard(lock);
161
162 if (!local_partition) {
163 local_partition = grouping_data->CreateShared();
164 local_append = make_uniq<PartitionedColumnDataAppendState>();
165 local_partition->InitializeAppendState(state&: *local_append);
166 return;
167 }
168
169 // Grow the groups if they are too big
170 ResizeGroupingData(cardinality: count);
171
172 // Sync local partition to have the same bit count
173 SyncLocalPartition(local_partition, local_append);
174}
175
176void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) {
177 if (!local_partition) {
178 return;
179 }
180 local_partition->FlushAppendState(state&: *local_append);
181
182 // Make sure grouping_data doesn't change under us.
183 // Combine has an internal mutex, so this is single-threaded anyway.
184 lock_guard<mutex> guard(lock);
185 SyncLocalPartition(local_partition, local_append);
186 grouping_data->Combine(other&: *local_partition);
187}
188
189void PartitionGlobalSinkState::BuildSortState(ColumnDataCollection &group_data, PartitionGlobalHashGroup &hash_group) {
190 auto &global_sort = *hash_group.global_sort;
191
192 // Set up the sort expression computation.
193 vector<LogicalType> sort_types;
194 ExpressionExecutor executor(context);
195 for (auto &order : orders) {
196 auto &oexpr = order.expression;
197 sort_types.emplace_back(args&: oexpr->return_type);
198 executor.AddExpression(expr: *oexpr);
199 }
200 DataChunk sort_chunk;
201 sort_chunk.Initialize(allocator, types: sort_types);
202
203 // Copy the data from the group into the sort code.
204 LocalSortState local_sort;
205 local_sort.Initialize(global_sort_state&: global_sort, buffer_manager_p&: global_sort.buffer_manager);
206
207 // Strip hash column
208 DataChunk payload_chunk;
209 payload_chunk.Initialize(allocator, types: payload_types);
210
211 vector<column_t> column_ids;
212 column_ids.reserve(n: payload_types.size());
213 for (column_t i = 0; i < payload_types.size(); ++i) {
214 column_ids.emplace_back(args&: i);
215 }
216 ColumnDataConsumer scanner(group_data, column_ids);
217 ColumnDataConsumerScanState chunk_state;
218 chunk_state.current_chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY;
219 scanner.InitializeScan();
220 for (auto chunk_idx = scanner.ChunkCount(); chunk_idx-- > 0;) {
221 if (!scanner.AssignChunk(state&: chunk_state)) {
222 break;
223 }
224 scanner.ScanChunk(state&: chunk_state, chunk&: payload_chunk);
225
226 sort_chunk.Reset();
227 executor.Execute(input&: payload_chunk, result&: sort_chunk);
228
229 local_sort.SinkChunk(sort&: sort_chunk, payload&: payload_chunk);
230 if (local_sort.SizeInBytes() > memory_per_thread) {
231 local_sort.Sort(global_sort_state&: global_sort, reorder_heap: true);
232 }
233 scanner.FinishChunk(state&: chunk_state);
234 }
235
236 global_sort.AddLocalState(local_sort_state&: local_sort);
237
238 hash_group.count += group_data.Count();
239}
240
241// Per-thread sink state
242PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p)
243 : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) {
244
245 vector<LogicalType> group_types;
246 for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) {
247 auto &pexpr = *gstate.partitions[prt_idx].expression.get();
248 group_types.push_back(x: pexpr.return_type);
249 executor.AddExpression(expr: pexpr);
250 }
251 sort_cols = gstate.orders.size() + group_types.size();
252
253 if (sort_cols) {
254 if (!group_types.empty()) {
255 // OVER(PARTITION BY...)
256 group_chunk.Initialize(allocator, types: group_types);
257 }
258 // OVER(...)
259 auto payload_types = gstate.payload_types;
260 payload_types.emplace_back(args: LogicalType::HASH);
261 payload_chunk.Initialize(allocator, types: payload_types);
262 } else {
263 // OVER()
264 payload_layout.Initialize(types: gstate.payload_types);
265 }
266}
267
268void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) {
269 const auto count = input_chunk.size();
270 if (group_chunk.ColumnCount() > 0) {
271 // OVER(PARTITION BY...) (hash grouping)
272 group_chunk.Reset();
273 executor.Execute(input&: input_chunk, result&: group_chunk);
274 VectorOperations::Hash(input&: group_chunk.data[0], hashes&: hash_vector, count);
275 for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) {
276 VectorOperations::CombineHash(hashes&: hash_vector, input&: group_chunk.data[prt_idx], count);
277 }
278 } else {
279 // OVER(...) (sorting)
280 // Single partition => single hash value
281 hash_vector.SetVectorType(VectorType::CONSTANT_VECTOR);
282 auto hashes = ConstantVector::GetData<hash_t>(vector&: hash_vector);
283 hashes[0] = 0;
284 }
285}
286
287void PartitionLocalSinkState::Sink(DataChunk &input_chunk) {
288 gstate.count += input_chunk.size();
289
290 // OVER()
291 if (sort_cols == 0) {
292 // No sorts, so build paged row chunks
293 if (!rows) {
294 const auto entry_size = payload_layout.GetRowWidth();
295 const auto capacity = MaxValue<idx_t>(STANDARD_VECTOR_SIZE, b: (Storage::BLOCK_SIZE / entry_size) + 1);
296 rows = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: capacity, args: entry_size);
297 strings = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: (idx_t)Storage::BLOCK_SIZE, args: 1, args: true);
298 }
299 const auto row_count = input_chunk.size();
300 const auto row_sel = FlatVector::IncrementalSelectionVector();
301 Vector addresses(LogicalType::POINTER);
302 auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses);
303 const auto prev_rows_blocks = rows->blocks.size();
304 auto handles = rows->Build(added_count: row_count, key_locations, entry_sizes: nullptr, sel: row_sel);
305 auto input_data = input_chunk.ToUnifiedFormat();
306 RowOperations::Scatter(columns&: input_chunk, col_data: input_data.get(), layout: payload_layout, rows&: addresses, string_heap&: *strings, sel: *row_sel, count: row_count);
307 // Mark that row blocks contain pointers (heap blocks are pinned)
308 if (!payload_layout.AllConstant()) {
309 D_ASSERT(strings->keep_pinned);
310 for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) {
311 rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink");
312 }
313 }
314 return;
315 }
316
317 // OVER(...)
318 payload_chunk.Reset();
319 auto &hash_vector = payload_chunk.data.back();
320 Hash(input_chunk, hash_vector);
321 for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) {
322 payload_chunk.data[col_idx].Reference(other&: input_chunk.data[col_idx]);
323 }
324 payload_chunk.SetCardinality(input_chunk);
325
326 gstate.UpdateLocalPartition(local_partition, local_append);
327 local_partition->Append(state&: *local_append, input&: payload_chunk);
328}
329
330void PartitionLocalSinkState::Combine() {
331 // OVER()
332 if (sort_cols == 0) {
333 // Only one partition again, so need a global lock.
334 lock_guard<mutex> glock(gstate.lock);
335 if (gstate.rows) {
336 if (rows) {
337 gstate.rows->Merge(other&: *rows);
338 gstate.strings->Merge(other&: *strings);
339 rows.reset();
340 strings.reset();
341 }
342 } else {
343 gstate.rows = std::move(rows);
344 gstate.strings = std::move(strings);
345 }
346 return;
347 }
348
349 // OVER(...)
350 gstate.CombineLocalPartition(local_partition, local_append);
351}
352
353PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data,
354 hash_t hash_bin)
355 : sink(sink), group_data(std::move(group_data)), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0),
356 tasks_completed(0) {
357
358 const auto group_idx = sink.hash_groups.size();
359 auto new_group = make_uniq<PartitionGlobalHashGroup>(args&: sink.buffer_manager, args&: sink.partitions, args&: sink.orders,
360 args: sink.payload_types, args&: sink.external);
361 sink.hash_groups.emplace_back(args: std::move(new_group));
362
363 hash_group = sink.hash_groups[group_idx].get();
364 global_sort = sink.hash_groups[group_idx]->global_sort.get();
365
366 sink.bin_groups[hash_bin] = group_idx;
367}
368
369void PartitionLocalMergeState::Prepare() {
370 auto &global_sort = *merge_state->global_sort;
371 merge_state->sink.BuildSortState(group_data&: *merge_state->group_data, hash_group&: *merge_state->hash_group);
372 merge_state->group_data.reset();
373
374 global_sort.PrepareMergePhase();
375}
376
377void PartitionLocalMergeState::Merge() {
378 auto &global_sort = *merge_state->global_sort;
379 MergeSorter merge_sorter(global_sort, global_sort.buffer_manager);
380 merge_sorter.PerformInMergeRound();
381}
382
383void PartitionLocalMergeState::ExecuteTask() {
384 switch (stage) {
385 case PartitionSortStage::PREPARE:
386 Prepare();
387 break;
388 case PartitionSortStage::MERGE:
389 Merge();
390 break;
391 default:
392 throw InternalException("Unexpected PartitionGlobalMergeState in ExecuteTask!");
393 }
394
395 merge_state->CompleteTask();
396 finished = true;
397}
398
399bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) {
400 lock_guard<mutex> guard(lock);
401
402 if (tasks_assigned >= total_tasks) {
403 return false;
404 }
405
406 local_state.merge_state = this;
407 local_state.stage = stage;
408 local_state.finished = false;
409 tasks_assigned++;
410
411 return true;
412}
413
414void PartitionGlobalMergeState::CompleteTask() {
415 lock_guard<mutex> guard(lock);
416
417 ++tasks_completed;
418}
419
420bool PartitionGlobalMergeState::TryPrepareNextStage() {
421 lock_guard<mutex> guard(lock);
422
423 if (tasks_completed < total_tasks) {
424 return false;
425 }
426
427 tasks_assigned = tasks_completed = 0;
428
429 switch (stage) {
430 case PartitionSortStage::INIT:
431 total_tasks = 1;
432 stage = PartitionSortStage::PREPARE;
433 return true;
434
435 case PartitionSortStage::PREPARE:
436 total_tasks = global_sort->sorted_blocks.size() / 2;
437 if (!total_tasks) {
438 break;
439 }
440 stage = PartitionSortStage::MERGE;
441 global_sort->InitializeMergeRound();
442 return true;
443
444 case PartitionSortStage::MERGE:
445 global_sort->CompleteMergeRound(keep_radix_data: true);
446 total_tasks = global_sort->sorted_blocks.size() / 2;
447 if (!total_tasks) {
448 break;
449 }
450 global_sort->InitializeMergeRound();
451 return true;
452
453 case PartitionSortStage::SORTED:
454 break;
455 }
456
457 stage = PartitionSortStage::SORTED;
458
459 return false;
460}
461
462PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) {
463 // Schedule all the sorts for maximum thread utilisation
464 auto &partitions = sink.grouping_data->GetPartitions();
465 sink.bin_groups.resize(new_size: partitions.size(), x: partitions.size());
466 for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) {
467 auto &group_data = partitions[hash_bin];
468 // Prepare for merge sort phase
469 if (group_data->Count()) {
470 auto state = make_uniq<PartitionGlobalMergeState>(args&: sink, args: std::move(group_data), args&: hash_bin);
471 states.emplace_back(args: std::move(state));
472 }
473 }
474}
475
476class PartitionMergeTask : public ExecutorTask {
477public:
478 PartitionMergeTask(shared_ptr<Event> event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p)
479 : ExecutorTask(context_p), event(std::move(event_p)), hash_groups(hash_groups_p) {
480 }
481
482 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override;
483
484private:
485 shared_ptr<Event> event;
486 PartitionLocalMergeState local_state;
487 PartitionGlobalMergeStates &hash_groups;
488};
489
490TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) {
491 // Loop until all hash groups are done
492 size_t sorted = 0;
493 while (sorted < hash_groups.states.size()) {
494 // First check if there is an unfinished task for this thread
495 if (executor.HasError()) {
496 return TaskExecutionResult::TASK_ERROR;
497 }
498 if (!local_state.TaskFinished()) {
499 local_state.ExecuteTask();
500 continue;
501 }
502
503 // Thread is done with its assigned task, try to fetch new work
504 for (auto group = sorted; group < hash_groups.states.size(); ++group) {
505 auto &global_state = hash_groups.states[group];
506 if (global_state->IsSorted()) {
507 // This hash group is done
508 // Update the high water mark of densely completed groups
509 if (sorted == group) {
510 ++sorted;
511 }
512 continue;
513 }
514
515 // Try to assign work for this hash group to this thread
516 if (global_state->AssignTask(local_state)) {
517 // We assigned a task to this thread!
518 // Break out of this loop to re-enter the top-level loop and execute the task
519 break;
520 }
521
522 // Hash group global state couldn't assign a task to this thread
523 // Try to prepare the next stage
524 if (!global_state->TryPrepareNextStage()) {
525 // This current hash group is not yet done
526 // But we were not able to assign a task for it to this thread
527 // See if the next hash group is better
528 continue;
529 }
530
531 // We were able to prepare the next stage for this hash group!
532 // Try to assign a task once more
533 if (global_state->AssignTask(local_state)) {
534 // We assigned a task to this thread!
535 // Break out of this loop to re-enter the top-level loop and execute the task
536 break;
537 }
538
539 // We were able to prepare the next merge round,
540 // but we were not able to assign a task for it to this thread
541 // The tasks were assigned to other threads while this thread waited for the lock
542 // Go to the next iteration to see if another hash group has a task
543 }
544 }
545
546 event->FinishTask();
547 return TaskExecutionResult::TASK_FINISHED;
548}
549
550void PartitionMergeEvent::Schedule() {
551 auto &context = pipeline->GetClientContext();
552
553 // Schedule tasks equal to the number of threads, which will each merge multiple partitions
554 auto &ts = TaskScheduler::GetScheduler(context);
555 idx_t num_threads = ts.NumberOfThreads();
556
557 vector<shared_ptr<Task>> merge_tasks;
558 for (idx_t tnum = 0; tnum < num_threads; tnum++) {
559 merge_tasks.emplace_back(args: make_uniq<PartitionMergeTask>(args: shared_from_this(), args&: context, args&: merge_states));
560 }
561 SetTasks(std::move(merge_tasks));
562}
563
564} // namespace duckdb
565