1#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp"
2
3#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/execution/aggregate_hashtable.hpp"
6#include "duckdb/main/client_context.hpp"
7#include "duckdb/parallel/interrupt.hpp"
8#include "duckdb/parallel/pipeline.hpp"
9#include "duckdb/parallel/task_scheduler.hpp"
10#include "duckdb/parallel/thread_context.hpp"
11#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
12#include "duckdb/planner/expression/bound_constant_expression.hpp"
13#include "duckdb/planner/expression/bound_reference_expression.hpp"
14#include "duckdb/parallel/base_pipeline_event.hpp"
15#include "duckdb/common/atomic.hpp"
16#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp"
17
18namespace duckdb {
19
20HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p,
21 const GroupedAggregateData &grouped_aggregate_data,
22 unique_ptr<DistinctAggregateCollectionInfo> &info)
23 : table_data(grouping_set_p, grouped_aggregate_data) {
24 if (info) {
25 distinct_data = make_uniq<DistinctAggregateData>(args&: *info, args&: grouping_set_p, args: &grouped_aggregate_data.groups);
26 }
27}
28
29bool HashAggregateGroupingData::HasDistinct() const {
30 return distinct_data != nullptr;
31}
32
33HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data,
34 ClientContext &context) {
35 table_state = data.table_data.GetGlobalSinkState(context);
36 if (data.HasDistinct()) {
37 distinct_state = make_uniq<DistinctAggregateState>(args&: *data.distinct_data, args&: context);
38 }
39}
40
41HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op,
42 const HashAggregateGroupingData &data,
43 ExecutionContext &context) {
44 table_state = data.table_data.GetLocalSinkState(context);
45 if (!data.HasDistinct()) {
46 return;
47 }
48 auto &distinct_data = *data.distinct_data;
49
50 auto &distinct_indices = op.distinct_collection_info->Indices();
51 D_ASSERT(!distinct_indices.empty());
52
53 distinct_states.resize(new_size: op.distinct_collection_info->aggregates.size());
54 auto &table_map = op.distinct_collection_info->table_map;
55
56 for (auto &idx : distinct_indices) {
57 idx_t table_idx = table_map[idx];
58 auto &radix_table = distinct_data.radix_tables[table_idx];
59 if (radix_table == nullptr) {
60 // This aggregate has identical input as another aggregate, so no table is created for it
61 continue;
62 }
63 // Initialize the states of the radix tables used for the distinct aggregates
64 distinct_states[table_idx] = radix_table->GetLocalSinkState(context);
65 }
66}
67
68static vector<LogicalType> CreateGroupChunkTypes(vector<unique_ptr<Expression>> &groups) {
69 set<idx_t> group_indices;
70
71 if (groups.empty()) {
72 return {};
73 }
74
75 for (auto &group : groups) {
76 D_ASSERT(group->type == ExpressionType::BOUND_REF);
77 auto &bound_ref = group->Cast<BoundReferenceExpression>();
78 group_indices.insert(x: bound_ref.index);
79 }
80 idx_t highest_index = *group_indices.rbegin();
81 vector<LogicalType> types(highest_index + 1, LogicalType::SQLNULL);
82 for (auto &group : groups) {
83 auto &bound_ref = group->Cast<BoundReferenceExpression>();
84 types[bound_ref.index] = bound_ref.return_type;
85 }
86 return types;
87}
88
89bool PhysicalHashAggregate::CanSkipRegularSink() const {
90 if (!filter_indexes.empty()) {
91 // If we have filters, we can't skip the regular sink, because we might lose groups otherwise.
92 return false;
93 }
94 if (grouped_aggregate_data.aggregates.empty()) {
95 // When there are no aggregates, we have to add to the main ht right away
96 return false;
97 }
98 if (!non_distinct_filter.empty()) {
99 return false;
100 }
101 return true;
102}
103
104PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types,
105 vector<unique_ptr<Expression>> expressions, idx_t estimated_cardinality)
106 : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) {
107}
108
109PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types,
110 vector<unique_ptr<Expression>> expressions,
111 vector<unique_ptr<Expression>> groups_p, idx_t estimated_cardinality)
112 : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {},
113 estimated_cardinality) {
114}
115
116PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types,
117 vector<unique_ptr<Expression>> expressions,
118 vector<unique_ptr<Expression>> groups_p,
119 vector<GroupingSet> grouping_sets_p,
120 vector<unsafe_vector<idx_t>> grouping_functions_p,
121 idx_t estimated_cardinality)
122 : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality),
123 grouping_sets(std::move(grouping_sets_p)) {
124 // get a list of all aggregates to be computed
125 const idx_t group_count = groups_p.size();
126 if (grouping_sets.empty()) {
127 GroupingSet set;
128 for (idx_t i = 0; i < group_count; i++) {
129 set.insert(x: i);
130 }
131 grouping_sets.push_back(x: std::move(set));
132 }
133 input_group_types = CreateGroupChunkTypes(groups&: groups_p);
134
135 grouped_aggregate_data.InitializeGroupby(groups: std::move(groups_p), expressions: std::move(expressions),
136 grouping_functions: std::move(grouping_functions_p));
137
138 auto &aggregates = grouped_aggregate_data.aggregates;
139 // filter_indexes must be pre-built, not lazily instantiated in parallel...
140 // Because everything that lives in this class should be read-only at execution time
141 idx_t aggregate_input_idx = 0;
142 for (idx_t i = 0; i < aggregates.size(); i++) {
143 auto &aggregate = aggregates[i];
144 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
145 aggregate_input_idx += aggr.children.size();
146 if (aggr.aggr_type == AggregateType::DISTINCT) {
147 distinct_filter.push_back(x: i);
148 } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) {
149 non_distinct_filter.push_back(x: i);
150 } else { // LCOV_EXCL_START
151 throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate");
152 } // LCOV_EXCL_STOP
153 }
154
155 for (idx_t i = 0; i < aggregates.size(); i++) {
156 auto &aggregate = aggregates[i];
157 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
158 if (aggr.filter) {
159 auto &bound_ref_expr = aggr.filter->Cast<BoundReferenceExpression>();
160 if (!filter_indexes.count(x: aggr.filter.get())) {
161 // Replace the bound reference expression's index with the corresponding index of the payload chunk
162 filter_indexes[aggr.filter.get()] = bound_ref_expr.index;
163 bound_ref_expr.index = aggregate_input_idx;
164 }
165 aggregate_input_idx++;
166 }
167 }
168
169 distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates&: grouped_aggregate_data.aggregates);
170
171 for (idx_t i = 0; i < grouping_sets.size(); i++) {
172 groupings.emplace_back(args&: grouping_sets[i], args&: grouped_aggregate_data, args&: distinct_collection_info);
173 }
174}
175
176//===--------------------------------------------------------------------===//
177// Sink
178//===--------------------------------------------------------------------===//
179class HashAggregateGlobalState : public GlobalSinkState {
180public:
181 HashAggregateGlobalState(const PhysicalHashAggregate &op, ClientContext &context) {
182 grouping_states.reserve(n: op.groupings.size());
183 for (idx_t i = 0; i < op.groupings.size(); i++) {
184 auto &grouping = op.groupings[i];
185 grouping_states.emplace_back(args: grouping, args&: context);
186 }
187 vector<LogicalType> filter_types;
188 for (auto &aggr : op.grouped_aggregate_data.aggregates) {
189 auto &aggregate = aggr->Cast<BoundAggregateExpression>();
190 for (auto &child : aggregate.children) {
191 payload_types.push_back(x: child->return_type);
192 }
193 if (aggregate.filter) {
194 filter_types.push_back(x: aggregate.filter->return_type);
195 }
196 }
197 payload_types.reserve(n: payload_types.size() + filter_types.size());
198 payload_types.insert(position: payload_types.end(), first: filter_types.begin(), last: filter_types.end());
199 }
200
201 vector<HashAggregateGroupingGlobalState> grouping_states;
202 vector<LogicalType> payload_types;
203 //! Whether or not the aggregate is finished
204 bool finished = false;
205};
206
207class HashAggregateLocalState : public LocalSinkState {
208public:
209 HashAggregateLocalState(const PhysicalHashAggregate &op, ExecutionContext &context) {
210
211 auto &payload_types = op.grouped_aggregate_data.payload_types;
212 if (!payload_types.empty()) {
213 aggregate_input_chunk.InitializeEmpty(types: payload_types);
214 }
215
216 grouping_states.reserve(n: op.groupings.size());
217 for (auto &grouping : op.groupings) {
218 grouping_states.emplace_back(args: op, args: grouping, args&: context);
219 }
220 // The filter set is only needed here for the distinct aggregates
221 // the filtering of data for the regular aggregates is done within the hashtable
222 vector<AggregateObject> aggregate_objects;
223 for (auto &aggregate : op.grouped_aggregate_data.aggregates) {
224 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
225 aggregate_objects.emplace_back(args: &aggr);
226 }
227
228 filter_set.Initialize(context&: context.client, aggregates: aggregate_objects, payload_types);
229 }
230
231 DataChunk aggregate_input_chunk;
232 vector<HashAggregateGroupingLocalState> grouping_states;
233 AggregateFilterDataSet filter_set;
234};
235
236void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) {
237 auto &gstate = state.Cast<HashAggregateGlobalState>();
238 for (auto &grouping_state : gstate.grouping_states) {
239 auto &radix_state = grouping_state.table_state;
240 RadixPartitionedHashTable::SetMultiScan(*radix_state);
241 if (!grouping_state.distinct_state) {
242 continue;
243 }
244 }
245}
246
247unique_ptr<GlobalSinkState> PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const {
248 return make_uniq<HashAggregateGlobalState>(args: *this, args&: context);
249}
250
251unique_ptr<LocalSinkState> PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const {
252 return make_uniq<HashAggregateLocalState>(args: *this, args&: context);
253}
254
255void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input,
256 idx_t grouping_idx) const {
257 auto &sink = input.local_state.Cast<HashAggregateLocalState>();
258 auto &global_sink = input.global_state.Cast<HashAggregateGlobalState>();
259
260 auto &grouping_gstate = global_sink.grouping_states[grouping_idx];
261 auto &grouping_lstate = sink.grouping_states[grouping_idx];
262 auto &distinct_info = *distinct_collection_info;
263
264 auto &distinct_state = grouping_gstate.distinct_state;
265 auto &distinct_data = groupings[grouping_idx].distinct_data;
266
267 DataChunk empty_chunk;
268
269 // Create an empty filter for Sink, since we don't need to update any aggregate states here
270 unsafe_vector<idx_t> empty_filter;
271
272 for (idx_t &idx : distinct_info.indices) {
273 auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast<BoundAggregateExpression>();
274
275 D_ASSERT(distinct_info.table_map.count(idx));
276 idx_t table_idx = distinct_info.table_map[idx];
277 if (!distinct_data->radix_tables[table_idx]) {
278 continue;
279 }
280 D_ASSERT(distinct_data->radix_tables[table_idx]);
281 auto &radix_table = *distinct_data->radix_tables[table_idx];
282 auto &radix_global_sink = *distinct_state->radix_states[table_idx];
283 auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx];
284
285 InterruptState interrupt_state;
286 OperatorSinkInput sink_input {.global_state: radix_global_sink, .local_state: radix_local_sink, .interrupt_state: interrupt_state};
287
288 if (aggregate.filter) {
289 DataChunk filter_chunk;
290 auto &filtered_data = sink.filter_set.GetFilterData(aggr_idx: idx);
291 filter_chunk.InitializeEmpty(types: filtered_data.filtered_payload.GetTypes());
292
293 // Add the filter Vector (BOOL)
294 auto it = filter_indexes.find(x: aggregate.filter.get());
295 D_ASSERT(it != filter_indexes.end());
296 D_ASSERT(it->second < chunk.data.size());
297 auto &filter_bound_ref = aggregate.filter->Cast<BoundReferenceExpression>();
298 filter_chunk.data[filter_bound_ref.index].Reference(other&: chunk.data[it->second]);
299 filter_chunk.SetCardinality(chunk.size());
300
301 // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to
302 // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those.
303 SelectionVector sel_vec(STANDARD_VECTOR_SIZE);
304 idx_t count = filtered_data.filter_executor.SelectExpression(input&: filter_chunk, sel&: sel_vec);
305
306 if (count == 0) {
307 continue;
308 }
309
310 // Because the 'input' chunk needs to be re-used after this, we need to create
311 // a duplicate of it, that we can apply the filter to
312 DataChunk filtered_input;
313 filtered_input.InitializeEmpty(types: chunk.GetTypes());
314
315 for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) {
316 auto &group = grouped_aggregate_data.groups[group_idx];
317 auto &bound_ref = group->Cast<BoundReferenceExpression>();
318 filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]);
319 }
320 for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) {
321 auto &child = aggregate.children[child_idx];
322 auto &bound_ref = child->Cast<BoundReferenceExpression>();
323
324 filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]);
325 }
326 filtered_input.Slice(sel_vector: sel_vec, count);
327 filtered_input.SetCardinality(count);
328
329 radix_table.Sink(context, chunk&: filtered_input, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter);
330 } else {
331 radix_table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter);
332 }
333 }
334}
335
336void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const {
337 for (idx_t i = 0; i < groupings.size(); i++) {
338 SinkDistinctGrouping(context, chunk, input, grouping_idx: i);
339 }
340}
341
342SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk,
343 OperatorSinkInput &input) const {
344 auto &llstate = input.local_state.Cast<HashAggregateLocalState>();
345 auto &gstate = input.global_state.Cast<HashAggregateGlobalState>();
346
347 if (distinct_collection_info) {
348 SinkDistinct(context, chunk, input);
349 }
350
351 if (CanSkipRegularSink()) {
352 return SinkResultType::NEED_MORE_INPUT;
353 }
354
355 DataChunk &aggregate_input_chunk = llstate.aggregate_input_chunk;
356
357 auto &aggregates = grouped_aggregate_data.aggregates;
358 idx_t aggregate_input_idx = 0;
359
360 // Populate the aggregate child vectors
361 for (auto &aggregate : aggregates) {
362 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
363 for (auto &child_expr : aggr.children) {
364 D_ASSERT(child_expr->type == ExpressionType::BOUND_REF);
365 auto &bound_ref_expr = child_expr->Cast<BoundReferenceExpression>();
366 D_ASSERT(bound_ref_expr.index < chunk.data.size());
367 aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[bound_ref_expr.index]);
368 }
369 }
370 // Populate the filter vectors
371 for (auto &aggregate : aggregates) {
372 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
373 if (aggr.filter) {
374 auto it = filter_indexes.find(x: aggr.filter.get());
375 D_ASSERT(it != filter_indexes.end());
376 D_ASSERT(it->second < chunk.data.size());
377 aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[it->second]);
378 }
379 }
380
381 aggregate_input_chunk.SetCardinality(chunk.size());
382 aggregate_input_chunk.Verify();
383
384 // For every grouping set there is one radix_table
385 for (idx_t i = 0; i < groupings.size(); i++) {
386 auto &grouping_gstate = gstate.grouping_states[i];
387 auto &grouping_lstate = llstate.grouping_states[i];
388 InterruptState interrupt_state;
389 OperatorSinkInput sink_input {.global_state: *grouping_gstate.table_state, .local_state: *grouping_lstate.table_state, .interrupt_state: interrupt_state};
390
391 auto &grouping = groupings[i];
392 auto &table = grouping.table_data;
393 table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk, filter: non_distinct_filter);
394 }
395
396 return SinkResultType::NEED_MORE_INPUT;
397}
398
399void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, GlobalSinkState &state,
400 LocalSinkState &lstate) const {
401 auto &global_sink = state.Cast<HashAggregateGlobalState>();
402 auto &sink = lstate.Cast<HashAggregateLocalState>();
403
404 if (!distinct_collection_info) {
405 return;
406 }
407 for (idx_t i = 0; i < groupings.size(); i++) {
408 auto &grouping_gstate = global_sink.grouping_states[i];
409 auto &grouping_lstate = sink.grouping_states[i];
410
411 auto &distinct_data = groupings[i].distinct_data;
412 auto &distinct_state = grouping_gstate.distinct_state;
413
414 const auto table_count = distinct_data->radix_tables.size();
415 for (idx_t table_idx = 0; table_idx < table_count; table_idx++) {
416 if (!distinct_data->radix_tables[table_idx]) {
417 continue;
418 }
419 auto &radix_table = *distinct_data->radix_tables[table_idx];
420 auto &radix_global_sink = *distinct_state->radix_states[table_idx];
421 auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx];
422
423 radix_table.Combine(context, state&: radix_global_sink, lstate&: radix_local_sink);
424 }
425 }
426}
427
428void PhysicalHashAggregate::Combine(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate) const {
429 auto &gstate = state.Cast<HashAggregateGlobalState>();
430 auto &llstate = lstate.Cast<HashAggregateLocalState>();
431
432 CombineDistinct(context, state, lstate);
433
434 if (CanSkipRegularSink()) {
435 return;
436 }
437 for (idx_t i = 0; i < groupings.size(); i++) {
438 auto &grouping_gstate = gstate.grouping_states[i];
439 auto &grouping_lstate = llstate.grouping_states[i];
440
441 auto &grouping = groupings[i];
442 auto &table = grouping.table_data;
443 table.Combine(context, state&: *grouping_gstate.table_state, lstate&: *grouping_lstate.table_state);
444 }
445}
446
447//! REGULAR FINALIZE EVENT
448
449class HashAggregateMergeEvent : public BasePipelineEvent {
450public:
451 HashAggregateMergeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, Pipeline *pipeline_p)
452 : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p) {
453 }
454
455 const PhysicalHashAggregate &op;
456 HashAggregateGlobalState &gstate;
457
458public:
459 void Schedule() override {
460 vector<shared_ptr<Task>> tasks;
461 for (idx_t i = 0; i < op.groupings.size(); i++) {
462 auto &grouping_gstate = gstate.grouping_states[i];
463
464 auto &grouping = op.groupings[i];
465 auto &table = grouping.table_data;
466 table.ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(), state&: *grouping_gstate.table_state, tasks);
467 }
468 D_ASSERT(!tasks.empty());
469 SetTasks(std::move(tasks));
470 }
471};
472
473//! REGULAR FINALIZE FROM DISTINCT FINALIZE
474
475class HashAggregateFinalizeTask : public ExecutorTask {
476public:
477 HashAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p,
478 ClientContext &context, const PhysicalHashAggregate &op)
479 : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p),
480 context(context), op(op) {
481 }
482
483 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
484 op.FinalizeInternal(pipeline, event&: *event, context, gstate, check_distinct: false);
485 D_ASSERT(!gstate.finished);
486 gstate.finished = true;
487 event->FinishTask();
488 return TaskExecutionResult::TASK_FINISHED;
489 }
490
491private:
492 Pipeline &pipeline;
493 shared_ptr<Event> event;
494 HashAggregateGlobalState &gstate;
495 ClientContext &context;
496 const PhysicalHashAggregate &op;
497};
498
499class HashAggregateFinalizeEvent : public BasePipelineEvent {
500public:
501 HashAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p,
502 Pipeline *pipeline_p, ClientContext &context)
503 : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p), context(context) {
504 }
505
506 const PhysicalHashAggregate &op;
507 HashAggregateGlobalState &gstate;
508 ClientContext &context;
509
510public:
511 void Schedule() override {
512 vector<shared_ptr<Task>> tasks;
513 tasks.push_back(x: make_uniq<HashAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context, args: op));
514 D_ASSERT(!tasks.empty());
515 SetTasks(std::move(tasks));
516 }
517};
518
519//! DISTINCT FINALIZE TASK
520
521class HashDistinctAggregateFinalizeTask : public ExecutorTask {
522public:
523 HashDistinctAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p,
524 ClientContext &context, const PhysicalHashAggregate &op,
525 vector<vector<unique_ptr<GlobalSourceState>>> &global_sources_p)
526 : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p),
527 context(context), op(op), global_sources(global_sources_p) {
528 }
529
530 void AggregateDistinctGrouping(DistinctAggregateCollectionInfo &info,
531 const HashAggregateGroupingData &grouping_data,
532 HashAggregateGroupingGlobalState &grouping_state, idx_t grouping_idx) {
533 auto &aggregates = info.aggregates;
534 auto &data = *grouping_data.distinct_data;
535 auto &state = *grouping_state.distinct_state;
536 auto &table_state = *grouping_state.table_state;
537
538 ThreadContext temp_thread_context(context);
539 ExecutionContext temp_exec_context(context, temp_thread_context, &pipeline);
540
541 auto temp_local_state = grouping_data.table_data.GetLocalSinkState(context&: temp_exec_context);
542
543 // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors
544 DataChunk group_chunk;
545 if (!op.input_group_types.empty()) {
546 group_chunk.Initialize(context, types: op.input_group_types);
547 }
548
549 auto &groups = op.grouped_aggregate_data.groups;
550 const idx_t group_by_size = groups.size();
551
552 DataChunk aggregate_input_chunk;
553 if (!gstate.payload_types.empty()) {
554 aggregate_input_chunk.Initialize(context, types: gstate.payload_types);
555 }
556
557 idx_t payload_idx;
558 idx_t next_payload_idx = 0;
559
560 for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) {
561 auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>();
562
563 // Forward the payload idx
564 payload_idx = next_payload_idx;
565 next_payload_idx = payload_idx + aggregate.children.size();
566
567 // If aggregate is not distinct, skip it
568 if (!data.IsDistinct(index: i)) {
569 continue;
570 }
571 D_ASSERT(data.info.table_map.count(i));
572 auto table_idx = data.info.table_map.at(k: i);
573 auto &radix_table_p = data.radix_tables[table_idx];
574
575 // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original
576 DataChunk output_chunk;
577 output_chunk.Initialize(context, types: state.distinct_output_chunks[table_idx]->GetTypes());
578
579 auto &global_source = global_sources[grouping_idx][i];
580 auto local_source = radix_table_p->GetLocalSourceState(context&: temp_exec_context);
581
582 // Fetch all the data from the aggregate ht, and Sink it into the main ht
583 while (true) {
584 output_chunk.Reset();
585 group_chunk.Reset();
586 aggregate_input_chunk.Reset();
587
588 InterruptState interrupt_state;
589 OperatorSourceInput source_input {.global_state: *global_source, .local_state: *local_source, .interrupt_state: interrupt_state};
590 auto res = radix_table_p->GetData(context&: temp_exec_context, chunk&: output_chunk, sink_state&: *state.radix_states[table_idx],
591 input&: source_input);
592
593 if (res == SourceResultType::FINISHED) {
594 D_ASSERT(output_chunk.size() == 0);
595 break;
596 } else if (res == SourceResultType::BLOCKED) {
597 throw InternalException(
598 "Unexpected interrupt from radix table GetData in HashDistinctAggregateFinalizeTask");
599 }
600
601 auto &grouped_aggregate_data = *data.grouped_aggregate_data[table_idx];
602
603 for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) {
604 auto &group = grouped_aggregate_data.groups[group_idx];
605 auto &bound_ref_expr = group->Cast<BoundReferenceExpression>();
606 group_chunk.data[bound_ref_expr.index].Reference(other&: output_chunk.data[group_idx]);
607 }
608 group_chunk.SetCardinality(output_chunk);
609
610 for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size;
611 child_idx++) {
612 aggregate_input_chunk.data[payload_idx + child_idx].Reference(
613 other&: output_chunk.data[group_by_size + child_idx]);
614 }
615 aggregate_input_chunk.SetCardinality(output_chunk);
616
617 // Sink it into the main ht
618 OperatorSinkInput sink_input {.global_state: table_state, .local_state: *temp_local_state, .interrupt_state: interrupt_state};
619 grouping_data.table_data.Sink(context&: temp_exec_context, chunk&: group_chunk, input&: sink_input, aggregate_input_chunk, filter: {i});
620 }
621 }
622 grouping_data.table_data.Combine(context&: temp_exec_context, state&: table_state, lstate&: *temp_local_state);
623 }
624
625 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
626 D_ASSERT(op.distinct_collection_info);
627 auto &info = *op.distinct_collection_info;
628 for (idx_t i = 0; i < op.groupings.size(); i++) {
629 auto &grouping = op.groupings[i];
630 auto &grouping_state = gstate.grouping_states[i];
631 AggregateDistinctGrouping(info, grouping_data: grouping, grouping_state, grouping_idx: i);
632 }
633 event->FinishTask();
634 return TaskExecutionResult::TASK_FINISHED;
635 }
636
637private:
638 Pipeline &pipeline;
639 shared_ptr<Event> event;
640 HashAggregateGlobalState &gstate;
641 ClientContext &context;
642 const PhysicalHashAggregate &op;
643 vector<vector<unique_ptr<GlobalSourceState>>> &global_sources;
644};
645
646//! DISTINCT FINALIZE EVENT
647
648// TODO: Create tasks and run these in parallel instead of doing this all in Schedule, single threaded
649class HashDistinctAggregateFinalizeEvent : public BasePipelineEvent {
650public:
651 HashDistinctAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p,
652 Pipeline &pipeline_p, ClientContext &context)
653 : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) {
654 }
655 const PhysicalHashAggregate &op;
656 HashAggregateGlobalState &gstate;
657 ClientContext &context;
658 //! The GlobalSourceStates for all the radix tables of the distinct aggregates
659 vector<vector<unique_ptr<GlobalSourceState>>> global_sources;
660
661public:
662 void Schedule() override {
663 global_sources = CreateGlobalSources();
664
665 vector<shared_ptr<Task>> tasks;
666 auto &scheduler = TaskScheduler::GetScheduler(context);
667 auto number_of_threads = scheduler.NumberOfThreads();
668 tasks.reserve(n: number_of_threads);
669 for (int32_t i = 0; i < number_of_threads; i++) {
670 tasks.push_back(x: make_uniq<HashDistinctAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context,
671 args: op, args&: global_sources));
672 }
673 D_ASSERT(!tasks.empty());
674 SetTasks(std::move(tasks));
675 }
676
677 void FinishEvent() override {
678 //! Now that everything is added to the main ht, we can actually finalize
679 auto new_event = make_shared<HashAggregateFinalizeEvent>(args: op, args&: gstate, args: pipeline.get(), args&: context);
680 this->InsertEvent(replacement_event: std::move(new_event));
681 }
682
683private:
684 vector<vector<unique_ptr<GlobalSourceState>>> CreateGlobalSources() {
685 vector<vector<unique_ptr<GlobalSourceState>>> grouping_sources;
686 grouping_sources.reserve(n: op.groupings.size());
687 for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) {
688 auto &grouping = op.groupings[grouping_idx];
689 auto &data = *grouping.distinct_data;
690
691 vector<unique_ptr<GlobalSourceState>> aggregate_sources;
692 aggregate_sources.reserve(n: op.grouped_aggregate_data.aggregates.size());
693
694 for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) {
695 auto &aggregate = op.grouped_aggregate_data.aggregates[i];
696 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
697
698 if (!aggr.IsDistinct()) {
699 aggregate_sources.push_back(x: nullptr);
700 continue;
701 }
702
703 D_ASSERT(data.info.table_map.count(i));
704 auto table_idx = data.info.table_map.at(k: i);
705 auto &radix_table_p = data.radix_tables[table_idx];
706 aggregate_sources.push_back(x: radix_table_p->GetGlobalSourceState(context));
707 }
708 grouping_sources.push_back(x: std::move(aggregate_sources));
709 }
710 return grouping_sources;
711 }
712};
713
714//! DISTINCT COMBINE EVENT
715
716class HashDistinctCombineFinalizeEvent : public BasePipelineEvent {
717public:
718 HashDistinctCombineFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p,
719 Pipeline &pipeline_p, ClientContext &client)
720 : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), client(client) {
721 }
722
723 const PhysicalHashAggregate &op;
724 HashAggregateGlobalState &gstate;
725 ClientContext &client;
726
727public:
728 void Schedule() override {
729 vector<shared_ptr<Task>> tasks;
730 for (idx_t i = 0; i < op.groupings.size(); i++) {
731 auto &grouping = op.groupings[i];
732 auto &distinct_data = *grouping.distinct_data;
733 auto &distinct_state = *gstate.grouping_states[i].distinct_state;
734 for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) {
735 if (!distinct_data.radix_tables[table_idx]) {
736 continue;
737 }
738 distinct_data.radix_tables[table_idx]->ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(),
739 state&: *distinct_state.radix_states[table_idx], tasks);
740 }
741 }
742
743 D_ASSERT(!tasks.empty());
744 SetTasks(std::move(tasks));
745 }
746
747 void FinishEvent() override {
748 //! Now that all tables are combined, it's time to do the distinct aggregations
749 auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: op, args&: gstate, args&: *pipeline, args&: client);
750 this->InsertEvent(replacement_event: std::move(new_event));
751 }
752};
753
754//! FINALIZE
755
756SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context,
757 GlobalSinkState &gstate_p) const {
758 auto &gstate = gstate_p.Cast<HashAggregateGlobalState>();
759 D_ASSERT(distinct_collection_info);
760
761 bool any_partitioned = false;
762 for (idx_t i = 0; i < groupings.size(); i++) {
763 auto &grouping = groupings[i];
764 auto &distinct_data = *grouping.distinct_data;
765 auto &distinct_state = *gstate.grouping_states[i].distinct_state;
766
767 for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) {
768 if (!distinct_data.radix_tables[table_idx]) {
769 continue;
770 }
771 auto &radix_table = distinct_data.radix_tables[table_idx];
772 auto &radix_state = *distinct_state.radix_states[table_idx];
773 bool partitioned = radix_table->Finalize(context, gstate_p&: radix_state);
774 if (partitioned) {
775 any_partitioned = true;
776 }
777 }
778 }
779 if (any_partitioned) {
780 // If any of the groupings are partitioned then we first need to combine those, then aggregate
781 auto new_event = make_shared<HashDistinctCombineFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context);
782 event.InsertEvent(replacement_event: std::move(new_event));
783 } else {
784 // Hashtables aren't partitioned, they dont need to be joined first
785 // so we can already compute the aggregate
786 auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context);
787 event.InsertEvent(replacement_event: std::move(new_event));
788 }
789 return SinkFinalizeType::READY;
790}
791
792SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context,
793 GlobalSinkState &gstate_p, bool check_distinct) const {
794 auto &gstate = gstate_p.Cast<HashAggregateGlobalState>();
795
796 if (check_distinct && distinct_collection_info) {
797 // There are distinct aggregates
798 // If these are partitioned those need to be combined first
799 // Then we Finalize again, skipping this step
800 return FinalizeDistinct(pipeline, event, context, gstate_p);
801 }
802
803 bool any_partitioned = false;
804 for (idx_t i = 0; i < groupings.size(); i++) {
805 auto &grouping = groupings[i];
806 auto &grouping_gstate = gstate.grouping_states[i];
807
808 bool is_partitioned = grouping.table_data.Finalize(context, gstate_p&: *grouping_gstate.table_state);
809 if (is_partitioned) {
810 any_partitioned = true;
811 }
812 }
813 if (any_partitioned) {
814 auto new_event = make_shared<HashAggregateMergeEvent>(args: *this, args&: gstate, args: &pipeline);
815 event.InsertEvent(replacement_event: std::move(new_event));
816 }
817 return SinkFinalizeType::READY;
818}
819
820SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context,
821 GlobalSinkState &gstate_p) const {
822 return FinalizeInternal(pipeline, event, context, gstate_p, check_distinct: true);
823}
824
825//===--------------------------------------------------------------------===//
826// Source
827//===--------------------------------------------------------------------===//
828class PhysicalHashAggregateGlobalSourceState : public GlobalSourceState {
829public:
830 PhysicalHashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op)
831 : op(op), state_index(0) {
832 for (auto &grouping : op.groupings) {
833 auto &rt = grouping.table_data;
834 radix_states.push_back(x: rt.GetGlobalSourceState(context));
835 }
836 }
837
838 const PhysicalHashAggregate &op;
839 mutex lock;
840 atomic<idx_t> state_index;
841
842 vector<unique_ptr<GlobalSourceState>> radix_states;
843
844public:
845 idx_t MaxThreads() override {
846 // If there are no tables, we only need one thread.
847 if (op.groupings.empty()) {
848 return 1;
849 }
850
851 auto &ht_state = op.sink_state->Cast<HashAggregateGlobalState>();
852 idx_t count = 0;
853 for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) {
854 auto &grouping = op.groupings[sidx];
855 auto &grouping_gstate = ht_state.grouping_states[sidx];
856 count += grouping.table_data.Size(sink_state&: *grouping_gstate.table_state);
857 }
858 return MaxValue<idx_t>(a: 1, b: count / STANDARD_VECTOR_SIZE);
859 }
860};
861
862unique_ptr<GlobalSourceState> PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const {
863 return make_uniq<PhysicalHashAggregateGlobalSourceState>(args&: context, args: *this);
864}
865
866class PhysicalHashAggregateLocalSourceState : public LocalSourceState {
867public:
868 explicit PhysicalHashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) {
869 for (auto &grouping : op.groupings) {
870 auto &rt = grouping.table_data;
871 radix_states.push_back(x: rt.GetLocalSourceState(context));
872 }
873 }
874
875 vector<unique_ptr<LocalSourceState>> radix_states;
876};
877
878unique_ptr<LocalSourceState> PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context,
879 GlobalSourceState &gstate) const {
880 return make_uniq<PhysicalHashAggregateLocalSourceState>(args&: context, args: *this);
881}
882
883SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk,
884 OperatorSourceInput &input) const {
885 auto &sink_gstate = sink_state->Cast<HashAggregateGlobalState>();
886 auto &gstate = input.global_state.Cast<PhysicalHashAggregateGlobalSourceState>();
887 auto &lstate = input.local_state.Cast<PhysicalHashAggregateLocalSourceState>();
888 while (true) {
889 idx_t radix_idx = gstate.state_index;
890 if (radix_idx >= groupings.size()) {
891 break;
892 }
893 auto &grouping = groupings[radix_idx];
894 auto &radix_table = grouping.table_data;
895 auto &grouping_gstate = sink_gstate.grouping_states[radix_idx];
896
897 InterruptState interrupt_state;
898 OperatorSourceInput source_input {.global_state: *gstate.radix_states[radix_idx], .local_state: *lstate.radix_states[radix_idx],
899 .interrupt_state: interrupt_state};
900 auto res = radix_table.GetData(context, chunk, sink_state&: *grouping_gstate.table_state, input&: source_input);
901 if (chunk.size() != 0) {
902 return SourceResultType::HAVE_MORE_OUTPUT;
903 } else if (res == SourceResultType::BLOCKED) {
904 throw InternalException("Unexpectedly Blocked from radix_table");
905 }
906
907 // move to the next table
908 lock_guard<mutex> l(gstate.lock);
909 radix_idx++;
910 if (radix_idx > gstate.state_index) {
911 // we have not yet worked on the table
912 // move the global index forwards
913 gstate.state_index = radix_idx;
914 }
915 }
916
917 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
918}
919
920string PhysicalHashAggregate::ParamsToString() const {
921 string result;
922 auto &groups = grouped_aggregate_data.groups;
923 auto &aggregates = grouped_aggregate_data.aggregates;
924 for (idx_t i = 0; i < groups.size(); i++) {
925 if (i > 0) {
926 result += "\n";
927 }
928 result += groups[i]->GetName();
929 }
930 for (idx_t i = 0; i < aggregates.size(); i++) {
931 auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>();
932 if (i > 0 || !groups.empty()) {
933 result += "\n";
934 }
935 result += aggregates[i]->GetName();
936 if (aggregate.filter) {
937 result += " Filter: " + aggregate.filter->GetName();
938 }
939 }
940 return result;
941}
942
943} // namespace duckdb
944