1 | #include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" |
2 | |
3 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
5 | #include "duckdb/execution/aggregate_hashtable.hpp" |
6 | #include "duckdb/main/client_context.hpp" |
7 | #include "duckdb/parallel/interrupt.hpp" |
8 | #include "duckdb/parallel/pipeline.hpp" |
9 | #include "duckdb/parallel/task_scheduler.hpp" |
10 | #include "duckdb/parallel/thread_context.hpp" |
11 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
12 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
13 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
14 | #include "duckdb/parallel/base_pipeline_event.hpp" |
15 | #include "duckdb/common/atomic.hpp" |
16 | #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" |
17 | |
18 | namespace duckdb { |
19 | |
20 | HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p, |
21 | const GroupedAggregateData &grouped_aggregate_data, |
22 | unique_ptr<DistinctAggregateCollectionInfo> &info) |
23 | : table_data(grouping_set_p, grouped_aggregate_data) { |
24 | if (info) { |
25 | distinct_data = make_uniq<DistinctAggregateData>(args&: *info, args&: grouping_set_p, args: &grouped_aggregate_data.groups); |
26 | } |
27 | } |
28 | |
29 | bool HashAggregateGroupingData::HasDistinct() const { |
30 | return distinct_data != nullptr; |
31 | } |
32 | |
33 | HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, |
34 | ClientContext &context) { |
35 | table_state = data.table_data.GetGlobalSinkState(context); |
36 | if (data.HasDistinct()) { |
37 | distinct_state = make_uniq<DistinctAggregateState>(args&: *data.distinct_data, args&: context); |
38 | } |
39 | } |
40 | |
41 | HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, |
42 | const HashAggregateGroupingData &data, |
43 | ExecutionContext &context) { |
44 | table_state = data.table_data.GetLocalSinkState(context); |
45 | if (!data.HasDistinct()) { |
46 | return; |
47 | } |
48 | auto &distinct_data = *data.distinct_data; |
49 | |
50 | auto &distinct_indices = op.distinct_collection_info->Indices(); |
51 | D_ASSERT(!distinct_indices.empty()); |
52 | |
53 | distinct_states.resize(new_size: op.distinct_collection_info->aggregates.size()); |
54 | auto &table_map = op.distinct_collection_info->table_map; |
55 | |
56 | for (auto &idx : distinct_indices) { |
57 | idx_t table_idx = table_map[idx]; |
58 | auto &radix_table = distinct_data.radix_tables[table_idx]; |
59 | if (radix_table == nullptr) { |
60 | // This aggregate has identical input as another aggregate, so no table is created for it |
61 | continue; |
62 | } |
63 | // Initialize the states of the radix tables used for the distinct aggregates |
64 | distinct_states[table_idx] = radix_table->GetLocalSinkState(context); |
65 | } |
66 | } |
67 | |
68 | static vector<LogicalType> CreateGroupChunkTypes(vector<unique_ptr<Expression>> &groups) { |
69 | set<idx_t> group_indices; |
70 | |
71 | if (groups.empty()) { |
72 | return {}; |
73 | } |
74 | |
75 | for (auto &group : groups) { |
76 | D_ASSERT(group->type == ExpressionType::BOUND_REF); |
77 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
78 | group_indices.insert(x: bound_ref.index); |
79 | } |
80 | idx_t highest_index = *group_indices.rbegin(); |
81 | vector<LogicalType> types(highest_index + 1, LogicalType::SQLNULL); |
82 | for (auto &group : groups) { |
83 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
84 | types[bound_ref.index] = bound_ref.return_type; |
85 | } |
86 | return types; |
87 | } |
88 | |
89 | bool PhysicalHashAggregate::CanSkipRegularSink() const { |
90 | if (!filter_indexes.empty()) { |
91 | // If we have filters, we can't skip the regular sink, because we might lose groups otherwise. |
92 | return false; |
93 | } |
94 | if (grouped_aggregate_data.aggregates.empty()) { |
95 | // When there are no aggregates, we have to add to the main ht right away |
96 | return false; |
97 | } |
98 | if (!non_distinct_filter.empty()) { |
99 | return false; |
100 | } |
101 | return true; |
102 | } |
103 | |
104 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
105 | vector<unique_ptr<Expression>> expressions, idx_t estimated_cardinality) |
106 | : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) { |
107 | } |
108 | |
109 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
110 | vector<unique_ptr<Expression>> expressions, |
111 | vector<unique_ptr<Expression>> groups_p, idx_t estimated_cardinality) |
112 | : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {}, |
113 | estimated_cardinality) { |
114 | } |
115 | |
116 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
117 | vector<unique_ptr<Expression>> expressions, |
118 | vector<unique_ptr<Expression>> groups_p, |
119 | vector<GroupingSet> grouping_sets_p, |
120 | vector<unsafe_vector<idx_t>> grouping_functions_p, |
121 | idx_t estimated_cardinality) |
122 | : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality), |
123 | grouping_sets(std::move(grouping_sets_p)) { |
124 | // get a list of all aggregates to be computed |
125 | const idx_t group_count = groups_p.size(); |
126 | if (grouping_sets.empty()) { |
127 | GroupingSet set; |
128 | for (idx_t i = 0; i < group_count; i++) { |
129 | set.insert(x: i); |
130 | } |
131 | grouping_sets.push_back(x: std::move(set)); |
132 | } |
133 | input_group_types = CreateGroupChunkTypes(groups&: groups_p); |
134 | |
135 | grouped_aggregate_data.InitializeGroupby(groups: std::move(groups_p), expressions: std::move(expressions), |
136 | grouping_functions: std::move(grouping_functions_p)); |
137 | |
138 | auto &aggregates = grouped_aggregate_data.aggregates; |
139 | // filter_indexes must be pre-built, not lazily instantiated in parallel... |
140 | // Because everything that lives in this class should be read-only at execution time |
141 | idx_t aggregate_input_idx = 0; |
142 | for (idx_t i = 0; i < aggregates.size(); i++) { |
143 | auto &aggregate = aggregates[i]; |
144 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
145 | aggregate_input_idx += aggr.children.size(); |
146 | if (aggr.aggr_type == AggregateType::DISTINCT) { |
147 | distinct_filter.push_back(x: i); |
148 | } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { |
149 | non_distinct_filter.push_back(x: i); |
150 | } else { // LCOV_EXCL_START |
151 | throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate" ); |
152 | } // LCOV_EXCL_STOP |
153 | } |
154 | |
155 | for (idx_t i = 0; i < aggregates.size(); i++) { |
156 | auto &aggregate = aggregates[i]; |
157 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
158 | if (aggr.filter) { |
159 | auto &bound_ref_expr = aggr.filter->Cast<BoundReferenceExpression>(); |
160 | if (!filter_indexes.count(x: aggr.filter.get())) { |
161 | // Replace the bound reference expression's index with the corresponding index of the payload chunk |
162 | filter_indexes[aggr.filter.get()] = bound_ref_expr.index; |
163 | bound_ref_expr.index = aggregate_input_idx; |
164 | } |
165 | aggregate_input_idx++; |
166 | } |
167 | } |
168 | |
169 | distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates&: grouped_aggregate_data.aggregates); |
170 | |
171 | for (idx_t i = 0; i < grouping_sets.size(); i++) { |
172 | groupings.emplace_back(args&: grouping_sets[i], args&: grouped_aggregate_data, args&: distinct_collection_info); |
173 | } |
174 | } |
175 | |
176 | //===--------------------------------------------------------------------===// |
177 | // Sink |
178 | //===--------------------------------------------------------------------===// |
179 | class HashAggregateGlobalState : public GlobalSinkState { |
180 | public: |
181 | HashAggregateGlobalState(const PhysicalHashAggregate &op, ClientContext &context) { |
182 | grouping_states.reserve(n: op.groupings.size()); |
183 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
184 | auto &grouping = op.groupings[i]; |
185 | grouping_states.emplace_back(args: grouping, args&: context); |
186 | } |
187 | vector<LogicalType> filter_types; |
188 | for (auto &aggr : op.grouped_aggregate_data.aggregates) { |
189 | auto &aggregate = aggr->Cast<BoundAggregateExpression>(); |
190 | for (auto &child : aggregate.children) { |
191 | payload_types.push_back(x: child->return_type); |
192 | } |
193 | if (aggregate.filter) { |
194 | filter_types.push_back(x: aggregate.filter->return_type); |
195 | } |
196 | } |
197 | payload_types.reserve(n: payload_types.size() + filter_types.size()); |
198 | payload_types.insert(position: payload_types.end(), first: filter_types.begin(), last: filter_types.end()); |
199 | } |
200 | |
201 | vector<HashAggregateGroupingGlobalState> grouping_states; |
202 | vector<LogicalType> payload_types; |
203 | //! Whether or not the aggregate is finished |
204 | bool finished = false; |
205 | }; |
206 | |
207 | class HashAggregateLocalState : public LocalSinkState { |
208 | public: |
209 | HashAggregateLocalState(const PhysicalHashAggregate &op, ExecutionContext &context) { |
210 | |
211 | auto &payload_types = op.grouped_aggregate_data.payload_types; |
212 | if (!payload_types.empty()) { |
213 | aggregate_input_chunk.InitializeEmpty(types: payload_types); |
214 | } |
215 | |
216 | grouping_states.reserve(n: op.groupings.size()); |
217 | for (auto &grouping : op.groupings) { |
218 | grouping_states.emplace_back(args: op, args: grouping, args&: context); |
219 | } |
220 | // The filter set is only needed here for the distinct aggregates |
221 | // the filtering of data for the regular aggregates is done within the hashtable |
222 | vector<AggregateObject> aggregate_objects; |
223 | for (auto &aggregate : op.grouped_aggregate_data.aggregates) { |
224 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
225 | aggregate_objects.emplace_back(args: &aggr); |
226 | } |
227 | |
228 | filter_set.Initialize(context&: context.client, aggregates: aggregate_objects, payload_types); |
229 | } |
230 | |
231 | DataChunk aggregate_input_chunk; |
232 | vector<HashAggregateGroupingLocalState> grouping_states; |
233 | AggregateFilterDataSet filter_set; |
234 | }; |
235 | |
236 | void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { |
237 | auto &gstate = state.Cast<HashAggregateGlobalState>(); |
238 | for (auto &grouping_state : gstate.grouping_states) { |
239 | auto &radix_state = grouping_state.table_state; |
240 | RadixPartitionedHashTable::SetMultiScan(*radix_state); |
241 | if (!grouping_state.distinct_state) { |
242 | continue; |
243 | } |
244 | } |
245 | } |
246 | |
247 | unique_ptr<GlobalSinkState> PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const { |
248 | return make_uniq<HashAggregateGlobalState>(args: *this, args&: context); |
249 | } |
250 | |
251 | unique_ptr<LocalSinkState> PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const { |
252 | return make_uniq<HashAggregateLocalState>(args: *this, args&: context); |
253 | } |
254 | |
255 | void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, |
256 | idx_t grouping_idx) const { |
257 | auto &sink = input.local_state.Cast<HashAggregateLocalState>(); |
258 | auto &global_sink = input.global_state.Cast<HashAggregateGlobalState>(); |
259 | |
260 | auto &grouping_gstate = global_sink.grouping_states[grouping_idx]; |
261 | auto &grouping_lstate = sink.grouping_states[grouping_idx]; |
262 | auto &distinct_info = *distinct_collection_info; |
263 | |
264 | auto &distinct_state = grouping_gstate.distinct_state; |
265 | auto &distinct_data = groupings[grouping_idx].distinct_data; |
266 | |
267 | DataChunk empty_chunk; |
268 | |
269 | // Create an empty filter for Sink, since we don't need to update any aggregate states here |
270 | unsafe_vector<idx_t> empty_filter; |
271 | |
272 | for (idx_t &idx : distinct_info.indices) { |
273 | auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast<BoundAggregateExpression>(); |
274 | |
275 | D_ASSERT(distinct_info.table_map.count(idx)); |
276 | idx_t table_idx = distinct_info.table_map[idx]; |
277 | if (!distinct_data->radix_tables[table_idx]) { |
278 | continue; |
279 | } |
280 | D_ASSERT(distinct_data->radix_tables[table_idx]); |
281 | auto &radix_table = *distinct_data->radix_tables[table_idx]; |
282 | auto &radix_global_sink = *distinct_state->radix_states[table_idx]; |
283 | auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; |
284 | |
285 | InterruptState interrupt_state; |
286 | OperatorSinkInput sink_input {.global_state: radix_global_sink, .local_state: radix_local_sink, .interrupt_state: interrupt_state}; |
287 | |
288 | if (aggregate.filter) { |
289 | DataChunk filter_chunk; |
290 | auto &filtered_data = sink.filter_set.GetFilterData(aggr_idx: idx); |
291 | filter_chunk.InitializeEmpty(types: filtered_data.filtered_payload.GetTypes()); |
292 | |
293 | // Add the filter Vector (BOOL) |
294 | auto it = filter_indexes.find(x: aggregate.filter.get()); |
295 | D_ASSERT(it != filter_indexes.end()); |
296 | D_ASSERT(it->second < chunk.data.size()); |
297 | auto &filter_bound_ref = aggregate.filter->Cast<BoundReferenceExpression>(); |
298 | filter_chunk.data[filter_bound_ref.index].Reference(other&: chunk.data[it->second]); |
299 | filter_chunk.SetCardinality(chunk.size()); |
300 | |
301 | // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to |
302 | // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. |
303 | SelectionVector sel_vec(STANDARD_VECTOR_SIZE); |
304 | idx_t count = filtered_data.filter_executor.SelectExpression(input&: filter_chunk, sel&: sel_vec); |
305 | |
306 | if (count == 0) { |
307 | continue; |
308 | } |
309 | |
310 | // Because the 'input' chunk needs to be re-used after this, we need to create |
311 | // a duplicate of it, that we can apply the filter to |
312 | DataChunk filtered_input; |
313 | filtered_input.InitializeEmpty(types: chunk.GetTypes()); |
314 | |
315 | for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { |
316 | auto &group = grouped_aggregate_data.groups[group_idx]; |
317 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
318 | filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]); |
319 | } |
320 | for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { |
321 | auto &child = aggregate.children[child_idx]; |
322 | auto &bound_ref = child->Cast<BoundReferenceExpression>(); |
323 | |
324 | filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]); |
325 | } |
326 | filtered_input.Slice(sel_vector: sel_vec, count); |
327 | filtered_input.SetCardinality(count); |
328 | |
329 | radix_table.Sink(context, chunk&: filtered_input, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter); |
330 | } else { |
331 | radix_table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter); |
332 | } |
333 | } |
334 | } |
335 | |
336 | void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
337 | for (idx_t i = 0; i < groupings.size(); i++) { |
338 | SinkDistinctGrouping(context, chunk, input, grouping_idx: i); |
339 | } |
340 | } |
341 | |
342 | SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, |
343 | OperatorSinkInput &input) const { |
344 | auto &llstate = input.local_state.Cast<HashAggregateLocalState>(); |
345 | auto &gstate = input.global_state.Cast<HashAggregateGlobalState>(); |
346 | |
347 | if (distinct_collection_info) { |
348 | SinkDistinct(context, chunk, input); |
349 | } |
350 | |
351 | if (CanSkipRegularSink()) { |
352 | return SinkResultType::NEED_MORE_INPUT; |
353 | } |
354 | |
355 | DataChunk &aggregate_input_chunk = llstate.aggregate_input_chunk; |
356 | |
357 | auto &aggregates = grouped_aggregate_data.aggregates; |
358 | idx_t aggregate_input_idx = 0; |
359 | |
360 | // Populate the aggregate child vectors |
361 | for (auto &aggregate : aggregates) { |
362 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
363 | for (auto &child_expr : aggr.children) { |
364 | D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); |
365 | auto &bound_ref_expr = child_expr->Cast<BoundReferenceExpression>(); |
366 | D_ASSERT(bound_ref_expr.index < chunk.data.size()); |
367 | aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[bound_ref_expr.index]); |
368 | } |
369 | } |
370 | // Populate the filter vectors |
371 | for (auto &aggregate : aggregates) { |
372 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
373 | if (aggr.filter) { |
374 | auto it = filter_indexes.find(x: aggr.filter.get()); |
375 | D_ASSERT(it != filter_indexes.end()); |
376 | D_ASSERT(it->second < chunk.data.size()); |
377 | aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[it->second]); |
378 | } |
379 | } |
380 | |
381 | aggregate_input_chunk.SetCardinality(chunk.size()); |
382 | aggregate_input_chunk.Verify(); |
383 | |
384 | // For every grouping set there is one radix_table |
385 | for (idx_t i = 0; i < groupings.size(); i++) { |
386 | auto &grouping_gstate = gstate.grouping_states[i]; |
387 | auto &grouping_lstate = llstate.grouping_states[i]; |
388 | InterruptState interrupt_state; |
389 | OperatorSinkInput sink_input {.global_state: *grouping_gstate.table_state, .local_state: *grouping_lstate.table_state, .interrupt_state: interrupt_state}; |
390 | |
391 | auto &grouping = groupings[i]; |
392 | auto &table = grouping.table_data; |
393 | table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk, filter: non_distinct_filter); |
394 | } |
395 | |
396 | return SinkResultType::NEED_MORE_INPUT; |
397 | } |
398 | |
399 | void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, GlobalSinkState &state, |
400 | LocalSinkState &lstate) const { |
401 | auto &global_sink = state.Cast<HashAggregateGlobalState>(); |
402 | auto &sink = lstate.Cast<HashAggregateLocalState>(); |
403 | |
404 | if (!distinct_collection_info) { |
405 | return; |
406 | } |
407 | for (idx_t i = 0; i < groupings.size(); i++) { |
408 | auto &grouping_gstate = global_sink.grouping_states[i]; |
409 | auto &grouping_lstate = sink.grouping_states[i]; |
410 | |
411 | auto &distinct_data = groupings[i].distinct_data; |
412 | auto &distinct_state = grouping_gstate.distinct_state; |
413 | |
414 | const auto table_count = distinct_data->radix_tables.size(); |
415 | for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { |
416 | if (!distinct_data->radix_tables[table_idx]) { |
417 | continue; |
418 | } |
419 | auto &radix_table = *distinct_data->radix_tables[table_idx]; |
420 | auto &radix_global_sink = *distinct_state->radix_states[table_idx]; |
421 | auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; |
422 | |
423 | radix_table.Combine(context, state&: radix_global_sink, lstate&: radix_local_sink); |
424 | } |
425 | } |
426 | } |
427 | |
428 | void PhysicalHashAggregate::Combine(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate) const { |
429 | auto &gstate = state.Cast<HashAggregateGlobalState>(); |
430 | auto &llstate = lstate.Cast<HashAggregateLocalState>(); |
431 | |
432 | CombineDistinct(context, state, lstate); |
433 | |
434 | if (CanSkipRegularSink()) { |
435 | return; |
436 | } |
437 | for (idx_t i = 0; i < groupings.size(); i++) { |
438 | auto &grouping_gstate = gstate.grouping_states[i]; |
439 | auto &grouping_lstate = llstate.grouping_states[i]; |
440 | |
441 | auto &grouping = groupings[i]; |
442 | auto &table = grouping.table_data; |
443 | table.Combine(context, state&: *grouping_gstate.table_state, lstate&: *grouping_lstate.table_state); |
444 | } |
445 | } |
446 | |
447 | //! REGULAR FINALIZE EVENT |
448 | |
449 | class HashAggregateMergeEvent : public BasePipelineEvent { |
450 | public: |
451 | HashAggregateMergeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, Pipeline *pipeline_p) |
452 | : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p) { |
453 | } |
454 | |
455 | const PhysicalHashAggregate &op; |
456 | HashAggregateGlobalState &gstate; |
457 | |
458 | public: |
459 | void Schedule() override { |
460 | vector<shared_ptr<Task>> tasks; |
461 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
462 | auto &grouping_gstate = gstate.grouping_states[i]; |
463 | |
464 | auto &grouping = op.groupings[i]; |
465 | auto &table = grouping.table_data; |
466 | table.ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(), state&: *grouping_gstate.table_state, tasks); |
467 | } |
468 | D_ASSERT(!tasks.empty()); |
469 | SetTasks(std::move(tasks)); |
470 | } |
471 | }; |
472 | |
473 | //! REGULAR FINALIZE FROM DISTINCT FINALIZE |
474 | |
475 | class HashAggregateFinalizeTask : public ExecutorTask { |
476 | public: |
477 | HashAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p, |
478 | ClientContext &context, const PhysicalHashAggregate &op) |
479 | : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p), |
480 | context(context), op(op) { |
481 | } |
482 | |
483 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
484 | op.FinalizeInternal(pipeline, event&: *event, context, gstate, check_distinct: false); |
485 | D_ASSERT(!gstate.finished); |
486 | gstate.finished = true; |
487 | event->FinishTask(); |
488 | return TaskExecutionResult::TASK_FINISHED; |
489 | } |
490 | |
491 | private: |
492 | Pipeline &pipeline; |
493 | shared_ptr<Event> event; |
494 | HashAggregateGlobalState &gstate; |
495 | ClientContext &context; |
496 | const PhysicalHashAggregate &op; |
497 | }; |
498 | |
499 | class HashAggregateFinalizeEvent : public BasePipelineEvent { |
500 | public: |
501 | HashAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
502 | Pipeline *pipeline_p, ClientContext &context) |
503 | : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p), context(context) { |
504 | } |
505 | |
506 | const PhysicalHashAggregate &op; |
507 | HashAggregateGlobalState &gstate; |
508 | ClientContext &context; |
509 | |
510 | public: |
511 | void Schedule() override { |
512 | vector<shared_ptr<Task>> tasks; |
513 | tasks.push_back(x: make_uniq<HashAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context, args: op)); |
514 | D_ASSERT(!tasks.empty()); |
515 | SetTasks(std::move(tasks)); |
516 | } |
517 | }; |
518 | |
519 | //! DISTINCT FINALIZE TASK |
520 | |
521 | class HashDistinctAggregateFinalizeTask : public ExecutorTask { |
522 | public: |
523 | HashDistinctAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p, |
524 | ClientContext &context, const PhysicalHashAggregate &op, |
525 | vector<vector<unique_ptr<GlobalSourceState>>> &global_sources_p) |
526 | : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p), |
527 | context(context), op(op), global_sources(global_sources_p) { |
528 | } |
529 | |
530 | void AggregateDistinctGrouping(DistinctAggregateCollectionInfo &info, |
531 | const HashAggregateGroupingData &grouping_data, |
532 | HashAggregateGroupingGlobalState &grouping_state, idx_t grouping_idx) { |
533 | auto &aggregates = info.aggregates; |
534 | auto &data = *grouping_data.distinct_data; |
535 | auto &state = *grouping_state.distinct_state; |
536 | auto &table_state = *grouping_state.table_state; |
537 | |
538 | ThreadContext temp_thread_context(context); |
539 | ExecutionContext temp_exec_context(context, temp_thread_context, &pipeline); |
540 | |
541 | auto temp_local_state = grouping_data.table_data.GetLocalSinkState(context&: temp_exec_context); |
542 | |
543 | // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors |
544 | DataChunk group_chunk; |
545 | if (!op.input_group_types.empty()) { |
546 | group_chunk.Initialize(context, types: op.input_group_types); |
547 | } |
548 | |
549 | auto &groups = op.grouped_aggregate_data.groups; |
550 | const idx_t group_by_size = groups.size(); |
551 | |
552 | DataChunk aggregate_input_chunk; |
553 | if (!gstate.payload_types.empty()) { |
554 | aggregate_input_chunk.Initialize(context, types: gstate.payload_types); |
555 | } |
556 | |
557 | idx_t payload_idx; |
558 | idx_t next_payload_idx = 0; |
559 | |
560 | for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) { |
561 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
562 | |
563 | // Forward the payload idx |
564 | payload_idx = next_payload_idx; |
565 | next_payload_idx = payload_idx + aggregate.children.size(); |
566 | |
567 | // If aggregate is not distinct, skip it |
568 | if (!data.IsDistinct(index: i)) { |
569 | continue; |
570 | } |
571 | D_ASSERT(data.info.table_map.count(i)); |
572 | auto table_idx = data.info.table_map.at(k: i); |
573 | auto &radix_table_p = data.radix_tables[table_idx]; |
574 | |
575 | // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original |
576 | DataChunk output_chunk; |
577 | output_chunk.Initialize(context, types: state.distinct_output_chunks[table_idx]->GetTypes()); |
578 | |
579 | auto &global_source = global_sources[grouping_idx][i]; |
580 | auto local_source = radix_table_p->GetLocalSourceState(context&: temp_exec_context); |
581 | |
582 | // Fetch all the data from the aggregate ht, and Sink it into the main ht |
583 | while (true) { |
584 | output_chunk.Reset(); |
585 | group_chunk.Reset(); |
586 | aggregate_input_chunk.Reset(); |
587 | |
588 | InterruptState interrupt_state; |
589 | OperatorSourceInput source_input {.global_state: *global_source, .local_state: *local_source, .interrupt_state: interrupt_state}; |
590 | auto res = radix_table_p->GetData(context&: temp_exec_context, chunk&: output_chunk, sink_state&: *state.radix_states[table_idx], |
591 | input&: source_input); |
592 | |
593 | if (res == SourceResultType::FINISHED) { |
594 | D_ASSERT(output_chunk.size() == 0); |
595 | break; |
596 | } else if (res == SourceResultType::BLOCKED) { |
597 | throw InternalException( |
598 | "Unexpected interrupt from radix table GetData in HashDistinctAggregateFinalizeTask" ); |
599 | } |
600 | |
601 | auto &grouped_aggregate_data = *data.grouped_aggregate_data[table_idx]; |
602 | |
603 | for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { |
604 | auto &group = grouped_aggregate_data.groups[group_idx]; |
605 | auto &bound_ref_expr = group->Cast<BoundReferenceExpression>(); |
606 | group_chunk.data[bound_ref_expr.index].Reference(other&: output_chunk.data[group_idx]); |
607 | } |
608 | group_chunk.SetCardinality(output_chunk); |
609 | |
610 | for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; |
611 | child_idx++) { |
612 | aggregate_input_chunk.data[payload_idx + child_idx].Reference( |
613 | other&: output_chunk.data[group_by_size + child_idx]); |
614 | } |
615 | aggregate_input_chunk.SetCardinality(output_chunk); |
616 | |
617 | // Sink it into the main ht |
618 | OperatorSinkInput sink_input {.global_state: table_state, .local_state: *temp_local_state, .interrupt_state: interrupt_state}; |
619 | grouping_data.table_data.Sink(context&: temp_exec_context, chunk&: group_chunk, input&: sink_input, aggregate_input_chunk, filter: {i}); |
620 | } |
621 | } |
622 | grouping_data.table_data.Combine(context&: temp_exec_context, state&: table_state, lstate&: *temp_local_state); |
623 | } |
624 | |
625 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
626 | D_ASSERT(op.distinct_collection_info); |
627 | auto &info = *op.distinct_collection_info; |
628 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
629 | auto &grouping = op.groupings[i]; |
630 | auto &grouping_state = gstate.grouping_states[i]; |
631 | AggregateDistinctGrouping(info, grouping_data: grouping, grouping_state, grouping_idx: i); |
632 | } |
633 | event->FinishTask(); |
634 | return TaskExecutionResult::TASK_FINISHED; |
635 | } |
636 | |
637 | private: |
638 | Pipeline &pipeline; |
639 | shared_ptr<Event> event; |
640 | HashAggregateGlobalState &gstate; |
641 | ClientContext &context; |
642 | const PhysicalHashAggregate &op; |
643 | vector<vector<unique_ptr<GlobalSourceState>>> &global_sources; |
644 | }; |
645 | |
646 | //! DISTINCT FINALIZE EVENT |
647 | |
648 | // TODO: Create tasks and run these in parallel instead of doing this all in Schedule, single threaded |
649 | class HashDistinctAggregateFinalizeEvent : public BasePipelineEvent { |
650 | public: |
651 | HashDistinctAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
652 | Pipeline &pipeline_p, ClientContext &context) |
653 | : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) { |
654 | } |
655 | const PhysicalHashAggregate &op; |
656 | HashAggregateGlobalState &gstate; |
657 | ClientContext &context; |
658 | //! The GlobalSourceStates for all the radix tables of the distinct aggregates |
659 | vector<vector<unique_ptr<GlobalSourceState>>> global_sources; |
660 | |
661 | public: |
662 | void Schedule() override { |
663 | global_sources = CreateGlobalSources(); |
664 | |
665 | vector<shared_ptr<Task>> tasks; |
666 | auto &scheduler = TaskScheduler::GetScheduler(context); |
667 | auto number_of_threads = scheduler.NumberOfThreads(); |
668 | tasks.reserve(n: number_of_threads); |
669 | for (int32_t i = 0; i < number_of_threads; i++) { |
670 | tasks.push_back(x: make_uniq<HashDistinctAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context, |
671 | args: op, args&: global_sources)); |
672 | } |
673 | D_ASSERT(!tasks.empty()); |
674 | SetTasks(std::move(tasks)); |
675 | } |
676 | |
677 | void FinishEvent() override { |
678 | //! Now that everything is added to the main ht, we can actually finalize |
679 | auto new_event = make_shared<HashAggregateFinalizeEvent>(args: op, args&: gstate, args: pipeline.get(), args&: context); |
680 | this->InsertEvent(replacement_event: std::move(new_event)); |
681 | } |
682 | |
683 | private: |
684 | vector<vector<unique_ptr<GlobalSourceState>>> CreateGlobalSources() { |
685 | vector<vector<unique_ptr<GlobalSourceState>>> grouping_sources; |
686 | grouping_sources.reserve(n: op.groupings.size()); |
687 | for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { |
688 | auto &grouping = op.groupings[grouping_idx]; |
689 | auto &data = *grouping.distinct_data; |
690 | |
691 | vector<unique_ptr<GlobalSourceState>> aggregate_sources; |
692 | aggregate_sources.reserve(n: op.grouped_aggregate_data.aggregates.size()); |
693 | |
694 | for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) { |
695 | auto &aggregate = op.grouped_aggregate_data.aggregates[i]; |
696 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
697 | |
698 | if (!aggr.IsDistinct()) { |
699 | aggregate_sources.push_back(x: nullptr); |
700 | continue; |
701 | } |
702 | |
703 | D_ASSERT(data.info.table_map.count(i)); |
704 | auto table_idx = data.info.table_map.at(k: i); |
705 | auto &radix_table_p = data.radix_tables[table_idx]; |
706 | aggregate_sources.push_back(x: radix_table_p->GetGlobalSourceState(context)); |
707 | } |
708 | grouping_sources.push_back(x: std::move(aggregate_sources)); |
709 | } |
710 | return grouping_sources; |
711 | } |
712 | }; |
713 | |
714 | //! DISTINCT COMBINE EVENT |
715 | |
716 | class HashDistinctCombineFinalizeEvent : public BasePipelineEvent { |
717 | public: |
718 | HashDistinctCombineFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
719 | Pipeline &pipeline_p, ClientContext &client) |
720 | : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), client(client) { |
721 | } |
722 | |
723 | const PhysicalHashAggregate &op; |
724 | HashAggregateGlobalState &gstate; |
725 | ClientContext &client; |
726 | |
727 | public: |
728 | void Schedule() override { |
729 | vector<shared_ptr<Task>> tasks; |
730 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
731 | auto &grouping = op.groupings[i]; |
732 | auto &distinct_data = *grouping.distinct_data; |
733 | auto &distinct_state = *gstate.grouping_states[i].distinct_state; |
734 | for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { |
735 | if (!distinct_data.radix_tables[table_idx]) { |
736 | continue; |
737 | } |
738 | distinct_data.radix_tables[table_idx]->ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(), |
739 | state&: *distinct_state.radix_states[table_idx], tasks); |
740 | } |
741 | } |
742 | |
743 | D_ASSERT(!tasks.empty()); |
744 | SetTasks(std::move(tasks)); |
745 | } |
746 | |
747 | void FinishEvent() override { |
748 | //! Now that all tables are combined, it's time to do the distinct aggregations |
749 | auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: op, args&: gstate, args&: *pipeline, args&: client); |
750 | this->InsertEvent(replacement_event: std::move(new_event)); |
751 | } |
752 | }; |
753 | |
754 | //! FINALIZE |
755 | |
756 | SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, |
757 | GlobalSinkState &gstate_p) const { |
758 | auto &gstate = gstate_p.Cast<HashAggregateGlobalState>(); |
759 | D_ASSERT(distinct_collection_info); |
760 | |
761 | bool any_partitioned = false; |
762 | for (idx_t i = 0; i < groupings.size(); i++) { |
763 | auto &grouping = groupings[i]; |
764 | auto &distinct_data = *grouping.distinct_data; |
765 | auto &distinct_state = *gstate.grouping_states[i].distinct_state; |
766 | |
767 | for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { |
768 | if (!distinct_data.radix_tables[table_idx]) { |
769 | continue; |
770 | } |
771 | auto &radix_table = distinct_data.radix_tables[table_idx]; |
772 | auto &radix_state = *distinct_state.radix_states[table_idx]; |
773 | bool partitioned = radix_table->Finalize(context, gstate_p&: radix_state); |
774 | if (partitioned) { |
775 | any_partitioned = true; |
776 | } |
777 | } |
778 | } |
779 | if (any_partitioned) { |
780 | // If any of the groupings are partitioned then we first need to combine those, then aggregate |
781 | auto new_event = make_shared<HashDistinctCombineFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context); |
782 | event.InsertEvent(replacement_event: std::move(new_event)); |
783 | } else { |
784 | // Hashtables aren't partitioned, they dont need to be joined first |
785 | // so we can already compute the aggregate |
786 | auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context); |
787 | event.InsertEvent(replacement_event: std::move(new_event)); |
788 | } |
789 | return SinkFinalizeType::READY; |
790 | } |
791 | |
792 | SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, |
793 | GlobalSinkState &gstate_p, bool check_distinct) const { |
794 | auto &gstate = gstate_p.Cast<HashAggregateGlobalState>(); |
795 | |
796 | if (check_distinct && distinct_collection_info) { |
797 | // There are distinct aggregates |
798 | // If these are partitioned those need to be combined first |
799 | // Then we Finalize again, skipping this step |
800 | return FinalizeDistinct(pipeline, event, context, gstate_p); |
801 | } |
802 | |
803 | bool any_partitioned = false; |
804 | for (idx_t i = 0; i < groupings.size(); i++) { |
805 | auto &grouping = groupings[i]; |
806 | auto &grouping_gstate = gstate.grouping_states[i]; |
807 | |
808 | bool is_partitioned = grouping.table_data.Finalize(context, gstate_p&: *grouping_gstate.table_state); |
809 | if (is_partitioned) { |
810 | any_partitioned = true; |
811 | } |
812 | } |
813 | if (any_partitioned) { |
814 | auto new_event = make_shared<HashAggregateMergeEvent>(args: *this, args&: gstate, args: &pipeline); |
815 | event.InsertEvent(replacement_event: std::move(new_event)); |
816 | } |
817 | return SinkFinalizeType::READY; |
818 | } |
819 | |
820 | SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, |
821 | GlobalSinkState &gstate_p) const { |
822 | return FinalizeInternal(pipeline, event, context, gstate_p, check_distinct: true); |
823 | } |
824 | |
825 | //===--------------------------------------------------------------------===// |
826 | // Source |
827 | //===--------------------------------------------------------------------===// |
828 | class PhysicalHashAggregateGlobalSourceState : public GlobalSourceState { |
829 | public: |
830 | PhysicalHashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) |
831 | : op(op), state_index(0) { |
832 | for (auto &grouping : op.groupings) { |
833 | auto &rt = grouping.table_data; |
834 | radix_states.push_back(x: rt.GetGlobalSourceState(context)); |
835 | } |
836 | } |
837 | |
838 | const PhysicalHashAggregate &op; |
839 | mutex lock; |
840 | atomic<idx_t> state_index; |
841 | |
842 | vector<unique_ptr<GlobalSourceState>> radix_states; |
843 | |
844 | public: |
845 | idx_t MaxThreads() override { |
846 | // If there are no tables, we only need one thread. |
847 | if (op.groupings.empty()) { |
848 | return 1; |
849 | } |
850 | |
851 | auto &ht_state = op.sink_state->Cast<HashAggregateGlobalState>(); |
852 | idx_t count = 0; |
853 | for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) { |
854 | auto &grouping = op.groupings[sidx]; |
855 | auto &grouping_gstate = ht_state.grouping_states[sidx]; |
856 | count += grouping.table_data.Size(sink_state&: *grouping_gstate.table_state); |
857 | } |
858 | return MaxValue<idx_t>(a: 1, b: count / STANDARD_VECTOR_SIZE); |
859 | } |
860 | }; |
861 | |
862 | unique_ptr<GlobalSourceState> PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { |
863 | return make_uniq<PhysicalHashAggregateGlobalSourceState>(args&: context, args: *this); |
864 | } |
865 | |
866 | class PhysicalHashAggregateLocalSourceState : public LocalSourceState { |
867 | public: |
868 | explicit PhysicalHashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { |
869 | for (auto &grouping : op.groupings) { |
870 | auto &rt = grouping.table_data; |
871 | radix_states.push_back(x: rt.GetLocalSourceState(context)); |
872 | } |
873 | } |
874 | |
875 | vector<unique_ptr<LocalSourceState>> radix_states; |
876 | }; |
877 | |
878 | unique_ptr<LocalSourceState> PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, |
879 | GlobalSourceState &gstate) const { |
880 | return make_uniq<PhysicalHashAggregateLocalSourceState>(args&: context, args: *this); |
881 | } |
882 | |
883 | SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, |
884 | OperatorSourceInput &input) const { |
885 | auto &sink_gstate = sink_state->Cast<HashAggregateGlobalState>(); |
886 | auto &gstate = input.global_state.Cast<PhysicalHashAggregateGlobalSourceState>(); |
887 | auto &lstate = input.local_state.Cast<PhysicalHashAggregateLocalSourceState>(); |
888 | while (true) { |
889 | idx_t radix_idx = gstate.state_index; |
890 | if (radix_idx >= groupings.size()) { |
891 | break; |
892 | } |
893 | auto &grouping = groupings[radix_idx]; |
894 | auto &radix_table = grouping.table_data; |
895 | auto &grouping_gstate = sink_gstate.grouping_states[radix_idx]; |
896 | |
897 | InterruptState interrupt_state; |
898 | OperatorSourceInput source_input {.global_state: *gstate.radix_states[radix_idx], .local_state: *lstate.radix_states[radix_idx], |
899 | .interrupt_state: interrupt_state}; |
900 | auto res = radix_table.GetData(context, chunk, sink_state&: *grouping_gstate.table_state, input&: source_input); |
901 | if (chunk.size() != 0) { |
902 | return SourceResultType::HAVE_MORE_OUTPUT; |
903 | } else if (res == SourceResultType::BLOCKED) { |
904 | throw InternalException("Unexpectedly Blocked from radix_table" ); |
905 | } |
906 | |
907 | // move to the next table |
908 | lock_guard<mutex> l(gstate.lock); |
909 | radix_idx++; |
910 | if (radix_idx > gstate.state_index) { |
911 | // we have not yet worked on the table |
912 | // move the global index forwards |
913 | gstate.state_index = radix_idx; |
914 | } |
915 | } |
916 | |
917 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
918 | } |
919 | |
920 | string PhysicalHashAggregate::ParamsToString() const { |
921 | string result; |
922 | auto &groups = grouped_aggregate_data.groups; |
923 | auto &aggregates = grouped_aggregate_data.aggregates; |
924 | for (idx_t i = 0; i < groups.size(); i++) { |
925 | if (i > 0) { |
926 | result += "\n" ; |
927 | } |
928 | result += groups[i]->GetName(); |
929 | } |
930 | for (idx_t i = 0; i < aggregates.size(); i++) { |
931 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
932 | if (i > 0 || !groups.empty()) { |
933 | result += "\n" ; |
934 | } |
935 | result += aggregates[i]->GetName(); |
936 | if (aggregate.filter) { |
937 | result += " Filter: " + aggregate.filter->GetName(); |
938 | } |
939 | } |
940 | return result; |
941 | } |
942 | |
943 | } // namespace duckdb |
944 | |