1#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
2#include "duckdb/common/operator/subtract.hpp"
3#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp"
4#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp"
5#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp"
6#include "duckdb/execution/operator/projection/physical_projection.hpp"
7#include "duckdb/execution/physical_plan_generator.hpp"
8#include "duckdb/main/client_context.hpp"
9#include "duckdb/parser/expression/comparison_expression.hpp"
10#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
11#include "duckdb/planner/operator/logical_aggregate.hpp"
12#include "duckdb/function/function_binder.hpp"
13#include "duckdb/planner/expression/bound_reference_expression.hpp"
14
15namespace duckdb {
16
17static uint32_t RequiredBitsForValue(uint32_t n) {
18 idx_t required_bits = 0;
19 while (n > 0) {
20 n >>= 1;
21 required_bits++;
22 }
23 return required_bits;
24}
25
26static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate &op, vector<idx_t> &bits_per_group) {
27 if (op.grouping_sets.size() > 1 || !op.grouping_functions.empty()) {
28 return false;
29 }
30 idx_t perfect_hash_bits = 0;
31 if (op.group_stats.empty()) {
32 op.group_stats.resize(new_size: op.groups.size());
33 }
34 for (idx_t group_idx = 0; group_idx < op.groups.size(); group_idx++) {
35 auto &group = op.groups[group_idx];
36 auto &stats = op.group_stats[group_idx];
37
38 switch (group->return_type.InternalType()) {
39 case PhysicalType::INT8:
40 case PhysicalType::INT16:
41 case PhysicalType::INT32:
42 case PhysicalType::INT64:
43 break;
44 default:
45 // we only support simple integer types for perfect hashing
46 return false;
47 }
48 // check if the group has stats available
49 auto &group_type = group->return_type;
50 if (!stats) {
51 // no stats, but we might still be able to use perfect hashing if the type is small enough
52 // for small types we can just set the stats to [type_min, type_max]
53 switch (group_type.InternalType()) {
54 case PhysicalType::INT8:
55 case PhysicalType::INT16:
56 break;
57 default:
58 // type is too large and there are no stats: skip perfect hashing
59 return false;
60 }
61 // construct stats with the min and max value of the type
62 stats = NumericStats::CreateUnknown(type: group_type).ToUnique();
63 NumericStats::SetMin(stats&: *stats, val: Value::MinimumValue(type: group_type));
64 NumericStats::SetMax(stats&: *stats, val: Value::MaximumValue(type: group_type));
65 }
66 auto &nstats = *stats;
67
68 if (!NumericStats::HasMinMax(stats: nstats)) {
69 return false;
70 }
71 // we have a min and a max value for the stats: use that to figure out how many bits we have
72 // we add two here, one for the NULL value, and one to make the computation one-indexed
73 // (e.g. if min and max are the same, we still need one entry in total)
74 int64_t range;
75 switch (group_type.InternalType()) {
76 case PhysicalType::INT8:
77 range = int64_t(NumericStats::GetMax<int8_t>(stats: nstats)) - int64_t(NumericStats::GetMin<int8_t>(stats: nstats));
78 break;
79 case PhysicalType::INT16:
80 range = int64_t(NumericStats::GetMax<int16_t>(stats: nstats)) - int64_t(NumericStats::GetMin<int16_t>(stats: nstats));
81 break;
82 case PhysicalType::INT32:
83 range = int64_t(NumericStats::GetMax<int32_t>(stats: nstats)) - int64_t(NumericStats::GetMin<int32_t>(stats: nstats));
84 break;
85 case PhysicalType::INT64:
86 if (!TrySubtractOperator::Operation(left: NumericStats::GetMax<int64_t>(stats: nstats),
87 right: NumericStats::GetMin<int64_t>(stats: nstats), result&: range)) {
88 return false;
89 }
90 break;
91 default:
92 throw InternalException("Unsupported type for perfect hash (should be caught before)");
93 }
94 // bail out on any range bigger than 2^32
95 if (range >= NumericLimits<int32_t>::Maximum()) {
96 return false;
97 }
98 range += 2;
99 // figure out how many bits we need
100 idx_t required_bits = RequiredBitsForValue(n: range);
101 bits_per_group.push_back(x: required_bits);
102 perfect_hash_bits += required_bits;
103 // check if we have exceeded the bits for the hash
104 if (perfect_hash_bits > ClientConfig::GetConfig(context).perfect_ht_threshold) {
105 // too many bits for perfect hash
106 return false;
107 }
108 }
109 for (auto &expression : op.expressions) {
110 auto &aggregate = expression->Cast<BoundAggregateExpression>();
111 if (aggregate.IsDistinct() || !aggregate.function.combine) {
112 // distinct aggregates are not supported in perfect hash aggregates
113 return false;
114 }
115 }
116 return true;
117}
118
119unique_ptr<PhysicalOperator> PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) {
120 unique_ptr<PhysicalOperator> groupby;
121 D_ASSERT(op.children.size() == 1);
122
123 auto plan = CreatePlan(op&: *op.children[0]);
124
125 plan = ExtractAggregateExpressions(child: std::move(plan), expressions&: op.expressions, groups&: op.groups);
126
127 if (op.groups.empty() && op.grouping_sets.size() <= 1) {
128 // no groups, check if we can use a simple aggregation
129 // special case: aggregate entire columns together
130 bool use_simple_aggregation = true;
131 for (auto &expression : op.expressions) {
132 auto &aggregate = expression->Cast<BoundAggregateExpression>();
133 if (!aggregate.function.simple_update) {
134 // unsupported aggregate for simple aggregation: use hash aggregation
135 use_simple_aggregation = false;
136 break;
137 }
138 }
139 if (use_simple_aggregation) {
140 groupby = make_uniq_base<PhysicalOperator, PhysicalUngroupedAggregate>(args&: op.types, args: std::move(op.expressions),
141 args&: op.estimated_cardinality);
142 } else {
143 groupby = make_uniq_base<PhysicalOperator, PhysicalHashAggregate>(
144 args&: context, args&: op.types, args: std::move(op.expressions), args&: op.estimated_cardinality);
145 }
146 } else {
147 // groups! create a GROUP BY aggregator
148 // use a perfect hash aggregate if possible
149 vector<idx_t> required_bits;
150 if (CanUsePerfectHashAggregate(context, op, bits_per_group&: required_bits)) {
151 groupby = make_uniq_base<PhysicalOperator, PhysicalPerfectHashAggregate>(
152 args&: context, args&: op.types, args: std::move(op.expressions), args: std::move(op.groups), args: std::move(op.group_stats),
153 args: std::move(required_bits), args&: op.estimated_cardinality);
154 } else {
155 groupby = make_uniq_base<PhysicalOperator, PhysicalHashAggregate>(
156 args&: context, args&: op.types, args: std::move(op.expressions), args: std::move(op.groups), args: std::move(op.grouping_sets),
157 args: std::move(op.grouping_functions), args&: op.estimated_cardinality);
158 }
159 }
160 groupby->children.push_back(x: std::move(plan));
161 return groupby;
162}
163
164unique_ptr<PhysicalOperator>
165PhysicalPlanGenerator::ExtractAggregateExpressions(unique_ptr<PhysicalOperator> child,
166 vector<unique_ptr<Expression>> &aggregates,
167 vector<unique_ptr<Expression>> &groups) {
168 vector<unique_ptr<Expression>> expressions;
169 vector<LogicalType> types;
170
171 // bind sorted aggregates
172 for (auto &aggr : aggregates) {
173 auto &bound_aggr = aggr->Cast<BoundAggregateExpression>();
174 if (bound_aggr.order_bys) {
175 // sorted aggregate!
176 FunctionBinder::BindSortedAggregate(context, expr&: bound_aggr, groups);
177 }
178 }
179 for (auto &group : groups) {
180 auto ref = make_uniq<BoundReferenceExpression>(args&: group->return_type, args: expressions.size());
181 types.push_back(x: group->return_type);
182 expressions.push_back(x: std::move(group));
183 group = std::move(ref);
184 }
185 for (auto &aggr : aggregates) {
186 auto &bound_aggr = aggr->Cast<BoundAggregateExpression>();
187 for (auto &child : bound_aggr.children) {
188 auto ref = make_uniq<BoundReferenceExpression>(args&: child->return_type, args: expressions.size());
189 types.push_back(x: child->return_type);
190 expressions.push_back(x: std::move(child));
191 child = std::move(ref);
192 }
193 if (bound_aggr.filter) {
194 auto &filter = bound_aggr.filter;
195 auto ref = make_uniq<BoundReferenceExpression>(args&: filter->return_type, args: expressions.size());
196 types.push_back(x: filter->return_type);
197 expressions.push_back(x: std::move(filter));
198 bound_aggr.filter = std::move(ref);
199 }
200 }
201 if (expressions.empty()) {
202 return child;
203 }
204 auto projection =
205 make_uniq<PhysicalProjection>(args: std::move(types), args: std::move(expressions), args&: child->estimated_cardinality);
206 projection->children.push_back(x: std::move(child));
207 return std::move(projection);
208}
209
210} // namespace duckdb
211