1 | #include "duckdb/common/sort/partition_state.hpp" |
2 | |
3 | #include "duckdb/common/types/column/column_data_consumer.hpp" |
4 | #include "duckdb/common/row_operations/row_operations.hpp" |
5 | #include "duckdb/main/config.hpp" |
6 | #include "duckdb/parallel/event.hpp" |
7 | |
8 | #include <numeric> |
9 | |
10 | namespace duckdb { |
11 | |
12 | PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, |
13 | const Orders &orders, const Types &payload_types, bool external) |
14 | : count(0) { |
15 | |
16 | RowLayout payload_layout; |
17 | payload_layout.Initialize(types: payload_types); |
18 | global_sort = make_uniq<GlobalSortState>(args&: buffer_manager, args: orders, args&: payload_layout); |
19 | global_sort->external = external; |
20 | |
21 | // Set up a comparator for the partition subset |
22 | partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(num_prefix_cols: partitions.size()); |
23 | } |
24 | |
25 | int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const { |
26 | int part_cmp = 0; |
27 | if (partition_layout.all_constant) { |
28 | part_cmp = FastMemcmp(str1: left.entry_ptr, str2: right.entry_ptr, size: partition_layout.comparison_size); |
29 | } else { |
30 | part_cmp = Comparators::CompareTuple(left: left.scan, right: right.scan, l_ptr: left.entry_ptr, r_ptr: right.entry_ptr, sort_layout: partition_layout, |
31 | external_sort: left.external); |
32 | } |
33 | return part_cmp; |
34 | } |
35 | |
36 | void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) { |
37 | D_ASSERT(count > 0); |
38 | |
39 | SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); |
40 | SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); |
41 | |
42 | partition_mask.SetValidUnsafe(0); |
43 | order_mask.SetValidUnsafe(0); |
44 | for (++curr; curr.GetIndex() < count; ++curr) { |
45 | // Compare the partition subset first because if that differs, then so does the full ordering |
46 | const auto part_cmp = ComparePartitions(left: prev, right: curr); |
47 | ; |
48 | |
49 | if (part_cmp) { |
50 | partition_mask.SetValidUnsafe(curr.GetIndex()); |
51 | order_mask.SetValidUnsafe(curr.GetIndex()); |
52 | } else if (prev.Compare(other: curr)) { |
53 | order_mask.SetValidUnsafe(curr.GetIndex()); |
54 | } |
55 | ++prev; |
56 | } |
57 | } |
58 | |
59 | void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, |
60 | const vector<unique_ptr<Expression>> &partition_bys, |
61 | const Orders &order_bys, |
62 | const vector<unique_ptr<BaseStatistics>> &partition_stats) { |
63 | |
64 | // we sort by both 1) partition by expression list and 2) order by expressions |
65 | const auto partition_cols = partition_bys.size(); |
66 | for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { |
67 | auto &pexpr = partition_bys[prt_idx]; |
68 | |
69 | if (partition_stats.empty() || !partition_stats[prt_idx]) { |
70 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), args: nullptr); |
71 | } else { |
72 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), |
73 | args: partition_stats[prt_idx]->ToUnique()); |
74 | } |
75 | partitions.emplace_back(args: orders.back().Copy()); |
76 | } |
77 | |
78 | for (const auto &order : order_bys) { |
79 | orders.emplace_back(args: order.Copy()); |
80 | } |
81 | } |
82 | |
83 | PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, |
84 | const vector<unique_ptr<Expression>> &partition_bys, |
85 | const vector<BoundOrderByNode> &order_bys, |
86 | const Types &payload_types, |
87 | const vector<unique_ptr<BaseStatistics>> &partition_stats, |
88 | idx_t estimated_cardinality) |
89 | : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), |
90 | payload_types(payload_types), memory_per_thread(0), count(0) { |
91 | |
92 | GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); |
93 | |
94 | memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); |
95 | external = ClientConfig::GetConfig(context).force_external; |
96 | |
97 | if (!orders.empty()) { |
98 | grouping_types = payload_types; |
99 | grouping_types.push_back(x: LogicalType::HASH); |
100 | |
101 | ResizeGroupingData(cardinality: estimated_cardinality); |
102 | } |
103 | } |
104 | |
105 | void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { |
106 | // Have we started to combine? Then just live with it. |
107 | if (grouping_data && !grouping_data->GetPartitions().empty()) { |
108 | return; |
109 | } |
110 | // Is the average partition size too large? |
111 | const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE; |
112 | const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; |
113 | auto new_bits = bits ? bits : 4; |
114 | while (new_bits < 10 && (cardinality / RadixPartitioning::NumberOfPartitions(radix_bits: new_bits)) > partition_size) { |
115 | ++new_bits; |
116 | } |
117 | |
118 | // Repartition the grouping data |
119 | if (new_bits != bits) { |
120 | const auto hash_col_idx = payload_types.size(); |
121 | grouping_data = make_uniq<RadixPartitionedColumnData>(args&: context, args&: grouping_types, args&: new_bits, args: hash_col_idx); |
122 | } |
123 | } |
124 | |
125 | void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
126 | // We are done if the local_partition is right sized. |
127 | auto &local_radix = local_partition->Cast<RadixPartitionedColumnData>(); |
128 | if (local_radix.GetRadixBits() == grouping_data->GetRadixBits()) { |
129 | return; |
130 | } |
131 | |
132 | // If the local partition is now too small, flush it and reallocate |
133 | auto new_partition = grouping_data->CreateShared(); |
134 | auto new_append = make_uniq<PartitionedColumnDataAppendState>(); |
135 | new_partition->InitializeAppendState(state&: *new_append); |
136 | |
137 | local_partition->FlushAppendState(state&: *local_append); |
138 | auto &local_groups = local_partition->GetPartitions(); |
139 | for (auto &local_group : local_groups) { |
140 | ColumnDataScanState scanner; |
141 | local_group->InitializeScan(state&: scanner); |
142 | |
143 | DataChunk scan_chunk; |
144 | local_group->InitializeScanChunk(chunk&: scan_chunk); |
145 | for (scan_chunk.Reset(); local_group->Scan(state&: scanner, result&: scan_chunk); scan_chunk.Reset()) { |
146 | new_partition->Append(state&: *new_append, input&: scan_chunk); |
147 | } |
148 | } |
149 | |
150 | // The append state has stale pointers to the old local partition, so nuke it from orbit. |
151 | new_partition->FlushAppendState(state&: *new_append); |
152 | |
153 | local_partition = std::move(new_partition); |
154 | local_append = make_uniq<PartitionedColumnDataAppendState>(); |
155 | local_partition->InitializeAppendState(state&: *local_append); |
156 | } |
157 | |
158 | void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
159 | // Make sure grouping_data doesn't change under us. |
160 | lock_guard<mutex> guard(lock); |
161 | |
162 | if (!local_partition) { |
163 | local_partition = grouping_data->CreateShared(); |
164 | local_append = make_uniq<PartitionedColumnDataAppendState>(); |
165 | local_partition->InitializeAppendState(state&: *local_append); |
166 | return; |
167 | } |
168 | |
169 | // Grow the groups if they are too big |
170 | ResizeGroupingData(cardinality: count); |
171 | |
172 | // Sync local partition to have the same bit count |
173 | SyncLocalPartition(local_partition, local_append); |
174 | } |
175 | |
176 | void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
177 | if (!local_partition) { |
178 | return; |
179 | } |
180 | local_partition->FlushAppendState(state&: *local_append); |
181 | |
182 | // Make sure grouping_data doesn't change under us. |
183 | // Combine has an internal mutex, so this is single-threaded anyway. |
184 | lock_guard<mutex> guard(lock); |
185 | SyncLocalPartition(local_partition, local_append); |
186 | grouping_data->Combine(other&: *local_partition); |
187 | } |
188 | |
189 | void PartitionGlobalSinkState::BuildSortState(ColumnDataCollection &group_data, PartitionGlobalHashGroup &hash_group) { |
190 | auto &global_sort = *hash_group.global_sort; |
191 | |
192 | // Set up the sort expression computation. |
193 | vector<LogicalType> sort_types; |
194 | ExpressionExecutor executor(context); |
195 | for (auto &order : orders) { |
196 | auto &oexpr = order.expression; |
197 | sort_types.emplace_back(args&: oexpr->return_type); |
198 | executor.AddExpression(expr: *oexpr); |
199 | } |
200 | DataChunk sort_chunk; |
201 | sort_chunk.Initialize(allocator, types: sort_types); |
202 | |
203 | // Copy the data from the group into the sort code. |
204 | LocalSortState local_sort; |
205 | local_sort.Initialize(global_sort_state&: global_sort, buffer_manager_p&: global_sort.buffer_manager); |
206 | |
207 | // Strip hash column |
208 | DataChunk payload_chunk; |
209 | payload_chunk.Initialize(allocator, types: payload_types); |
210 | |
211 | vector<column_t> column_ids; |
212 | column_ids.reserve(n: payload_types.size()); |
213 | for (column_t i = 0; i < payload_types.size(); ++i) { |
214 | column_ids.emplace_back(args&: i); |
215 | } |
216 | ColumnDataConsumer scanner(group_data, column_ids); |
217 | ColumnDataConsumerScanState chunk_state; |
218 | chunk_state.current_chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; |
219 | scanner.InitializeScan(); |
220 | for (auto chunk_idx = scanner.ChunkCount(); chunk_idx-- > 0;) { |
221 | if (!scanner.AssignChunk(state&: chunk_state)) { |
222 | break; |
223 | } |
224 | scanner.ScanChunk(state&: chunk_state, chunk&: payload_chunk); |
225 | |
226 | sort_chunk.Reset(); |
227 | executor.Execute(input&: payload_chunk, result&: sort_chunk); |
228 | |
229 | local_sort.SinkChunk(sort&: sort_chunk, payload&: payload_chunk); |
230 | if (local_sort.SizeInBytes() > memory_per_thread) { |
231 | local_sort.Sort(global_sort_state&: global_sort, reorder_heap: true); |
232 | } |
233 | scanner.FinishChunk(state&: chunk_state); |
234 | } |
235 | |
236 | global_sort.AddLocalState(local_sort_state&: local_sort); |
237 | |
238 | hash_group.count += group_data.Count(); |
239 | } |
240 | |
241 | // Per-thread sink state |
242 | PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) |
243 | : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { |
244 | |
245 | vector<LogicalType> group_types; |
246 | for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { |
247 | auto &pexpr = *gstate.partitions[prt_idx].expression.get(); |
248 | group_types.push_back(x: pexpr.return_type); |
249 | executor.AddExpression(expr: pexpr); |
250 | } |
251 | sort_cols = gstate.orders.size() + group_types.size(); |
252 | |
253 | if (sort_cols) { |
254 | if (!group_types.empty()) { |
255 | // OVER(PARTITION BY...) |
256 | group_chunk.Initialize(allocator, types: group_types); |
257 | } |
258 | // OVER(...) |
259 | auto payload_types = gstate.payload_types; |
260 | payload_types.emplace_back(args: LogicalType::HASH); |
261 | payload_chunk.Initialize(allocator, types: payload_types); |
262 | } else { |
263 | // OVER() |
264 | payload_layout.Initialize(types: gstate.payload_types); |
265 | } |
266 | } |
267 | |
268 | void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { |
269 | const auto count = input_chunk.size(); |
270 | if (group_chunk.ColumnCount() > 0) { |
271 | // OVER(PARTITION BY...) (hash grouping) |
272 | group_chunk.Reset(); |
273 | executor.Execute(input&: input_chunk, result&: group_chunk); |
274 | VectorOperations::Hash(input&: group_chunk.data[0], hashes&: hash_vector, count); |
275 | for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { |
276 | VectorOperations::CombineHash(hashes&: hash_vector, input&: group_chunk.data[prt_idx], count); |
277 | } |
278 | } else { |
279 | // OVER(...) (sorting) |
280 | // Single partition => single hash value |
281 | hash_vector.SetVectorType(VectorType::CONSTANT_VECTOR); |
282 | auto hashes = ConstantVector::GetData<hash_t>(vector&: hash_vector); |
283 | hashes[0] = 0; |
284 | } |
285 | } |
286 | |
287 | void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { |
288 | gstate.count += input_chunk.size(); |
289 | |
290 | // OVER() |
291 | if (sort_cols == 0) { |
292 | // No sorts, so build paged row chunks |
293 | if (!rows) { |
294 | const auto entry_size = payload_layout.GetRowWidth(); |
295 | const auto capacity = MaxValue<idx_t>(STANDARD_VECTOR_SIZE, b: (Storage::BLOCK_SIZE / entry_size) + 1); |
296 | rows = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: capacity, args: entry_size); |
297 | strings = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: (idx_t)Storage::BLOCK_SIZE, args: 1, args: true); |
298 | } |
299 | const auto row_count = input_chunk.size(); |
300 | const auto row_sel = FlatVector::IncrementalSelectionVector(); |
301 | Vector addresses(LogicalType::POINTER); |
302 | auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
303 | const auto prev_rows_blocks = rows->blocks.size(); |
304 | auto handles = rows->Build(added_count: row_count, key_locations, entry_sizes: nullptr, sel: row_sel); |
305 | auto input_data = input_chunk.ToUnifiedFormat(); |
306 | RowOperations::Scatter(columns&: input_chunk, col_data: input_data.get(), layout: payload_layout, rows&: addresses, string_heap&: *strings, sel: *row_sel, count: row_count); |
307 | // Mark that row blocks contain pointers (heap blocks are pinned) |
308 | if (!payload_layout.AllConstant()) { |
309 | D_ASSERT(strings->keep_pinned); |
310 | for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { |
311 | rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink" ); |
312 | } |
313 | } |
314 | return; |
315 | } |
316 | |
317 | // OVER(...) |
318 | payload_chunk.Reset(); |
319 | auto &hash_vector = payload_chunk.data.back(); |
320 | Hash(input_chunk, hash_vector); |
321 | for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { |
322 | payload_chunk.data[col_idx].Reference(other&: input_chunk.data[col_idx]); |
323 | } |
324 | payload_chunk.SetCardinality(input_chunk); |
325 | |
326 | gstate.UpdateLocalPartition(local_partition, local_append); |
327 | local_partition->Append(state&: *local_append, input&: payload_chunk); |
328 | } |
329 | |
330 | void PartitionLocalSinkState::Combine() { |
331 | // OVER() |
332 | if (sort_cols == 0) { |
333 | // Only one partition again, so need a global lock. |
334 | lock_guard<mutex> glock(gstate.lock); |
335 | if (gstate.rows) { |
336 | if (rows) { |
337 | gstate.rows->Merge(other&: *rows); |
338 | gstate.strings->Merge(other&: *strings); |
339 | rows.reset(); |
340 | strings.reset(); |
341 | } |
342 | } else { |
343 | gstate.rows = std::move(rows); |
344 | gstate.strings = std::move(strings); |
345 | } |
346 | return; |
347 | } |
348 | |
349 | // OVER(...) |
350 | gstate.CombineLocalPartition(local_partition, local_append); |
351 | } |
352 | |
353 | PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, |
354 | hash_t hash_bin) |
355 | : sink(sink), group_data(std::move(group_data)), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), |
356 | tasks_completed(0) { |
357 | |
358 | const auto group_idx = sink.hash_groups.size(); |
359 | auto new_group = make_uniq<PartitionGlobalHashGroup>(args&: sink.buffer_manager, args&: sink.partitions, args&: sink.orders, |
360 | args: sink.payload_types, args&: sink.external); |
361 | sink.hash_groups.emplace_back(args: std::move(new_group)); |
362 | |
363 | hash_group = sink.hash_groups[group_idx].get(); |
364 | global_sort = sink.hash_groups[group_idx]->global_sort.get(); |
365 | |
366 | sink.bin_groups[hash_bin] = group_idx; |
367 | } |
368 | |
369 | void PartitionLocalMergeState::Prepare() { |
370 | auto &global_sort = *merge_state->global_sort; |
371 | merge_state->sink.BuildSortState(group_data&: *merge_state->group_data, hash_group&: *merge_state->hash_group); |
372 | merge_state->group_data.reset(); |
373 | |
374 | global_sort.PrepareMergePhase(); |
375 | } |
376 | |
377 | void PartitionLocalMergeState::Merge() { |
378 | auto &global_sort = *merge_state->global_sort; |
379 | MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); |
380 | merge_sorter.PerformInMergeRound(); |
381 | } |
382 | |
383 | void PartitionLocalMergeState::ExecuteTask() { |
384 | switch (stage) { |
385 | case PartitionSortStage::PREPARE: |
386 | Prepare(); |
387 | break; |
388 | case PartitionSortStage::MERGE: |
389 | Merge(); |
390 | break; |
391 | default: |
392 | throw InternalException("Unexpected PartitionGlobalMergeState in ExecuteTask!" ); |
393 | } |
394 | |
395 | merge_state->CompleteTask(); |
396 | finished = true; |
397 | } |
398 | |
399 | bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { |
400 | lock_guard<mutex> guard(lock); |
401 | |
402 | if (tasks_assigned >= total_tasks) { |
403 | return false; |
404 | } |
405 | |
406 | local_state.merge_state = this; |
407 | local_state.stage = stage; |
408 | local_state.finished = false; |
409 | tasks_assigned++; |
410 | |
411 | return true; |
412 | } |
413 | |
414 | void PartitionGlobalMergeState::CompleteTask() { |
415 | lock_guard<mutex> guard(lock); |
416 | |
417 | ++tasks_completed; |
418 | } |
419 | |
420 | bool PartitionGlobalMergeState::TryPrepareNextStage() { |
421 | lock_guard<mutex> guard(lock); |
422 | |
423 | if (tasks_completed < total_tasks) { |
424 | return false; |
425 | } |
426 | |
427 | tasks_assigned = tasks_completed = 0; |
428 | |
429 | switch (stage) { |
430 | case PartitionSortStage::INIT: |
431 | total_tasks = 1; |
432 | stage = PartitionSortStage::PREPARE; |
433 | return true; |
434 | |
435 | case PartitionSortStage::PREPARE: |
436 | total_tasks = global_sort->sorted_blocks.size() / 2; |
437 | if (!total_tasks) { |
438 | break; |
439 | } |
440 | stage = PartitionSortStage::MERGE; |
441 | global_sort->InitializeMergeRound(); |
442 | return true; |
443 | |
444 | case PartitionSortStage::MERGE: |
445 | global_sort->CompleteMergeRound(keep_radix_data: true); |
446 | total_tasks = global_sort->sorted_blocks.size() / 2; |
447 | if (!total_tasks) { |
448 | break; |
449 | } |
450 | global_sort->InitializeMergeRound(); |
451 | return true; |
452 | |
453 | case PartitionSortStage::SORTED: |
454 | break; |
455 | } |
456 | |
457 | stage = PartitionSortStage::SORTED; |
458 | |
459 | return false; |
460 | } |
461 | |
462 | PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { |
463 | // Schedule all the sorts for maximum thread utilisation |
464 | auto &partitions = sink.grouping_data->GetPartitions(); |
465 | sink.bin_groups.resize(new_size: partitions.size(), x: partitions.size()); |
466 | for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { |
467 | auto &group_data = partitions[hash_bin]; |
468 | // Prepare for merge sort phase |
469 | if (group_data->Count()) { |
470 | auto state = make_uniq<PartitionGlobalMergeState>(args&: sink, args: std::move(group_data), args&: hash_bin); |
471 | states.emplace_back(args: std::move(state)); |
472 | } |
473 | } |
474 | } |
475 | |
476 | class PartitionMergeTask : public ExecutorTask { |
477 | public: |
478 | PartitionMergeTask(shared_ptr<Event> event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p) |
479 | : ExecutorTask(context_p), event(std::move(event_p)), hash_groups(hash_groups_p) { |
480 | } |
481 | |
482 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; |
483 | |
484 | private: |
485 | shared_ptr<Event> event; |
486 | PartitionLocalMergeState local_state; |
487 | PartitionGlobalMergeStates &hash_groups; |
488 | }; |
489 | |
490 | TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { |
491 | // Loop until all hash groups are done |
492 | size_t sorted = 0; |
493 | while (sorted < hash_groups.states.size()) { |
494 | // First check if there is an unfinished task for this thread |
495 | if (executor.HasError()) { |
496 | return TaskExecutionResult::TASK_ERROR; |
497 | } |
498 | if (!local_state.TaskFinished()) { |
499 | local_state.ExecuteTask(); |
500 | continue; |
501 | } |
502 | |
503 | // Thread is done with its assigned task, try to fetch new work |
504 | for (auto group = sorted; group < hash_groups.states.size(); ++group) { |
505 | auto &global_state = hash_groups.states[group]; |
506 | if (global_state->IsSorted()) { |
507 | // This hash group is done |
508 | // Update the high water mark of densely completed groups |
509 | if (sorted == group) { |
510 | ++sorted; |
511 | } |
512 | continue; |
513 | } |
514 | |
515 | // Try to assign work for this hash group to this thread |
516 | if (global_state->AssignTask(local_state)) { |
517 | // We assigned a task to this thread! |
518 | // Break out of this loop to re-enter the top-level loop and execute the task |
519 | break; |
520 | } |
521 | |
522 | // Hash group global state couldn't assign a task to this thread |
523 | // Try to prepare the next stage |
524 | if (!global_state->TryPrepareNextStage()) { |
525 | // This current hash group is not yet done |
526 | // But we were not able to assign a task for it to this thread |
527 | // See if the next hash group is better |
528 | continue; |
529 | } |
530 | |
531 | // We were able to prepare the next stage for this hash group! |
532 | // Try to assign a task once more |
533 | if (global_state->AssignTask(local_state)) { |
534 | // We assigned a task to this thread! |
535 | // Break out of this loop to re-enter the top-level loop and execute the task |
536 | break; |
537 | } |
538 | |
539 | // We were able to prepare the next merge round, |
540 | // but we were not able to assign a task for it to this thread |
541 | // The tasks were assigned to other threads while this thread waited for the lock |
542 | // Go to the next iteration to see if another hash group has a task |
543 | } |
544 | } |
545 | |
546 | event->FinishTask(); |
547 | return TaskExecutionResult::TASK_FINISHED; |
548 | } |
549 | |
550 | void PartitionMergeEvent::Schedule() { |
551 | auto &context = pipeline->GetClientContext(); |
552 | |
553 | // Schedule tasks equal to the number of threads, which will each merge multiple partitions |
554 | auto &ts = TaskScheduler::GetScheduler(context); |
555 | idx_t num_threads = ts.NumberOfThreads(); |
556 | |
557 | vector<shared_ptr<Task>> merge_tasks; |
558 | for (idx_t tnum = 0; tnum < num_threads; tnum++) { |
559 | merge_tasks.emplace_back(args: make_uniq<PartitionMergeTask>(args: shared_from_this(), args&: context, args&: merge_states)); |
560 | } |
561 | SetTasks(std::move(merge_tasks)); |
562 | } |
563 | |
564 | } // namespace duckdb |
565 | |