1 | #include "duckdb/execution/operator/aggregate/physical_simple_aggregate.hpp" |
2 | |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/execution/expression_executor.hpp" |
5 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
6 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
7 | |
8 | using namespace duckdb; |
9 | using namespace std; |
10 | |
11 | class PhysicalSimpleAggregateOperatorState : public PhysicalOperatorState { |
12 | public: |
13 | PhysicalSimpleAggregateOperatorState(PhysicalSimpleAggregate *parent, PhysicalOperator *child); |
14 | ~PhysicalSimpleAggregateOperatorState() { |
15 | assert(destructors.size() == aggregates.size()); |
16 | for (idx_t i = 0; i < destructors.size(); i++) { |
17 | if (!destructors[i]) { |
18 | continue; |
19 | } |
20 | Vector state_vector(Value::POINTER((uintptr_t)aggregates[i].get())); |
21 | state_vector.vector_type = VectorType::FLAT_VECTOR; |
22 | |
23 | destructors[i](state_vector, 1); |
24 | } |
25 | } |
26 | |
27 | //! The aggregate values |
28 | vector<unique_ptr<data_t[]>> aggregates; |
29 | |
30 | vector<aggregate_destructor_t> destructors; |
31 | |
32 | ExpressionExecutor child_executor; |
33 | //! The payload chunk |
34 | DataChunk payload_chunk; |
35 | }; |
36 | |
37 | PhysicalSimpleAggregate::PhysicalSimpleAggregate(vector<TypeId> types, vector<unique_ptr<Expression>> expressions) |
38 | : PhysicalOperator(PhysicalOperatorType::SIMPLE_AGGREGATE, types), aggregates(move(expressions)) { |
39 | } |
40 | |
41 | void PhysicalSimpleAggregate::GetChunkInternal(ClientContext &context, DataChunk &chunk, |
42 | PhysicalOperatorState *state_) { |
43 | auto state = reinterpret_cast<PhysicalSimpleAggregateOperatorState *>(state_); |
44 | while (true) { |
45 | // iterate over the child |
46 | children[0]->GetChunk(context, state->child_chunk, state->child_state.get()); |
47 | if (state->child_chunk.size() == 0) { |
48 | break; |
49 | } |
50 | |
51 | // now resolve the aggregates for each of the children |
52 | idx_t payload_idx = 0, payload_expr_idx = 0; |
53 | DataChunk &payload_chunk = state->payload_chunk; |
54 | payload_chunk.Reset(); |
55 | state->child_executor.SetChunk(state->child_chunk); |
56 | payload_chunk.SetCardinality(state->child_chunk); |
57 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
58 | auto &aggregate = (BoundAggregateExpression &)*aggregates[aggr_idx]; |
59 | idx_t payload_cnt = 0; |
60 | // resolve the child expression of the aggregate (if any) |
61 | if (aggregate.children.size() > 0) { |
62 | for (idx_t i = 0; i < aggregate.children.size(); ++i) { |
63 | state->child_executor.ExecuteExpression(payload_expr_idx, |
64 | payload_chunk.data[payload_idx + payload_cnt]); |
65 | payload_expr_idx++; |
66 | payload_cnt++; |
67 | } |
68 | } else { |
69 | payload_cnt++; |
70 | } |
71 | // perform the actual aggregation |
72 | aggregate.function.simple_update(&payload_chunk.data[payload_idx], payload_cnt, |
73 | state->aggregates[aggr_idx].get(), payload_chunk.size()); |
74 | payload_idx += payload_cnt; |
75 | } |
76 | } |
77 | // initialize the result chunk with the aggregate values |
78 | chunk.SetCardinality(1); |
79 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
80 | auto &aggregate = (BoundAggregateExpression &)*aggregates[aggr_idx]; |
81 | |
82 | Vector state_vector(Value::POINTER((uintptr_t)state->aggregates[aggr_idx].get())); |
83 | aggregate.function.finalize(state_vector, chunk.data[aggr_idx], 1); |
84 | } |
85 | state->finished = true; |
86 | } |
87 | |
88 | unique_ptr<PhysicalOperatorState> PhysicalSimpleAggregate::GetOperatorState() { |
89 | return make_unique<PhysicalSimpleAggregateOperatorState>(this, children[0].get()); |
90 | } |
91 | |
92 | PhysicalSimpleAggregateOperatorState::PhysicalSimpleAggregateOperatorState(PhysicalSimpleAggregate *parent, |
93 | PhysicalOperator *child) |
94 | : PhysicalOperatorState(child) { |
95 | vector<TypeId> payload_types; |
96 | for (auto &aggregate : parent->aggregates) { |
97 | assert(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); |
98 | auto &aggr = (BoundAggregateExpression &)*aggregate; |
99 | // initialize the payload chunk |
100 | if (aggr.children.size()) { |
101 | for (idx_t i = 0; i < aggr.children.size(); ++i) { |
102 | payload_types.push_back(aggr.children[i]->return_type); |
103 | child_executor.AddExpression(*aggr.children[i]); |
104 | } |
105 | } else { |
106 | // COUNT(*) |
107 | payload_types.push_back(TypeId::INT64); |
108 | } |
109 | // initialize the aggregate values |
110 | auto state = unique_ptr<data_t[]>(new data_t[aggr.function.state_size()]); |
111 | aggr.function.initialize(state.get()); |
112 | aggregates.push_back(move(state)); |
113 | destructors.push_back(aggr.function.destructor); |
114 | } |
115 | payload_chunk.Initialize(payload_types); |
116 | } |
117 | |