1#include "duckdb/execution/operator/aggregate/physical_simple_aggregate.hpp"
2
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/execution/expression_executor.hpp"
5#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
6#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
7
8using namespace duckdb;
9using namespace std;
10
11class PhysicalSimpleAggregateOperatorState : public PhysicalOperatorState {
12public:
13 PhysicalSimpleAggregateOperatorState(PhysicalSimpleAggregate *parent, PhysicalOperator *child);
14 ~PhysicalSimpleAggregateOperatorState() {
15 assert(destructors.size() == aggregates.size());
16 for (idx_t i = 0; i < destructors.size(); i++) {
17 if (!destructors[i]) {
18 continue;
19 }
20 Vector state_vector(Value::POINTER((uintptr_t)aggregates[i].get()));
21 state_vector.vector_type = VectorType::FLAT_VECTOR;
22
23 destructors[i](state_vector, 1);
24 }
25 }
26
27 //! The aggregate values
28 vector<unique_ptr<data_t[]>> aggregates;
29
30 vector<aggregate_destructor_t> destructors;
31
32 ExpressionExecutor child_executor;
33 //! The payload chunk
34 DataChunk payload_chunk;
35};
36
37PhysicalSimpleAggregate::PhysicalSimpleAggregate(vector<TypeId> types, vector<unique_ptr<Expression>> expressions)
38 : PhysicalOperator(PhysicalOperatorType::SIMPLE_AGGREGATE, types), aggregates(move(expressions)) {
39}
40
41void PhysicalSimpleAggregate::GetChunkInternal(ClientContext &context, DataChunk &chunk,
42 PhysicalOperatorState *state_) {
43 auto state = reinterpret_cast<PhysicalSimpleAggregateOperatorState *>(state_);
44 while (true) {
45 // iterate over the child
46 children[0]->GetChunk(context, state->child_chunk, state->child_state.get());
47 if (state->child_chunk.size() == 0) {
48 break;
49 }
50
51 // now resolve the aggregates for each of the children
52 idx_t payload_idx = 0, payload_expr_idx = 0;
53 DataChunk &payload_chunk = state->payload_chunk;
54 payload_chunk.Reset();
55 state->child_executor.SetChunk(state->child_chunk);
56 payload_chunk.SetCardinality(state->child_chunk);
57 for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) {
58 auto &aggregate = (BoundAggregateExpression &)*aggregates[aggr_idx];
59 idx_t payload_cnt = 0;
60 // resolve the child expression of the aggregate (if any)
61 if (aggregate.children.size() > 0) {
62 for (idx_t i = 0; i < aggregate.children.size(); ++i) {
63 state->child_executor.ExecuteExpression(payload_expr_idx,
64 payload_chunk.data[payload_idx + payload_cnt]);
65 payload_expr_idx++;
66 payload_cnt++;
67 }
68 } else {
69 payload_cnt++;
70 }
71 // perform the actual aggregation
72 aggregate.function.simple_update(&payload_chunk.data[payload_idx], payload_cnt,
73 state->aggregates[aggr_idx].get(), payload_chunk.size());
74 payload_idx += payload_cnt;
75 }
76 }
77 // initialize the result chunk with the aggregate values
78 chunk.SetCardinality(1);
79 for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) {
80 auto &aggregate = (BoundAggregateExpression &)*aggregates[aggr_idx];
81
82 Vector state_vector(Value::POINTER((uintptr_t)state->aggregates[aggr_idx].get()));
83 aggregate.function.finalize(state_vector, chunk.data[aggr_idx], 1);
84 }
85 state->finished = true;
86}
87
88unique_ptr<PhysicalOperatorState> PhysicalSimpleAggregate::GetOperatorState() {
89 return make_unique<PhysicalSimpleAggregateOperatorState>(this, children[0].get());
90}
91
92PhysicalSimpleAggregateOperatorState::PhysicalSimpleAggregateOperatorState(PhysicalSimpleAggregate *parent,
93 PhysicalOperator *child)
94 : PhysicalOperatorState(child) {
95 vector<TypeId> payload_types;
96 for (auto &aggregate : parent->aggregates) {
97 assert(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE);
98 auto &aggr = (BoundAggregateExpression &)*aggregate;
99 // initialize the payload chunk
100 if (aggr.children.size()) {
101 for (idx_t i = 0; i < aggr.children.size(); ++i) {
102 payload_types.push_back(aggr.children[i]->return_type);
103 child_executor.AddExpression(*aggr.children[i]);
104 }
105 } else {
106 // COUNT(*)
107 payload_types.push_back(TypeId::INT64);
108 }
109 // initialize the aggregate values
110 auto state = unique_ptr<data_t[]>(new data_t[aggr.function.state_size()]);
111 aggr.function.initialize(state.get());
112 aggregates.push_back(move(state));
113 destructors.push_back(aggr.function.destructor);
114 }
115 payload_chunk.Initialize(payload_types);
116}
117