1#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp"
2
3#include "duckdb/execution/perfect_aggregate_hashtable.hpp"
4#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
5#include "duckdb/planner/expression/bound_reference_expression.hpp"
6#include "duckdb/storage/buffer_manager.hpp"
7
8namespace duckdb {
9
10PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &context, vector<LogicalType> types_p,
11 vector<unique_ptr<Expression>> aggregates_p,
12 vector<unique_ptr<Expression>> groups_p,
13 const vector<unique_ptr<BaseStatistics>> &group_stats,
14 vector<idx_t> required_bits_p, idx_t estimated_cardinality)
15 : PhysicalOperator(PhysicalOperatorType::PERFECT_HASH_GROUP_BY, std::move(types_p), estimated_cardinality),
16 groups(std::move(groups_p)), aggregates(std::move(aggregates_p)), required_bits(std::move(required_bits_p)) {
17 D_ASSERT(groups.size() == group_stats.size());
18 group_minima.reserve(n: group_stats.size());
19 for (auto &stats : group_stats) {
20 D_ASSERT(stats);
21 auto &nstats = *stats;
22 D_ASSERT(NumericStats::HasMin(nstats));
23 group_minima.push_back(x: NumericStats::Min(stats: nstats));
24 }
25 for (auto &expr : groups) {
26 group_types.push_back(x: expr->return_type);
27 }
28
29 vector<BoundAggregateExpression *> bindings;
30 vector<LogicalType> payload_types_filters;
31 for (auto &expr : aggregates) {
32 D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE);
33 D_ASSERT(expr->IsAggregate());
34 auto &aggr = expr->Cast<BoundAggregateExpression>();
35 bindings.push_back(x: &aggr);
36
37 D_ASSERT(!aggr.IsDistinct());
38 D_ASSERT(aggr.function.combine);
39 for (auto &child : aggr.children) {
40 payload_types.push_back(x: child->return_type);
41 }
42 if (aggr.filter) {
43 payload_types_filters.push_back(x: aggr.filter->return_type);
44 }
45 }
46 for (const auto &pay_filters : payload_types_filters) {
47 payload_types.push_back(x: pay_filters);
48 }
49 aggregate_objects = AggregateObject::CreateAggregateObjects(bindings);
50
51 // filter_indexes must be pre-built, not lazily instantiated in parallel...
52 idx_t aggregate_input_idx = 0;
53 for (auto &aggregate : aggregates) {
54 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
55 aggregate_input_idx += aggr.children.size();
56 }
57 for (auto &aggregate : aggregates) {
58 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
59 if (aggr.filter) {
60 auto &bound_ref_expr = aggr.filter->Cast<BoundReferenceExpression>();
61 auto it = filter_indexes.find(x: aggr.filter.get());
62 if (it == filter_indexes.end()) {
63 filter_indexes[aggr.filter.get()] = bound_ref_expr.index;
64 bound_ref_expr.index = aggregate_input_idx++;
65 } else {
66 ++aggregate_input_idx;
67 }
68 }
69 }
70}
71
72unique_ptr<PerfectAggregateHashTable> PhysicalPerfectHashAggregate::CreateHT(Allocator &allocator,
73 ClientContext &context) const {
74 return make_uniq<PerfectAggregateHashTable>(args&: context, args&: allocator, args: group_types, args: payload_types, args: aggregate_objects,
75 args: group_minima, args: required_bits);
76}
77
78//===--------------------------------------------------------------------===//
79// Sink
80//===--------------------------------------------------------------------===//
81class PerfectHashAggregateGlobalState : public GlobalSinkState {
82public:
83 PerfectHashAggregateGlobalState(const PhysicalPerfectHashAggregate &op, ClientContext &context)
84 : ht(op.CreateHT(allocator&: Allocator::Get(context), context)) {
85 }
86
87 //! The lock for updating the global aggregate state
88 mutex lock;
89 //! The global aggregate hash table
90 unique_ptr<PerfectAggregateHashTable> ht;
91};
92
93class PerfectHashAggregateLocalState : public LocalSinkState {
94public:
95 PerfectHashAggregateLocalState(const PhysicalPerfectHashAggregate &op, ExecutionContext &context)
96 : ht(op.CreateHT(allocator&: Allocator::Get(context&: context.client), context&: context.client)) {
97 group_chunk.InitializeEmpty(types: op.group_types);
98 if (!op.payload_types.empty()) {
99 aggregate_input_chunk.InitializeEmpty(types: op.payload_types);
100 }
101 }
102
103 //! The local aggregate hash table
104 unique_ptr<PerfectAggregateHashTable> ht;
105 DataChunk group_chunk;
106 DataChunk aggregate_input_chunk;
107};
108
109unique_ptr<GlobalSinkState> PhysicalPerfectHashAggregate::GetGlobalSinkState(ClientContext &context) const {
110 return make_uniq<PerfectHashAggregateGlobalState>(args: *this, args&: context);
111}
112
113unique_ptr<LocalSinkState> PhysicalPerfectHashAggregate::GetLocalSinkState(ExecutionContext &context) const {
114 return make_uniq<PerfectHashAggregateLocalState>(args: *this, args&: context);
115}
116
117SinkResultType PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk,
118 OperatorSinkInput &input) const {
119 auto &lstate = input.local_state.Cast<PerfectHashAggregateLocalState>();
120 DataChunk &group_chunk = lstate.group_chunk;
121 DataChunk &aggregate_input_chunk = lstate.aggregate_input_chunk;
122
123 for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) {
124 auto &group = groups[group_idx];
125 D_ASSERT(group->type == ExpressionType::BOUND_REF);
126 auto &bound_ref_expr = group->Cast<BoundReferenceExpression>();
127 group_chunk.data[group_idx].Reference(other&: chunk.data[bound_ref_expr.index]);
128 }
129 idx_t aggregate_input_idx = 0;
130 for (auto &aggregate : aggregates) {
131 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
132 for (auto &child_expr : aggr.children) {
133 D_ASSERT(child_expr->type == ExpressionType::BOUND_REF);
134 auto &bound_ref_expr = child_expr->Cast<BoundReferenceExpression>();
135 aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[bound_ref_expr.index]);
136 }
137 }
138 for (auto &aggregate : aggregates) {
139 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
140 if (aggr.filter) {
141 auto it = filter_indexes.find(x: aggr.filter.get());
142 D_ASSERT(it != filter_indexes.end());
143 aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[it->second]);
144 }
145 }
146
147 group_chunk.SetCardinality(chunk.size());
148
149 aggregate_input_chunk.SetCardinality(chunk.size());
150
151 group_chunk.Verify();
152 aggregate_input_chunk.Verify();
153 D_ASSERT(aggregate_input_chunk.ColumnCount() == 0 || group_chunk.size() == aggregate_input_chunk.size());
154
155 lstate.ht->AddChunk(groups&: group_chunk, payload&: aggregate_input_chunk);
156 return SinkResultType::NEED_MORE_INPUT;
157}
158
159//===--------------------------------------------------------------------===//
160// Combine
161//===--------------------------------------------------------------------===//
162void PhysicalPerfectHashAggregate::Combine(ExecutionContext &context, GlobalSinkState &gstate_p,
163 LocalSinkState &lstate_p) const {
164 auto &lstate = lstate_p.Cast<PerfectHashAggregateLocalState>();
165 auto &gstate = gstate_p.Cast<PerfectHashAggregateGlobalState>();
166
167 lock_guard<mutex> l(gstate.lock);
168 gstate.ht->Combine(other&: *lstate.ht);
169}
170
171//===--------------------------------------------------------------------===//
172// Source
173//===--------------------------------------------------------------------===//
174class PerfectHashAggregateState : public GlobalSourceState {
175public:
176 PerfectHashAggregateState() : ht_scan_position(0) {
177 }
178
179 //! The current position to scan the HT for output tuples
180 idx_t ht_scan_position;
181};
182
183unique_ptr<GlobalSourceState> PhysicalPerfectHashAggregate::GetGlobalSourceState(ClientContext &context) const {
184 return make_uniq<PerfectHashAggregateState>();
185}
186
187SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk,
188 OperatorSourceInput &input) const {
189 auto &state = input.global_state.Cast<PerfectHashAggregateState>();
190 auto &gstate = sink_state->Cast<PerfectHashAggregateGlobalState>();
191
192 gstate.ht->Scan(scan_position&: state.ht_scan_position, result&: chunk);
193
194 if (chunk.size() > 0) {
195 return SourceResultType::HAVE_MORE_OUTPUT;
196 } else {
197 return SourceResultType::FINISHED;
198 }
199}
200
201string PhysicalPerfectHashAggregate::ParamsToString() const {
202 string result;
203 for (idx_t i = 0; i < groups.size(); i++) {
204 if (i > 0) {
205 result += "\n";
206 }
207 result += groups[i]->GetName();
208 }
209 for (idx_t i = 0; i < aggregates.size(); i++) {
210 if (i > 0 || !groups.empty()) {
211 result += "\n";
212 }
213 result += aggregates[i]->GetName();
214 auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>();
215 if (aggregate.filter) {
216 result += " Filter: " + aggregate.filter->GetName();
217 }
218 }
219 return result;
220}
221
222} // namespace duckdb
223