1 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
2 | #include "duckdb/common/pair.hpp" |
3 | #include "duckdb/common/operator/cast_operators.hpp" |
4 | #include "duckdb/parser/expression/function_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
8 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
9 | #include "duckdb/planner/expression_binder/aggregate_binder.hpp" |
10 | #include "duckdb/planner/expression_binder/base_select_binder.hpp" |
11 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
12 | #include "duckdb/execution/expression_executor.hpp" |
13 | #include "duckdb/function/scalar/generic_functions.hpp" |
14 | #include "duckdb/main/config.hpp" |
15 | #include "duckdb/function/function_binder.hpp" |
16 | #include "duckdb/planner/binder.hpp" |
17 | |
18 | namespace duckdb { |
19 | |
20 | static Value NegatePercentileValue(const Value &v, const bool desc) { |
21 | if (v.IsNull()) { |
22 | return v; |
23 | } |
24 | |
25 | const auto frac = v.GetValue<double>(); |
26 | if (frac < 0 || frac > 1) { |
27 | throw BinderException("PERCENTILEs can only take parameters in the range [0, 1]" ); |
28 | } |
29 | |
30 | if (!desc) { |
31 | return v; |
32 | } |
33 | |
34 | const auto &type = v.type(); |
35 | switch (type.id()) { |
36 | case LogicalTypeId::DECIMAL: { |
37 | // Negate DECIMALs as DECIMAL. |
38 | const auto integral = IntegralValue::Get(value: v); |
39 | const auto width = DecimalType::GetWidth(type); |
40 | const auto scale = DecimalType::GetScale(type); |
41 | switch (type.InternalType()) { |
42 | case PhysicalType::INT16: |
43 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int16_t>(input: -integral), width, scale); |
44 | case PhysicalType::INT32: |
45 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int32_t>(input: -integral), width, scale); |
46 | case PhysicalType::INT64: |
47 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int64_t>(input: -integral), width, scale); |
48 | case PhysicalType::INT128: |
49 | return Value::DECIMAL(value: -integral, width, scale); |
50 | default: |
51 | throw InternalException("Unknown DECIMAL type" ); |
52 | } |
53 | } |
54 | default: |
55 | // Everything else can just be a DOUBLE |
56 | return Value::DOUBLE(value: -v.GetValue<double>()); |
57 | } |
58 | } |
59 | |
60 | static void NegatePercentileFractions(ClientContext &context, unique_ptr<ParsedExpression> &fractions, bool desc) { |
61 | D_ASSERT(fractions.get()); |
62 | D_ASSERT(fractions->expression_class == ExpressionClass::BOUND_EXPRESSION); |
63 | auto &bound = BoundExpression::GetExpression(expr&: *fractions); |
64 | |
65 | if (!bound->IsFoldable()) { |
66 | return; |
67 | } |
68 | |
69 | Value value = ExpressionExecutor::EvaluateScalar(context, expr: *bound); |
70 | if (value.type().id() == LogicalTypeId::LIST) { |
71 | vector<Value> values; |
72 | for (const auto &element_val : ListValue::GetChildren(value)) { |
73 | values.push_back(x: NegatePercentileValue(v: element_val, desc)); |
74 | } |
75 | if (values.empty()) { |
76 | throw BinderException("Empty list in percentile not allowed" ); |
77 | } |
78 | bound = make_uniq<BoundConstantExpression>(args: Value::LIST(values)); |
79 | } else { |
80 | bound = make_uniq<BoundConstantExpression>(args: NegatePercentileValue(v: value, desc)); |
81 | } |
82 | } |
83 | |
84 | BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFunctionCatalogEntry &func, idx_t depth) { |
85 | // first bind the child of the aggregate expression (if any) |
86 | this->bound_aggregate = true; |
87 | unique_ptr<Expression> bound_filter; |
88 | AggregateBinder aggregate_binder(binder, context); |
89 | string error, filter_error; |
90 | |
91 | // Now we bind the filter (if any) |
92 | if (aggr.filter) { |
93 | aggregate_binder.BindChild(expr&: aggr.filter, depth: 0, error); |
94 | } |
95 | |
96 | // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children. |
97 | // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE |
98 | bool ordered_set_agg = false; |
99 | bool negate_fractions = false; |
100 | if (aggr.order_bys && aggr.order_bys->orders.size() == 1) { |
101 | const auto &func_name = aggr.function_name; |
102 | ordered_set_agg = (func_name == "quantile_cont" || func_name == "quantile_disc" || |
103 | (func_name == "mode" && aggr.children.empty())); |
104 | |
105 | if (ordered_set_agg) { |
106 | auto &config = DBConfig::GetConfig(context); |
107 | const auto &order = aggr.order_bys->orders[0]; |
108 | const auto sense = |
109 | (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type; |
110 | negate_fractions = (sense == OrderType::DESCENDING); |
111 | } |
112 | } |
113 | |
114 | for (auto &child : aggr.children) { |
115 | aggregate_binder.BindChild(expr&: child, depth: 0, error); |
116 | // We have to negate the fractions for PERCENTILE_XXXX DESC |
117 | if (error.empty() && ordered_set_agg) { |
118 | NegatePercentileFractions(context, fractions&: child, desc: negate_fractions); |
119 | } |
120 | } |
121 | |
122 | // Bind the ORDER BYs, if any |
123 | if (aggr.order_bys && !aggr.order_bys->orders.empty()) { |
124 | for (auto &order : aggr.order_bys->orders) { |
125 | aggregate_binder.BindChild(expr&: order.expression, depth: 0, error); |
126 | } |
127 | } |
128 | |
129 | if (!error.empty()) { |
130 | // failed to bind child |
131 | if (aggregate_binder.HasBoundColumns()) { |
132 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
133 | // however, we bound columns! |
134 | // that means this aggregation belongs to this node |
135 | // check if we have to resolve any errors by binding with parent binders |
136 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.children[i]); |
137 | // if there is still an error after this, we could not successfully bind the aggregate |
138 | if (!success) { |
139 | throw BinderException(error); |
140 | } |
141 | auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.children[i]); |
142 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
143 | } |
144 | if (aggr.filter) { |
145 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.filter); |
146 | // if there is still an error after this, we could not successfully bind the aggregate |
147 | if (!success) { |
148 | throw BinderException(error); |
149 | } |
150 | auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.filter); |
151 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
152 | } |
153 | if (aggr.order_bys && !aggr.order_bys->orders.empty()) { |
154 | for (auto &order : aggr.order_bys->orders) { |
155 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: order.expression); |
156 | if (!success) { |
157 | throw BinderException(error); |
158 | } |
159 | auto &bound_expr = BoundExpression::GetExpression(expr&: *order.expression); |
160 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
161 | } |
162 | } |
163 | } else { |
164 | // we didn't bind columns, try again in children |
165 | return BindResult(error); |
166 | } |
167 | } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) { |
168 | return BindResult("Aggregate with only constant parameters has to be bound in the root subquery" ); |
169 | } |
170 | if (!filter_error.empty()) { |
171 | return BindResult(filter_error); |
172 | } |
173 | |
174 | if (aggr.filter) { |
175 | auto &child = BoundExpression::GetExpression(expr&: *aggr.filter); |
176 | bound_filter = BoundCastExpression::AddCastToType(context, expr: std::move(child), target_type: LogicalType::BOOLEAN); |
177 | } |
178 | |
179 | // all children bound successfully |
180 | // extract the children and types |
181 | vector<LogicalType> types; |
182 | vector<LogicalType> arguments; |
183 | vector<unique_ptr<Expression>> children; |
184 | |
185 | if (ordered_set_agg) { |
186 | const bool order_sensitive = (aggr.function_name == "mode" ); |
187 | for (auto &order : aggr.order_bys->orders) { |
188 | auto &child = BoundExpression::GetExpression(expr&: *order.expression); |
189 | types.push_back(x: child->return_type); |
190 | arguments.push_back(x: child->return_type); |
191 | if (order_sensitive) { |
192 | children.push_back(x: child->Copy()); |
193 | } else { |
194 | children.push_back(x: std::move(child)); |
195 | } |
196 | } |
197 | if (!order_sensitive) { |
198 | aggr.order_bys->orders.clear(); |
199 | } |
200 | } |
201 | |
202 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
203 | auto &child = BoundExpression::GetExpression(expr&: *aggr.children[i]); |
204 | types.push_back(x: child->return_type); |
205 | arguments.push_back(x: child->return_type); |
206 | children.push_back(x: std::move(child)); |
207 | } |
208 | |
209 | // bind the aggregate |
210 | FunctionBinder function_binder(context); |
211 | idx_t best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error); |
212 | if (best_function == DConstants::INVALID_INDEX) { |
213 | throw BinderException(binder.FormatError(expr_context&: aggr, message: error)); |
214 | } |
215 | // found a matching function! |
216 | auto bound_function = func.functions.GetFunctionByOffset(offset: best_function); |
217 | |
218 | // Bind any sort columns, unless the aggregate is order-insensitive |
219 | unique_ptr<BoundOrderModifier> order_bys; |
220 | if (!aggr.order_bys->orders.empty()) { |
221 | order_bys = make_uniq<BoundOrderModifier>(); |
222 | auto &config = DBConfig::GetConfig(context); |
223 | for (auto &order : aggr.order_bys->orders) { |
224 | auto &order_expr = BoundExpression::GetExpression(expr&: *order.expression); |
225 | const auto sense = config.ResolveOrder(order_type: order.type); |
226 | const auto null_order = config.ResolveNullOrder(order_type: sense, null_type: order.null_order); |
227 | order_bys->orders.emplace_back(args: sense, args: null_order, args: std::move(order_expr)); |
228 | } |
229 | } |
230 | |
231 | auto aggregate = |
232 | function_binder.BindAggregateFunction(bound_function, children: std::move(children), filter: std::move(bound_filter), |
233 | aggr_type: aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); |
234 | if (aggr.export_state) { |
235 | aggregate = ExportAggregateFunction::Bind(child_aggregate: std::move(aggregate)); |
236 | } |
237 | aggregate->order_bys = std::move(order_bys); |
238 | |
239 | // check for all the aggregates if this aggregate already exists |
240 | idx_t aggr_index; |
241 | auto entry = node.aggregate_map.find(x: *aggregate); |
242 | if (entry == node.aggregate_map.end()) { |
243 | // new aggregate: insert into aggregate list |
244 | aggr_index = node.aggregates.size(); |
245 | node.aggregate_map[*aggregate] = aggr_index; |
246 | node.aggregates.push_back(x: std::move(aggregate)); |
247 | } else { |
248 | // duplicate aggregate: simplify refer to this aggregate |
249 | aggr_index = entry->second; |
250 | } |
251 | |
252 | // now create a column reference referring to the aggregate |
253 | auto colref = make_uniq<BoundColumnRefExpression>( |
254 | args: aggr.alias.empty() ? node.aggregates[aggr_index]->ToString() : aggr.alias, |
255 | args&: node.aggregates[aggr_index]->return_type, args: ColumnBinding(node.aggregate_index, aggr_index), args&: depth); |
256 | // move the aggregate expression into the set of bound aggregates |
257 | return BindResult(std::move(colref)); |
258 | } |
259 | } // namespace duckdb |
260 | |