| 1 | #include "duckdb/common/sort/partition_state.hpp" | 
|---|
| 2 |  | 
|---|
| 3 | #include "duckdb/common/types/column/column_data_consumer.hpp" | 
|---|
| 4 | #include "duckdb/common/row_operations/row_operations.hpp" | 
|---|
| 5 | #include "duckdb/main/config.hpp" | 
|---|
| 6 | #include "duckdb/parallel/event.hpp" | 
|---|
| 7 |  | 
|---|
| 8 | #include <numeric> | 
|---|
| 9 |  | 
|---|
| 10 | namespace duckdb { | 
|---|
| 11 |  | 
|---|
| 12 | PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, | 
|---|
| 13 | const Orders &orders, const Types &payload_types, bool external) | 
|---|
| 14 | : count(0) { | 
|---|
| 15 |  | 
|---|
| 16 | RowLayout payload_layout; | 
|---|
| 17 | payload_layout.Initialize(types: payload_types); | 
|---|
| 18 | global_sort = make_uniq<GlobalSortState>(args&: buffer_manager, args: orders, args&: payload_layout); | 
|---|
| 19 | global_sort->external = external; | 
|---|
| 20 |  | 
|---|
| 21 | //	Set up a comparator for the partition subset | 
|---|
| 22 | partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(num_prefix_cols: partitions.size()); | 
|---|
| 23 | } | 
|---|
| 24 |  | 
|---|
| 25 | int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const { | 
|---|
| 26 | int part_cmp = 0; | 
|---|
| 27 | if (partition_layout.all_constant) { | 
|---|
| 28 | part_cmp = FastMemcmp(str1: left.entry_ptr, str2: right.entry_ptr, size: partition_layout.comparison_size); | 
|---|
| 29 | } else { | 
|---|
| 30 | part_cmp = Comparators::CompareTuple(left: left.scan, right: right.scan, l_ptr: left.entry_ptr, r_ptr: right.entry_ptr, sort_layout: partition_layout, | 
|---|
| 31 | external_sort: left.external); | 
|---|
| 32 | } | 
|---|
| 33 | return part_cmp; | 
|---|
| 34 | } | 
|---|
| 35 |  | 
|---|
| 36 | void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) { | 
|---|
| 37 | D_ASSERT(count > 0); | 
|---|
| 38 |  | 
|---|
| 39 | SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); | 
|---|
| 40 | SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); | 
|---|
| 41 |  | 
|---|
| 42 | partition_mask.SetValidUnsafe(0); | 
|---|
| 43 | order_mask.SetValidUnsafe(0); | 
|---|
| 44 | for (++curr; curr.GetIndex() < count; ++curr) { | 
|---|
| 45 | //	Compare the partition subset first because if that differs, then so does the full ordering | 
|---|
| 46 | const auto part_cmp = ComparePartitions(left: prev, right: curr); | 
|---|
| 47 | ; | 
|---|
| 48 |  | 
|---|
| 49 | if (part_cmp) { | 
|---|
| 50 | partition_mask.SetValidUnsafe(curr.GetIndex()); | 
|---|
| 51 | order_mask.SetValidUnsafe(curr.GetIndex()); | 
|---|
| 52 | } else if (prev.Compare(other: curr)) { | 
|---|
| 53 | order_mask.SetValidUnsafe(curr.GetIndex()); | 
|---|
| 54 | } | 
|---|
| 55 | ++prev; | 
|---|
| 56 | } | 
|---|
| 57 | } | 
|---|
| 58 |  | 
|---|
| 59 | void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, | 
|---|
| 60 | const vector<unique_ptr<Expression>> &partition_bys, | 
|---|
| 61 | const Orders &order_bys, | 
|---|
| 62 | const vector<unique_ptr<BaseStatistics>> &partition_stats) { | 
|---|
| 63 |  | 
|---|
| 64 | // we sort by both 1) partition by expression list and 2) order by expressions | 
|---|
| 65 | const auto partition_cols = partition_bys.size(); | 
|---|
| 66 | for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { | 
|---|
| 67 | auto &pexpr = partition_bys[prt_idx]; | 
|---|
| 68 |  | 
|---|
| 69 | if (partition_stats.empty() || !partition_stats[prt_idx]) { | 
|---|
| 70 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), args: nullptr); | 
|---|
| 71 | } else { | 
|---|
| 72 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), | 
|---|
| 73 | args: partition_stats[prt_idx]->ToUnique()); | 
|---|
| 74 | } | 
|---|
| 75 | partitions.emplace_back(args: orders.back().Copy()); | 
|---|
| 76 | } | 
|---|
| 77 |  | 
|---|
| 78 | for (const auto &order : order_bys) { | 
|---|
| 79 | orders.emplace_back(args: order.Copy()); | 
|---|
| 80 | } | 
|---|
| 81 | } | 
|---|
| 82 |  | 
|---|
| 83 | PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, | 
|---|
| 84 | const vector<unique_ptr<Expression>> &partition_bys, | 
|---|
| 85 | const vector<BoundOrderByNode> &order_bys, | 
|---|
| 86 | const Types &payload_types, | 
|---|
| 87 | const vector<unique_ptr<BaseStatistics>> &partition_stats, | 
|---|
| 88 | idx_t estimated_cardinality) | 
|---|
| 89 | : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), | 
|---|
| 90 | payload_types(payload_types), memory_per_thread(0), count(0) { | 
|---|
| 91 |  | 
|---|
| 92 | GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); | 
|---|
| 93 |  | 
|---|
| 94 | memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); | 
|---|
| 95 | external = ClientConfig::GetConfig(context).force_external; | 
|---|
| 96 |  | 
|---|
| 97 | if (!orders.empty()) { | 
|---|
| 98 | grouping_types = payload_types; | 
|---|
| 99 | grouping_types.push_back(x: LogicalType::HASH); | 
|---|
| 100 |  | 
|---|
| 101 | ResizeGroupingData(cardinality: estimated_cardinality); | 
|---|
| 102 | } | 
|---|
| 103 | } | 
|---|
| 104 |  | 
|---|
| 105 | void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { | 
|---|
| 106 | //	Have we started to combine? Then just live with it. | 
|---|
| 107 | if (grouping_data && !grouping_data->GetPartitions().empty()) { | 
|---|
| 108 | return; | 
|---|
| 109 | } | 
|---|
| 110 | //	Is the average partition size too large? | 
|---|
| 111 | const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE; | 
|---|
| 112 | const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; | 
|---|
| 113 | auto new_bits = bits ? bits : 4; | 
|---|
| 114 | while (new_bits < 10 && (cardinality / RadixPartitioning::NumberOfPartitions(radix_bits: new_bits)) > partition_size) { | 
|---|
| 115 | ++new_bits; | 
|---|
| 116 | } | 
|---|
| 117 |  | 
|---|
| 118 | // Repartition the grouping data | 
|---|
| 119 | if (new_bits != bits) { | 
|---|
| 120 | const auto hash_col_idx = payload_types.size(); | 
|---|
| 121 | grouping_data = make_uniq<RadixPartitionedColumnData>(args&: context, args&: grouping_types, args&: new_bits, args: hash_col_idx); | 
|---|
| 122 | } | 
|---|
| 123 | } | 
|---|
| 124 |  | 
|---|
| 125 | void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { | 
|---|
| 126 | // We are done if the local_partition is right sized. | 
|---|
| 127 | auto &local_radix = local_partition->Cast<RadixPartitionedColumnData>(); | 
|---|
| 128 | if (local_radix.GetRadixBits() == grouping_data->GetRadixBits()) { | 
|---|
| 129 | return; | 
|---|
| 130 | } | 
|---|
| 131 |  | 
|---|
| 132 | // If the local partition is now too small, flush it and reallocate | 
|---|
| 133 | auto new_partition = grouping_data->CreateShared(); | 
|---|
| 134 | auto new_append = make_uniq<PartitionedColumnDataAppendState>(); | 
|---|
| 135 | new_partition->InitializeAppendState(state&: *new_append); | 
|---|
| 136 |  | 
|---|
| 137 | local_partition->FlushAppendState(state&: *local_append); | 
|---|
| 138 | auto &local_groups = local_partition->GetPartitions(); | 
|---|
| 139 | for (auto &local_group : local_groups) { | 
|---|
| 140 | ColumnDataScanState scanner; | 
|---|
| 141 | local_group->InitializeScan(state&: scanner); | 
|---|
| 142 |  | 
|---|
| 143 | DataChunk scan_chunk; | 
|---|
| 144 | local_group->InitializeScanChunk(chunk&: scan_chunk); | 
|---|
| 145 | for (scan_chunk.Reset(); local_group->Scan(state&: scanner, result&: scan_chunk); scan_chunk.Reset()) { | 
|---|
| 146 | new_partition->Append(state&: *new_append, input&: scan_chunk); | 
|---|
| 147 | } | 
|---|
| 148 | } | 
|---|
| 149 |  | 
|---|
| 150 | // The append state has stale pointers to the old local partition, so nuke it from orbit. | 
|---|
| 151 | new_partition->FlushAppendState(state&: *new_append); | 
|---|
| 152 |  | 
|---|
| 153 | local_partition = std::move(new_partition); | 
|---|
| 154 | local_append = make_uniq<PartitionedColumnDataAppendState>(); | 
|---|
| 155 | local_partition->InitializeAppendState(state&: *local_append); | 
|---|
| 156 | } | 
|---|
| 157 |  | 
|---|
| 158 | void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { | 
|---|
| 159 | // Make sure grouping_data doesn't change under us. | 
|---|
| 160 | lock_guard<mutex> guard(lock); | 
|---|
| 161 |  | 
|---|
| 162 | if (!local_partition) { | 
|---|
| 163 | local_partition = grouping_data->CreateShared(); | 
|---|
| 164 | local_append = make_uniq<PartitionedColumnDataAppendState>(); | 
|---|
| 165 | local_partition->InitializeAppendState(state&: *local_append); | 
|---|
| 166 | return; | 
|---|
| 167 | } | 
|---|
| 168 |  | 
|---|
| 169 | // 	Grow the groups if they are too big | 
|---|
| 170 | ResizeGroupingData(cardinality: count); | 
|---|
| 171 |  | 
|---|
| 172 | //	Sync local partition to have the same bit count | 
|---|
| 173 | SyncLocalPartition(local_partition, local_append); | 
|---|
| 174 | } | 
|---|
| 175 |  | 
|---|
| 176 | void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { | 
|---|
| 177 | if (!local_partition) { | 
|---|
| 178 | return; | 
|---|
| 179 | } | 
|---|
| 180 | local_partition->FlushAppendState(state&: *local_append); | 
|---|
| 181 |  | 
|---|
| 182 | // Make sure grouping_data doesn't change under us. | 
|---|
| 183 | // Combine has an internal mutex, so this is single-threaded anyway. | 
|---|
| 184 | lock_guard<mutex> guard(lock); | 
|---|
| 185 | SyncLocalPartition(local_partition, local_append); | 
|---|
| 186 | grouping_data->Combine(other&: *local_partition); | 
|---|
| 187 | } | 
|---|
| 188 |  | 
|---|
| 189 | void PartitionGlobalSinkState::BuildSortState(ColumnDataCollection &group_data, PartitionGlobalHashGroup &hash_group) { | 
|---|
| 190 | auto &global_sort = *hash_group.global_sort; | 
|---|
| 191 |  | 
|---|
| 192 | //	 Set up the sort expression computation. | 
|---|
| 193 | vector<LogicalType> sort_types; | 
|---|
| 194 | ExpressionExecutor executor(context); | 
|---|
| 195 | for (auto &order : orders) { | 
|---|
| 196 | auto &oexpr = order.expression; | 
|---|
| 197 | sort_types.emplace_back(args&: oexpr->return_type); | 
|---|
| 198 | executor.AddExpression(expr: *oexpr); | 
|---|
| 199 | } | 
|---|
| 200 | DataChunk sort_chunk; | 
|---|
| 201 | sort_chunk.Initialize(allocator, types: sort_types); | 
|---|
| 202 |  | 
|---|
| 203 | // Copy the data from the group into the sort code. | 
|---|
| 204 | LocalSortState local_sort; | 
|---|
| 205 | local_sort.Initialize(global_sort_state&: global_sort, buffer_manager_p&: global_sort.buffer_manager); | 
|---|
| 206 |  | 
|---|
| 207 | //	Strip hash column | 
|---|
| 208 | DataChunk payload_chunk; | 
|---|
| 209 | payload_chunk.Initialize(allocator, types: payload_types); | 
|---|
| 210 |  | 
|---|
| 211 | vector<column_t> column_ids; | 
|---|
| 212 | column_ids.reserve(n: payload_types.size()); | 
|---|
| 213 | for (column_t i = 0; i < payload_types.size(); ++i) { | 
|---|
| 214 | column_ids.emplace_back(args&: i); | 
|---|
| 215 | } | 
|---|
| 216 | ColumnDataConsumer scanner(group_data, column_ids); | 
|---|
| 217 | ColumnDataConsumerScanState chunk_state; | 
|---|
| 218 | chunk_state.current_chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; | 
|---|
| 219 | scanner.InitializeScan(); | 
|---|
| 220 | for (auto chunk_idx = scanner.ChunkCount(); chunk_idx-- > 0;) { | 
|---|
| 221 | if (!scanner.AssignChunk(state&: chunk_state)) { | 
|---|
| 222 | break; | 
|---|
| 223 | } | 
|---|
| 224 | scanner.ScanChunk(state&: chunk_state, chunk&: payload_chunk); | 
|---|
| 225 |  | 
|---|
| 226 | sort_chunk.Reset(); | 
|---|
| 227 | executor.Execute(input&: payload_chunk, result&: sort_chunk); | 
|---|
| 228 |  | 
|---|
| 229 | local_sort.SinkChunk(sort&: sort_chunk, payload&: payload_chunk); | 
|---|
| 230 | if (local_sort.SizeInBytes() > memory_per_thread) { | 
|---|
| 231 | local_sort.Sort(global_sort_state&: global_sort, reorder_heap: true); | 
|---|
| 232 | } | 
|---|
| 233 | scanner.FinishChunk(state&: chunk_state); | 
|---|
| 234 | } | 
|---|
| 235 |  | 
|---|
| 236 | global_sort.AddLocalState(local_sort_state&: local_sort); | 
|---|
| 237 |  | 
|---|
| 238 | hash_group.count += group_data.Count(); | 
|---|
| 239 | } | 
|---|
| 240 |  | 
|---|
| 241 | //	Per-thread sink state | 
|---|
| 242 | PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) | 
|---|
| 243 | : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { | 
|---|
| 244 |  | 
|---|
| 245 | vector<LogicalType> group_types; | 
|---|
| 246 | for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { | 
|---|
| 247 | auto &pexpr = *gstate.partitions[prt_idx].expression.get(); | 
|---|
| 248 | group_types.push_back(x: pexpr.return_type); | 
|---|
| 249 | executor.AddExpression(expr: pexpr); | 
|---|
| 250 | } | 
|---|
| 251 | sort_cols = gstate.orders.size() + group_types.size(); | 
|---|
| 252 |  | 
|---|
| 253 | if (sort_cols) { | 
|---|
| 254 | if (!group_types.empty()) { | 
|---|
| 255 | // OVER(PARTITION BY...) | 
|---|
| 256 | group_chunk.Initialize(allocator, types: group_types); | 
|---|
| 257 | } | 
|---|
| 258 | // OVER(...) | 
|---|
| 259 | auto payload_types = gstate.payload_types; | 
|---|
| 260 | payload_types.emplace_back(args: LogicalType::HASH); | 
|---|
| 261 | payload_chunk.Initialize(allocator, types: payload_types); | 
|---|
| 262 | } else { | 
|---|
| 263 | // OVER() | 
|---|
| 264 | payload_layout.Initialize(types: gstate.payload_types); | 
|---|
| 265 | } | 
|---|
| 266 | } | 
|---|
| 267 |  | 
|---|
| 268 | void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { | 
|---|
| 269 | const auto count = input_chunk.size(); | 
|---|
| 270 | if (group_chunk.ColumnCount() > 0) { | 
|---|
| 271 | // OVER(PARTITION BY...) (hash grouping) | 
|---|
| 272 | group_chunk.Reset(); | 
|---|
| 273 | executor.Execute(input&: input_chunk, result&: group_chunk); | 
|---|
| 274 | VectorOperations::Hash(input&: group_chunk.data[0], hashes&: hash_vector, count); | 
|---|
| 275 | for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { | 
|---|
| 276 | VectorOperations::CombineHash(hashes&: hash_vector, input&: group_chunk.data[prt_idx], count); | 
|---|
| 277 | } | 
|---|
| 278 | } else { | 
|---|
| 279 | // OVER(...) (sorting) | 
|---|
| 280 | // Single partition => single hash value | 
|---|
| 281 | hash_vector.SetVectorType(VectorType::CONSTANT_VECTOR); | 
|---|
| 282 | auto hashes = ConstantVector::GetData<hash_t>(vector&: hash_vector); | 
|---|
| 283 | hashes[0] = 0; | 
|---|
| 284 | } | 
|---|
| 285 | } | 
|---|
| 286 |  | 
|---|
| 287 | void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { | 
|---|
| 288 | gstate.count += input_chunk.size(); | 
|---|
| 289 |  | 
|---|
| 290 | // OVER() | 
|---|
| 291 | if (sort_cols == 0) { | 
|---|
| 292 | //	No sorts, so build paged row chunks | 
|---|
| 293 | if (!rows) { | 
|---|
| 294 | const auto entry_size = payload_layout.GetRowWidth(); | 
|---|
| 295 | const auto capacity = MaxValue<idx_t>(STANDARD_VECTOR_SIZE, b: (Storage::BLOCK_SIZE / entry_size) + 1); | 
|---|
| 296 | rows = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: capacity, args: entry_size); | 
|---|
| 297 | strings = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: (idx_t)Storage::BLOCK_SIZE, args: 1, args: true); | 
|---|
| 298 | } | 
|---|
| 299 | const auto row_count = input_chunk.size(); | 
|---|
| 300 | const auto row_sel = FlatVector::IncrementalSelectionVector(); | 
|---|
| 301 | Vector addresses(LogicalType::POINTER); | 
|---|
| 302 | auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses); | 
|---|
| 303 | const auto prev_rows_blocks = rows->blocks.size(); | 
|---|
| 304 | auto handles = rows->Build(added_count: row_count, key_locations, entry_sizes: nullptr, sel: row_sel); | 
|---|
| 305 | auto input_data = input_chunk.ToUnifiedFormat(); | 
|---|
| 306 | RowOperations::Scatter(columns&: input_chunk, col_data: input_data.get(), layout: payload_layout, rows&: addresses, string_heap&: *strings, sel: *row_sel, count: row_count); | 
|---|
| 307 | // Mark that row blocks contain pointers (heap blocks are pinned) | 
|---|
| 308 | if (!payload_layout.AllConstant()) { | 
|---|
| 309 | D_ASSERT(strings->keep_pinned); | 
|---|
| 310 | for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { | 
|---|
| 311 | rows->blocks[i]->block->SetSwizzling( "PartitionLocalSinkState::Sink"); | 
|---|
| 312 | } | 
|---|
| 313 | } | 
|---|
| 314 | return; | 
|---|
| 315 | } | 
|---|
| 316 |  | 
|---|
| 317 | // OVER(...) | 
|---|
| 318 | payload_chunk.Reset(); | 
|---|
| 319 | auto &hash_vector = payload_chunk.data.back(); | 
|---|
| 320 | Hash(input_chunk, hash_vector); | 
|---|
| 321 | for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { | 
|---|
| 322 | payload_chunk.data[col_idx].Reference(other&: input_chunk.data[col_idx]); | 
|---|
| 323 | } | 
|---|
| 324 | payload_chunk.SetCardinality(input_chunk); | 
|---|
| 325 |  | 
|---|
| 326 | gstate.UpdateLocalPartition(local_partition, local_append); | 
|---|
| 327 | local_partition->Append(state&: *local_append, input&: payload_chunk); | 
|---|
| 328 | } | 
|---|
| 329 |  | 
|---|
| 330 | void PartitionLocalSinkState::Combine() { | 
|---|
| 331 | // OVER() | 
|---|
| 332 | if (sort_cols == 0) { | 
|---|
| 333 | // Only one partition again, so need a global lock. | 
|---|
| 334 | lock_guard<mutex> glock(gstate.lock); | 
|---|
| 335 | if (gstate.rows) { | 
|---|
| 336 | if (rows) { | 
|---|
| 337 | gstate.rows->Merge(other&: *rows); | 
|---|
| 338 | gstate.strings->Merge(other&: *strings); | 
|---|
| 339 | rows.reset(); | 
|---|
| 340 | strings.reset(); | 
|---|
| 341 | } | 
|---|
| 342 | } else { | 
|---|
| 343 | gstate.rows = std::move(rows); | 
|---|
| 344 | gstate.strings = std::move(strings); | 
|---|
| 345 | } | 
|---|
| 346 | return; | 
|---|
| 347 | } | 
|---|
| 348 |  | 
|---|
| 349 | // OVER(...) | 
|---|
| 350 | gstate.CombineLocalPartition(local_partition, local_append); | 
|---|
| 351 | } | 
|---|
| 352 |  | 
|---|
| 353 | PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, | 
|---|
| 354 | hash_t hash_bin) | 
|---|
| 355 | : sink(sink), group_data(std::move(group_data)), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), | 
|---|
| 356 | tasks_completed(0) { | 
|---|
| 357 |  | 
|---|
| 358 | const auto group_idx = sink.hash_groups.size(); | 
|---|
| 359 | auto new_group = make_uniq<PartitionGlobalHashGroup>(args&: sink.buffer_manager, args&: sink.partitions, args&: sink.orders, | 
|---|
| 360 | args: sink.payload_types, args&: sink.external); | 
|---|
| 361 | sink.hash_groups.emplace_back(args: std::move(new_group)); | 
|---|
| 362 |  | 
|---|
| 363 | hash_group = sink.hash_groups[group_idx].get(); | 
|---|
| 364 | global_sort = sink.hash_groups[group_idx]->global_sort.get(); | 
|---|
| 365 |  | 
|---|
| 366 | sink.bin_groups[hash_bin] = group_idx; | 
|---|
| 367 | } | 
|---|
| 368 |  | 
|---|
| 369 | void PartitionLocalMergeState::Prepare() { | 
|---|
| 370 | auto &global_sort = *merge_state->global_sort; | 
|---|
| 371 | merge_state->sink.BuildSortState(group_data&: *merge_state->group_data, hash_group&: *merge_state->hash_group); | 
|---|
| 372 | merge_state->group_data.reset(); | 
|---|
| 373 |  | 
|---|
| 374 | global_sort.PrepareMergePhase(); | 
|---|
| 375 | } | 
|---|
| 376 |  | 
|---|
| 377 | void PartitionLocalMergeState::Merge() { | 
|---|
| 378 | auto &global_sort = *merge_state->global_sort; | 
|---|
| 379 | MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); | 
|---|
| 380 | merge_sorter.PerformInMergeRound(); | 
|---|
| 381 | } | 
|---|
| 382 |  | 
|---|
| 383 | void PartitionLocalMergeState::ExecuteTask() { | 
|---|
| 384 | switch (stage) { | 
|---|
| 385 | case PartitionSortStage::PREPARE: | 
|---|
| 386 | Prepare(); | 
|---|
| 387 | break; | 
|---|
| 388 | case PartitionSortStage::MERGE: | 
|---|
| 389 | Merge(); | 
|---|
| 390 | break; | 
|---|
| 391 | default: | 
|---|
| 392 | throw InternalException( "Unexpected PartitionGlobalMergeState in ExecuteTask!"); | 
|---|
| 393 | } | 
|---|
| 394 |  | 
|---|
| 395 | merge_state->CompleteTask(); | 
|---|
| 396 | finished = true; | 
|---|
| 397 | } | 
|---|
| 398 |  | 
|---|
| 399 | bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { | 
|---|
| 400 | lock_guard<mutex> guard(lock); | 
|---|
| 401 |  | 
|---|
| 402 | if (tasks_assigned >= total_tasks) { | 
|---|
| 403 | return false; | 
|---|
| 404 | } | 
|---|
| 405 |  | 
|---|
| 406 | local_state.merge_state = this; | 
|---|
| 407 | local_state.stage = stage; | 
|---|
| 408 | local_state.finished = false; | 
|---|
| 409 | tasks_assigned++; | 
|---|
| 410 |  | 
|---|
| 411 | return true; | 
|---|
| 412 | } | 
|---|
| 413 |  | 
|---|
| 414 | void PartitionGlobalMergeState::CompleteTask() { | 
|---|
| 415 | lock_guard<mutex> guard(lock); | 
|---|
| 416 |  | 
|---|
| 417 | ++tasks_completed; | 
|---|
| 418 | } | 
|---|
| 419 |  | 
|---|
| 420 | bool PartitionGlobalMergeState::TryPrepareNextStage() { | 
|---|
| 421 | lock_guard<mutex> guard(lock); | 
|---|
| 422 |  | 
|---|
| 423 | if (tasks_completed < total_tasks) { | 
|---|
| 424 | return false; | 
|---|
| 425 | } | 
|---|
| 426 |  | 
|---|
| 427 | tasks_assigned = tasks_completed = 0; | 
|---|
| 428 |  | 
|---|
| 429 | switch (stage) { | 
|---|
| 430 | case PartitionSortStage::INIT: | 
|---|
| 431 | total_tasks = 1; | 
|---|
| 432 | stage = PartitionSortStage::PREPARE; | 
|---|
| 433 | return true; | 
|---|
| 434 |  | 
|---|
| 435 | case PartitionSortStage::PREPARE: | 
|---|
| 436 | total_tasks = global_sort->sorted_blocks.size() / 2; | 
|---|
| 437 | if (!total_tasks) { | 
|---|
| 438 | break; | 
|---|
| 439 | } | 
|---|
| 440 | stage = PartitionSortStage::MERGE; | 
|---|
| 441 | global_sort->InitializeMergeRound(); | 
|---|
| 442 | return true; | 
|---|
| 443 |  | 
|---|
| 444 | case PartitionSortStage::MERGE: | 
|---|
| 445 | global_sort->CompleteMergeRound(keep_radix_data: true); | 
|---|
| 446 | total_tasks = global_sort->sorted_blocks.size() / 2; | 
|---|
| 447 | if (!total_tasks) { | 
|---|
| 448 | break; | 
|---|
| 449 | } | 
|---|
| 450 | global_sort->InitializeMergeRound(); | 
|---|
| 451 | return true; | 
|---|
| 452 |  | 
|---|
| 453 | case PartitionSortStage::SORTED: | 
|---|
| 454 | break; | 
|---|
| 455 | } | 
|---|
| 456 |  | 
|---|
| 457 | stage = PartitionSortStage::SORTED; | 
|---|
| 458 |  | 
|---|
| 459 | return false; | 
|---|
| 460 | } | 
|---|
| 461 |  | 
|---|
| 462 | PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { | 
|---|
| 463 | // Schedule all the sorts for maximum thread utilisation | 
|---|
| 464 | auto &partitions = sink.grouping_data->GetPartitions(); | 
|---|
| 465 | sink.bin_groups.resize(new_size: partitions.size(), x: partitions.size()); | 
|---|
| 466 | for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { | 
|---|
| 467 | auto &group_data = partitions[hash_bin]; | 
|---|
| 468 | // Prepare for merge sort phase | 
|---|
| 469 | if (group_data->Count()) { | 
|---|
| 470 | auto state = make_uniq<PartitionGlobalMergeState>(args&: sink, args: std::move(group_data), args&: hash_bin); | 
|---|
| 471 | states.emplace_back(args: std::move(state)); | 
|---|
| 472 | } | 
|---|
| 473 | } | 
|---|
| 474 | } | 
|---|
| 475 |  | 
|---|
| 476 | class PartitionMergeTask : public ExecutorTask { | 
|---|
| 477 | public: | 
|---|
| 478 | PartitionMergeTask(shared_ptr<Event> event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p) | 
|---|
| 479 | : ExecutorTask(context_p), event(std::move(event_p)), hash_groups(hash_groups_p) { | 
|---|
| 480 | } | 
|---|
| 481 |  | 
|---|
| 482 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; | 
|---|
| 483 |  | 
|---|
| 484 | private: | 
|---|
| 485 | shared_ptr<Event> event; | 
|---|
| 486 | PartitionLocalMergeState local_state; | 
|---|
| 487 | PartitionGlobalMergeStates &hash_groups; | 
|---|
| 488 | }; | 
|---|
| 489 |  | 
|---|
| 490 | TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { | 
|---|
| 491 | // Loop until all hash groups are done | 
|---|
| 492 | size_t sorted = 0; | 
|---|
| 493 | while (sorted < hash_groups.states.size()) { | 
|---|
| 494 | // First check if there is an unfinished task for this thread | 
|---|
| 495 | if (executor.HasError()) { | 
|---|
| 496 | return TaskExecutionResult::TASK_ERROR; | 
|---|
| 497 | } | 
|---|
| 498 | if (!local_state.TaskFinished()) { | 
|---|
| 499 | local_state.ExecuteTask(); | 
|---|
| 500 | continue; | 
|---|
| 501 | } | 
|---|
| 502 |  | 
|---|
| 503 | // Thread is done with its assigned task, try to fetch new work | 
|---|
| 504 | for (auto group = sorted; group < hash_groups.states.size(); ++group) { | 
|---|
| 505 | auto &global_state = hash_groups.states[group]; | 
|---|
| 506 | if (global_state->IsSorted()) { | 
|---|
| 507 | // This hash group is done | 
|---|
| 508 | // Update the high water mark of densely completed groups | 
|---|
| 509 | if (sorted == group) { | 
|---|
| 510 | ++sorted; | 
|---|
| 511 | } | 
|---|
| 512 | continue; | 
|---|
| 513 | } | 
|---|
| 514 |  | 
|---|
| 515 | // Try to assign work for this hash group to this thread | 
|---|
| 516 | if (global_state->AssignTask(local_state)) { | 
|---|
| 517 | // We assigned a task to this thread! | 
|---|
| 518 | // Break out of this loop to re-enter the top-level loop and execute the task | 
|---|
| 519 | break; | 
|---|
| 520 | } | 
|---|
| 521 |  | 
|---|
| 522 | // Hash group global state couldn't assign a task to this thread | 
|---|
| 523 | // Try to prepare the next stage | 
|---|
| 524 | if (!global_state->TryPrepareNextStage()) { | 
|---|
| 525 | // This current hash group is not yet done | 
|---|
| 526 | // But we were not able to assign a task for it to this thread | 
|---|
| 527 | // See if the next hash group is better | 
|---|
| 528 | continue; | 
|---|
| 529 | } | 
|---|
| 530 |  | 
|---|
| 531 | // We were able to prepare the next stage for this hash group! | 
|---|
| 532 | // Try to assign a task once more | 
|---|
| 533 | if (global_state->AssignTask(local_state)) { | 
|---|
| 534 | // We assigned a task to this thread! | 
|---|
| 535 | // Break out of this loop to re-enter the top-level loop and execute the task | 
|---|
| 536 | break; | 
|---|
| 537 | } | 
|---|
| 538 |  | 
|---|
| 539 | // We were able to prepare the next merge round, | 
|---|
| 540 | // but we were not able to assign a task for it to this thread | 
|---|
| 541 | // The tasks were assigned to other threads while this thread waited for the lock | 
|---|
| 542 | // Go to the next iteration to see if another hash group has a task | 
|---|
| 543 | } | 
|---|
| 544 | } | 
|---|
| 545 |  | 
|---|
| 546 | event->FinishTask(); | 
|---|
| 547 | return TaskExecutionResult::TASK_FINISHED; | 
|---|
| 548 | } | 
|---|
| 549 |  | 
|---|
| 550 | void PartitionMergeEvent::Schedule() { | 
|---|
| 551 | auto &context = pipeline->GetClientContext(); | 
|---|
| 552 |  | 
|---|
| 553 | // Schedule tasks equal to the number of threads, which will each merge multiple partitions | 
|---|
| 554 | auto &ts = TaskScheduler::GetScheduler(context); | 
|---|
| 555 | idx_t num_threads = ts.NumberOfThreads(); | 
|---|
| 556 |  | 
|---|
| 557 | vector<shared_ptr<Task>> merge_tasks; | 
|---|
| 558 | for (idx_t tnum = 0; tnum < num_threads; tnum++) { | 
|---|
| 559 | merge_tasks.emplace_back(args: make_uniq<PartitionMergeTask>(args: shared_from_this(), args&: context, args&: merge_states)); | 
|---|
| 560 | } | 
|---|
| 561 | SetTasks(std::move(merge_tasks)); | 
|---|
| 562 | } | 
|---|
| 563 |  | 
|---|
| 564 | } // namespace duckdb | 
|---|
| 565 |  | 
|---|