1#include "duckdb/execution/radix_partitioned_hashtable.hpp"
2
3#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp"
4#include "duckdb/parallel/event.hpp"
5#include "duckdb/parallel/task_scheduler.hpp"
6#include "duckdb/planner/expression/bound_reference_expression.hpp"
7
8namespace duckdb {
9
10// compute the GROUPING values
11// for each parameter to the GROUPING clause, we check if the hash table groups on this particular group
12// if it does, we return 0, otherwise we return 1
13// we then use bitshifts to combine these values
14void RadixPartitionedHashTable::SetGroupingValues() {
15 auto &grouping_functions = op.GetGroupingFunctions();
16 for (auto &grouping : grouping_functions) {
17 int64_t grouping_value = 0;
18 D_ASSERT(grouping.size() < sizeof(int64_t) * 8);
19 for (idx_t i = 0; i < grouping.size(); i++) {
20 if (grouping_set.find(x: grouping[i]) == grouping_set.end()) {
21 // we don't group on this value!
22 grouping_value += (int64_t)1 << (grouping.size() - (i + 1));
23 }
24 }
25 grouping_values.push_back(x: Value::BIGINT(value: grouping_value));
26 }
27}
28
29RadixPartitionedHashTable::RadixPartitionedHashTable(GroupingSet &grouping_set_p, const GroupedAggregateData &op_p)
30 : grouping_set(grouping_set_p), op(op_p) {
31
32 auto groups_count = op.GroupCount();
33 for (idx_t i = 0; i < groups_count; i++) {
34 if (grouping_set.find(x: i) == grouping_set.end()) {
35 null_groups.push_back(x: i);
36 }
37 }
38
39 // 10000 seems like a good compromise here
40 radix_limit = 10000;
41
42 if (grouping_set.empty()) {
43 // fake a single group with a constant value for aggregation without groups
44 group_types.emplace_back(args: LogicalType::TINYINT);
45 }
46 for (auto &entry : grouping_set) {
47 D_ASSERT(entry < op.group_types.size());
48 group_types.push_back(x: op.group_types[entry]);
49 }
50 SetGroupingValues();
51}
52
53//===--------------------------------------------------------------------===//
54// Sink
55//===--------------------------------------------------------------------===//
56class RadixHTGlobalState : public GlobalSinkState {
57 constexpr const static idx_t MAX_RADIX_PARTITIONS = 32;
58
59public:
60 explicit RadixHTGlobalState(ClientContext &context)
61 : is_empty(true), multi_scan(true), partitioned(false),
62 partition_info(
63 MinValue<idx_t>(a: MAX_RADIX_PARTITIONS, b: TaskScheduler::GetScheduler(context).NumberOfThreads())) {
64 }
65
66 vector<unique_ptr<PartitionableHashTable>> intermediate_hts;
67 vector<shared_ptr<GroupedAggregateHashTable>> finalized_hts;
68
69 //! Whether or not any tuples were added to the HT
70 bool is_empty;
71 //! Whether or not the hash table should be scannable multiple times
72 bool multi_scan;
73 //! The lock for updating the global aggregate state
74 mutex lock;
75 //! Whether or not any thread has crossed the partitioning threshold
76 atomic<bool> partitioned;
77
78 bool is_finalized = false;
79 bool is_partitioned = false;
80
81 RadixPartitionInfo partition_info;
82 AggregateHTAppendState append_state;
83};
84
85class RadixHTLocalState : public LocalSinkState {
86public:
87 explicit RadixHTLocalState(const RadixPartitionedHashTable &ht) : total_groups(0), is_empty(true) {
88 // if there are no groups we create a fake group so everything has the same group
89 group_chunk.InitializeEmpty(types: ht.group_types);
90 if (ht.grouping_set.empty()) {
91 group_chunk.data[0].Reference(value: Value::TINYINT(value: 42));
92 }
93 }
94
95 DataChunk group_chunk;
96 //! The aggregate HT
97 unique_ptr<PartitionableHashTable> ht;
98 //! The total number of groups found by this thread
99 idx_t total_groups;
100
101 //! Whether or not any tuples were added to the HT
102 bool is_empty;
103};
104
105void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &state) {
106 auto &gstate = state.Cast<RadixHTGlobalState>();
107 gstate.multi_scan = true;
108}
109
110unique_ptr<GlobalSinkState> RadixPartitionedHashTable::GetGlobalSinkState(ClientContext &context) const {
111 return make_uniq<RadixHTGlobalState>(args&: context);
112}
113
114unique_ptr<LocalSinkState> RadixPartitionedHashTable::GetLocalSinkState(ExecutionContext &context) const {
115 return make_uniq<RadixHTLocalState>(args: *this);
116}
117
118void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const {
119 idx_t chunk_index = 0;
120 // Populate the group_chunk
121 for (auto &group_idx : grouping_set) {
122 // Retrieve the expression containing the index in the input chunk
123 auto &group = op.groups[group_idx];
124 D_ASSERT(group->type == ExpressionType::BOUND_REF);
125 auto &bound_ref_expr = group->Cast<BoundReferenceExpression>();
126 // Reference from input_chunk[group.index] -> group_chunk[chunk_index]
127 group_chunk.data[chunk_index++].Reference(other&: input_chunk.data[bound_ref_expr.index]);
128 }
129 group_chunk.SetCardinality(input_chunk.size());
130 group_chunk.Verify();
131}
132
133void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input,
134 DataChunk &payload_input, const unsafe_vector<idx_t> &filter) const {
135 auto &llstate = input.local_state.Cast<RadixHTLocalState>();
136 auto &gstate = input.global_state.Cast<RadixHTGlobalState>();
137 D_ASSERT(!gstate.is_finalized);
138
139 DataChunk &group_chunk = llstate.group_chunk;
140 PopulateGroupChunk(group_chunk, input_chunk&: chunk);
141
142 // if we have non-combinable aggregates (e.g. string_agg) we cannot keep parallel hash
143 // tables
144 if (ForceSingleHT(state&: input.global_state)) {
145 lock_guard<mutex> glock(gstate.lock);
146 gstate.is_empty = gstate.is_empty && group_chunk.size() == 0;
147 if (gstate.finalized_hts.empty()) {
148 // Create a finalized ht in the global state, that we can populate
149 gstate.finalized_hts.push_back(
150 x: make_shared<GroupedAggregateHashTable>(args&: context.client, args&: Allocator::Get(context&: context.client), args: group_types,
151 args: op.payload_types, args: op.bindings, args: HtEntryType::HT_WIDTH_64));
152 }
153 D_ASSERT(gstate.finalized_hts.size() == 1);
154 D_ASSERT(gstate.finalized_hts[0]);
155 llstate.total_groups +=
156 gstate.finalized_hts[0]->AddChunk(state&: gstate.append_state, groups&: group_chunk, payload&: payload_input, filter);
157 return;
158 }
159
160 if (group_chunk.size() > 0) {
161 llstate.is_empty = false;
162 }
163
164 if (!llstate.ht) {
165 llstate.ht =
166 make_uniq<PartitionableHashTable>(args&: context.client, args&: Allocator::Get(context&: context.client), args&: gstate.partition_info,
167 args: group_types, args: op.payload_types, args: op.bindings);
168 }
169
170 llstate.total_groups += llstate.ht->AddChunk(groups&: group_chunk, payload&: payload_input,
171 do_partition: gstate.partitioned && gstate.partition_info.n_partitions > 1, filter);
172 if (llstate.total_groups >= radix_limit) {
173 gstate.partitioned = true;
174 }
175}
176
177void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkState &state,
178 LocalSinkState &lstate) const {
179 auto &llstate = lstate.Cast<RadixHTLocalState>();
180 auto &gstate = state.Cast<RadixHTGlobalState>();
181 D_ASSERT(!gstate.is_finalized);
182
183 // this actually does not do a lot but just pushes the local HTs into the global state so we can later combine them
184 // in parallel
185
186 if (ForceSingleHT(state)) {
187 D_ASSERT(gstate.finalized_hts.size() <= 1);
188 return;
189 }
190
191 if (!llstate.ht) {
192 return; // no data
193 }
194
195 if (!llstate.ht->IsPartitioned() && gstate.partition_info.n_partitions > 1 && gstate.partitioned) {
196 llstate.ht->Partition();
197 }
198
199 // we will never add new values to these HTs so we can drop the first part of the HT
200 llstate.ht->Finalize();
201
202 lock_guard<mutex> glock(gstate.lock);
203 if (!llstate.is_empty) {
204 gstate.is_empty = false;
205 }
206 // at this point we just collect them the PhysicalHashAggregateFinalizeTask (below) will merge them in parallel
207 gstate.intermediate_hts.push_back(x: std::move(llstate.ht));
208}
209
210bool RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState &gstate_p) const {
211 auto &gstate = gstate_p.Cast<RadixHTGlobalState>();
212 D_ASSERT(!gstate.is_finalized);
213 gstate.is_finalized = true;
214
215 // special case if we have non-combinable aggregates
216 // we have already aggreagted into a global shared HT that does not require any additional finalization steps
217 if (ForceSingleHT(state&: gstate)) {
218 D_ASSERT(gstate.finalized_hts.size() <= 1);
219 D_ASSERT(gstate.finalized_hts.empty() || gstate.finalized_hts[0]);
220 return false;
221 }
222
223 // we can have two cases now, non-partitioned for few groups and radix-partitioned for very many groups.
224 // go through all of the child hts and see if we ever called partition() on any of them
225 // if we did, its the latter case.
226 bool any_partitioned = false;
227 for (auto &pht : gstate.intermediate_hts) {
228 if (pht->IsPartitioned()) {
229 any_partitioned = true;
230 break;
231 }
232 }
233
234 auto &allocator = Allocator::Get(context);
235 if (any_partitioned) {
236 // if one is partitioned, all have to be
237 // this should mostly have already happened in Combine, but if not we do it here
238 for (auto &pht : gstate.intermediate_hts) {
239 if (!pht->IsPartitioned()) {
240 pht->Partition();
241 }
242 }
243 // schedule additional tasks to combine the partial HTs
244 gstate.finalized_hts.resize(new_size: gstate.partition_info.n_partitions);
245 for (idx_t r = 0; r < gstate.partition_info.n_partitions; r++) {
246 gstate.finalized_hts[r] = make_shared<GroupedAggregateHashTable>(
247 args&: context, args&: allocator, args: group_types, args: op.payload_types, args: op.bindings, args: HtEntryType::HT_WIDTH_64);
248 }
249 gstate.is_partitioned = true;
250 return true;
251 } else { // in the non-partitioned case we immediately combine all the unpartitioned hts created by the threads.
252 // TODO possible optimization, if total count < limit for 32 bit ht, use that one
253 // create this ht here so finalize needs no lock on gstate
254
255 gstate.finalized_hts.push_back(x: make_shared<GroupedAggregateHashTable>(
256 args&: context, args&: allocator, args: group_types, args: op.payload_types, args: op.bindings, args: HtEntryType::HT_WIDTH_64));
257 for (auto &pht : gstate.intermediate_hts) {
258 auto unpartitioned = pht->GetUnpartitioned();
259 for (auto &unpartitioned_ht : unpartitioned) {
260 D_ASSERT(unpartitioned_ht);
261 gstate.finalized_hts[0]->Combine(other&: *unpartitioned_ht);
262 unpartitioned_ht.reset();
263 }
264 unpartitioned.clear();
265 }
266 D_ASSERT(gstate.finalized_hts[0]);
267 gstate.finalized_hts[0]->Finalize();
268 return false;
269 }
270}
271
272// this task is run in multiple threads and combines the radix-partitioned hash tables into a single onen and then
273// folds them into the global ht finally.
274class RadixAggregateFinalizeTask : public ExecutorTask {
275public:
276 RadixAggregateFinalizeTask(Executor &executor, shared_ptr<Event> event_p, RadixHTGlobalState &state_p,
277 idx_t radix_p)
278 : ExecutorTask(executor), event(std::move(event_p)), state(state_p), radix(radix_p) {
279 }
280
281 static void FinalizeHT(RadixHTGlobalState &gstate, idx_t radix) {
282 D_ASSERT(gstate.partition_info.n_partitions <= gstate.finalized_hts.size());
283 D_ASSERT(gstate.finalized_hts[radix]);
284 for (auto &pht : gstate.intermediate_hts) {
285 for (auto &ht : pht->GetPartition(partition: radix)) {
286 gstate.finalized_hts[radix]->Combine(other&: *ht);
287 ht.reset();
288 }
289 }
290 gstate.finalized_hts[radix]->Finalize();
291 }
292
293 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
294 FinalizeHT(gstate&: state, radix);
295 event->FinishTask();
296 return TaskExecutionResult::TASK_FINISHED;
297 }
298
299private:
300 shared_ptr<Event> event;
301 RadixHTGlobalState &state;
302 idx_t radix;
303};
304
305void RadixPartitionedHashTable::ScheduleTasks(Executor &executor, const shared_ptr<Event> &event,
306 GlobalSinkState &state, vector<shared_ptr<Task>> &tasks) const {
307 auto &gstate = state.Cast<RadixHTGlobalState>();
308 if (!gstate.is_partitioned) {
309 return;
310 }
311 for (idx_t r = 0; r < gstate.partition_info.n_partitions; r++) {
312 D_ASSERT(gstate.partition_info.n_partitions <= gstate.finalized_hts.size());
313 D_ASSERT(gstate.finalized_hts[r]);
314 tasks.push_back(x: make_uniq<RadixAggregateFinalizeTask>(args&: executor, args: event, args&: gstate, args&: r));
315 }
316}
317
318bool RadixPartitionedHashTable::ForceSingleHT(GlobalSinkState &state) const {
319 auto &gstate = state.Cast<RadixHTGlobalState>();
320 return gstate.partition_info.n_partitions < 2;
321}
322
323//===--------------------------------------------------------------------===//
324// Source
325//===--------------------------------------------------------------------===//
326class RadixHTGlobalSourceState : public GlobalSourceState {
327public:
328 explicit RadixHTGlobalSourceState(Allocator &allocator, const RadixPartitionedHashTable &ht)
329 : ht_index(0), initialized(false), finished(false) {
330 }
331
332 //! Heavy handed for now.
333 mutex lock;
334 //! The current position to scan the HT for output tuples
335 idx_t ht_index;
336 //! The set of aggregate scan states
337 unsafe_unique_array<TupleDataParallelScanState> ht_scan_states;
338 atomic<bool> initialized;
339 atomic<bool> finished;
340};
341
342class RadixHTLocalSourceState : public LocalSourceState {
343public:
344 explicit RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &ht) {
345 auto &allocator = Allocator::Get(context&: context.client);
346 auto scan_chunk_types = ht.group_types;
347 for (auto &aggr_type : ht.op.aggregate_return_types) {
348 scan_chunk_types.push_back(x: aggr_type);
349 }
350 scan_chunk.Initialize(allocator, types: scan_chunk_types);
351 }
352
353 //! Materialized GROUP BY expressions & aggregates
354 DataChunk scan_chunk;
355 //! HT index
356 idx_t ht_index = DConstants::INVALID_INDEX;
357 //! A reference to the current HT that we are scanning
358 shared_ptr<GroupedAggregateHashTable> ht;
359 //! Scan state for the current HT
360 TupleDataLocalScanState scan_state;
361};
362
363unique_ptr<GlobalSourceState> RadixPartitionedHashTable::GetGlobalSourceState(ClientContext &context) const {
364 return make_uniq<RadixHTGlobalSourceState>(args&: Allocator::Get(context), args: *this);
365}
366
367unique_ptr<LocalSourceState> RadixPartitionedHashTable::GetLocalSourceState(ExecutionContext &context) const {
368 return make_uniq<RadixHTLocalSourceState>(args&: context, args: *this);
369}
370
371idx_t RadixPartitionedHashTable::Size(GlobalSinkState &sink_state) const {
372 auto &gstate = sink_state.Cast<RadixHTGlobalState>();
373 if (gstate.is_empty && grouping_set.empty()) {
374 return 1;
375 }
376
377 idx_t count = 0;
378 for (const auto &ht : gstate.finalized_hts) {
379 count += ht->Count();
380 }
381 return count;
382}
383
384SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, DataChunk &chunk,
385 GlobalSinkState &sink_state, OperatorSourceInput &input) const {
386 auto &gstate = sink_state.Cast<RadixHTGlobalState>();
387 auto &state = input.global_state.Cast<RadixHTGlobalSourceState>();
388 auto &lstate = input.local_state.Cast<RadixHTLocalSourceState>();
389 D_ASSERT(gstate.is_finalized);
390 if (state.finished) {
391 return SourceResultType::FINISHED;
392 }
393
394 // special case hack to sort out aggregating from empty intermediates
395 // for aggregations without groups
396 if (gstate.is_empty && grouping_set.empty()) {
397 D_ASSERT(chunk.ColumnCount() == null_groups.size() + op.aggregates.size() + op.grouping_functions.size());
398 // for each column in the aggregates, set to initial state
399 chunk.SetCardinality(1);
400 for (auto null_group : null_groups) {
401 chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR);
402 ConstantVector::SetNull(vector&: chunk.data[null_group], is_null: true);
403 }
404 for (idx_t i = 0; i < op.aggregates.size(); i++) {
405 D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE);
406 auto &aggr = op.aggregates[i]->Cast<BoundAggregateExpression>();
407 auto aggr_state = make_unsafe_uniq_array<data_t>(n: aggr.function.state_size());
408 aggr.function.initialize(aggr_state.get());
409
410 AggregateInputData aggr_input_data(aggr.bind_info.get(), Allocator::DefaultAllocator());
411 Vector state_vector(Value::POINTER(value: CastPointerToValue(src: aggr_state.get())));
412 aggr.function.finalize(state_vector, aggr_input_data, chunk.data[null_groups.size() + i], 1, 0);
413 if (aggr.function.destructor) {
414 aggr.function.destructor(state_vector, aggr_input_data, 1);
415 }
416 }
417 // Place the grouping values (all the groups of the grouping_set condensed into a single value)
418 // Behind the null groups + aggregates
419 for (idx_t i = 0; i < op.grouping_functions.size(); i++) {
420 chunk.data[null_groups.size() + op.aggregates.size() + i].Reference(value: grouping_values[i]);
421 }
422 state.finished = true;
423 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
424 }
425 if (gstate.is_empty) {
426 state.finished = true;
427 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
428 }
429 idx_t elements_found = 0;
430
431 lstate.scan_chunk.Reset();
432 if (!state.initialized) {
433 lock_guard<mutex> l(state.lock);
434 if (!state.initialized) {
435 auto &finalized_hts = gstate.finalized_hts;
436 state.ht_scan_states = make_unsafe_uniq_array<TupleDataParallelScanState>(n: finalized_hts.size());
437
438 const auto &layout = gstate.finalized_hts[0]->GetDataCollection().GetLayout();
439 vector<column_t> column_ids;
440 column_ids.reserve(n: layout.ColumnCount() - 1);
441 for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) {
442 column_ids.emplace_back(args&: col_idx);
443 }
444
445 for (idx_t ht_idx = 0; ht_idx < finalized_hts.size(); ht_idx++) {
446 gstate.finalized_hts[ht_idx]->GetDataCollection().InitializeScan(
447 state&: state.ht_scan_states.get()[ht_idx].scan_state, column_ids);
448 }
449 state.initialized = true;
450 }
451 }
452
453 auto &local_scan_state = lstate.scan_state;
454 while (true) {
455 D_ASSERT(state.ht_scan_states);
456 idx_t ht_index;
457 {
458 lock_guard<mutex> l(state.lock);
459 ht_index = state.ht_index;
460 if (ht_index >= gstate.finalized_hts.size()) {
461 state.finished = true;
462 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
463 }
464 }
465 D_ASSERT(ht_index < gstate.finalized_hts.size());
466 if (lstate.ht_index != DConstants::INVALID_INDEX && ht_index != lstate.ht_index) {
467 lstate.ht->GetDataCollection().FinalizePinState(pin_state&: local_scan_state.pin_state);
468 }
469 lstate.ht_index = ht_index;
470 lstate.ht = gstate.finalized_hts[ht_index];
471 D_ASSERT(lstate.ht);
472
473 auto &global_scan_state = state.ht_scan_states[ht_index];
474 elements_found = lstate.ht->Scan(gstate&: global_scan_state, lstate&: local_scan_state, result&: lstate.scan_chunk);
475 if (elements_found > 0) {
476 break;
477 }
478 lstate.ht->GetDataCollection().FinalizePinState(pin_state&: local_scan_state.pin_state);
479
480 // move to the next hash table
481 lock_guard<mutex> l(state.lock);
482 ht_index++;
483 if (ht_index > state.ht_index) {
484 // we have not yet worked on the table
485 // move the global index forwards
486 if (!gstate.multi_scan) {
487 gstate.finalized_hts[state.ht_index].reset();
488 }
489 state.ht_index = ht_index;
490 }
491 }
492
493 // compute the final projection list
494 chunk.SetCardinality(elements_found);
495
496 idx_t chunk_index = 0;
497 for (auto &entry : grouping_set) {
498 chunk.data[entry].Reference(other&: lstate.scan_chunk.data[chunk_index++]);
499 }
500 for (auto null_group : null_groups) {
501 chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR);
502 ConstantVector::SetNull(vector&: chunk.data[null_group], is_null: true);
503 }
504 D_ASSERT(grouping_set.size() + null_groups.size() == op.GroupCount());
505 for (idx_t col_idx = 0; col_idx < op.aggregates.size(); col_idx++) {
506 chunk.data[op.GroupCount() + col_idx].Reference(other&: lstate.scan_chunk.data[group_types.size() + col_idx]);
507 }
508 D_ASSERT(op.grouping_functions.size() == grouping_values.size());
509 for (idx_t i = 0; i < op.grouping_functions.size(); i++) {
510 chunk.data[op.GroupCount() + op.aggregates.size() + i].Reference(value: grouping_values[i]);
511 }
512 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
513}
514
515} // namespace duckdb
516