1#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
2#include "duckdb/common/pair.hpp"
3#include "duckdb/common/operator/cast_operators.hpp"
4#include "duckdb/parser/expression/function_expression.hpp"
5#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
6#include "duckdb/planner/expression/bound_cast_expression.hpp"
7#include "duckdb/planner/expression/bound_columnref_expression.hpp"
8#include "duckdb/planner/expression/bound_constant_expression.hpp"
9#include "duckdb/planner/expression_binder/aggregate_binder.hpp"
10#include "duckdb/planner/expression_binder/base_select_binder.hpp"
11#include "duckdb/planner/query_node/bound_select_node.hpp"
12#include "duckdb/execution/expression_executor.hpp"
13#include "duckdb/function/scalar/generic_functions.hpp"
14#include "duckdb/main/config.hpp"
15#include "duckdb/function/function_binder.hpp"
16#include "duckdb/planner/binder.hpp"
17
18namespace duckdb {
19
20static Value NegatePercentileValue(const Value &v, const bool desc) {
21 if (v.IsNull()) {
22 return v;
23 }
24
25 const auto frac = v.GetValue<double>();
26 if (frac < 0 || frac > 1) {
27 throw BinderException("PERCENTILEs can only take parameters in the range [0, 1]");
28 }
29
30 if (!desc) {
31 return v;
32 }
33
34 const auto &type = v.type();
35 switch (type.id()) {
36 case LogicalTypeId::DECIMAL: {
37 // Negate DECIMALs as DECIMAL.
38 const auto integral = IntegralValue::Get(value: v);
39 const auto width = DecimalType::GetWidth(type);
40 const auto scale = DecimalType::GetScale(type);
41 switch (type.InternalType()) {
42 case PhysicalType::INT16:
43 return Value::DECIMAL(value: Cast::Operation<hugeint_t, int16_t>(input: -integral), width, scale);
44 case PhysicalType::INT32:
45 return Value::DECIMAL(value: Cast::Operation<hugeint_t, int32_t>(input: -integral), width, scale);
46 case PhysicalType::INT64:
47 return Value::DECIMAL(value: Cast::Operation<hugeint_t, int64_t>(input: -integral), width, scale);
48 case PhysicalType::INT128:
49 return Value::DECIMAL(value: -integral, width, scale);
50 default:
51 throw InternalException("Unknown DECIMAL type");
52 }
53 }
54 default:
55 // Everything else can just be a DOUBLE
56 return Value::DOUBLE(value: -v.GetValue<double>());
57 }
58}
59
60static void NegatePercentileFractions(ClientContext &context, unique_ptr<ParsedExpression> &fractions, bool desc) {
61 D_ASSERT(fractions.get());
62 D_ASSERT(fractions->expression_class == ExpressionClass::BOUND_EXPRESSION);
63 auto &bound = BoundExpression::GetExpression(expr&: *fractions);
64
65 if (!bound->IsFoldable()) {
66 return;
67 }
68
69 Value value = ExpressionExecutor::EvaluateScalar(context, expr: *bound);
70 if (value.type().id() == LogicalTypeId::LIST) {
71 vector<Value> values;
72 for (const auto &element_val : ListValue::GetChildren(value)) {
73 values.push_back(x: NegatePercentileValue(v: element_val, desc));
74 }
75 if (values.empty()) {
76 throw BinderException("Empty list in percentile not allowed");
77 }
78 bound = make_uniq<BoundConstantExpression>(args: Value::LIST(values));
79 } else {
80 bound = make_uniq<BoundConstantExpression>(args: NegatePercentileValue(v: value, desc));
81 }
82}
83
84BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFunctionCatalogEntry &func, idx_t depth) {
85 // first bind the child of the aggregate expression (if any)
86 this->bound_aggregate = true;
87 unique_ptr<Expression> bound_filter;
88 AggregateBinder aggregate_binder(binder, context);
89 string error, filter_error;
90
91 // Now we bind the filter (if any)
92 if (aggr.filter) {
93 aggregate_binder.BindChild(expr&: aggr.filter, depth: 0, error);
94 }
95
96 // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children.
97 // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE
98 bool ordered_set_agg = false;
99 bool negate_fractions = false;
100 if (aggr.order_bys && aggr.order_bys->orders.size() == 1) {
101 const auto &func_name = aggr.function_name;
102 ordered_set_agg = (func_name == "quantile_cont" || func_name == "quantile_disc" ||
103 (func_name == "mode" && aggr.children.empty()));
104
105 if (ordered_set_agg) {
106 auto &config = DBConfig::GetConfig(context);
107 const auto &order = aggr.order_bys->orders[0];
108 const auto sense =
109 (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type;
110 negate_fractions = (sense == OrderType::DESCENDING);
111 }
112 }
113
114 for (auto &child : aggr.children) {
115 aggregate_binder.BindChild(expr&: child, depth: 0, error);
116 // We have to negate the fractions for PERCENTILE_XXXX DESC
117 if (error.empty() && ordered_set_agg) {
118 NegatePercentileFractions(context, fractions&: child, desc: negate_fractions);
119 }
120 }
121
122 // Bind the ORDER BYs, if any
123 if (aggr.order_bys && !aggr.order_bys->orders.empty()) {
124 for (auto &order : aggr.order_bys->orders) {
125 aggregate_binder.BindChild(expr&: order.expression, depth: 0, error);
126 }
127 }
128
129 if (!error.empty()) {
130 // failed to bind child
131 if (aggregate_binder.HasBoundColumns()) {
132 for (idx_t i = 0; i < aggr.children.size(); i++) {
133 // however, we bound columns!
134 // that means this aggregation belongs to this node
135 // check if we have to resolve any errors by binding with parent binders
136 bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.children[i]);
137 // if there is still an error after this, we could not successfully bind the aggregate
138 if (!success) {
139 throw BinderException(error);
140 }
141 auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.children[i]);
142 ExtractCorrelatedExpressions(binder, expr&: *bound_expr);
143 }
144 if (aggr.filter) {
145 bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.filter);
146 // if there is still an error after this, we could not successfully bind the aggregate
147 if (!success) {
148 throw BinderException(error);
149 }
150 auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.filter);
151 ExtractCorrelatedExpressions(binder, expr&: *bound_expr);
152 }
153 if (aggr.order_bys && !aggr.order_bys->orders.empty()) {
154 for (auto &order : aggr.order_bys->orders) {
155 bool success = aggregate_binder.BindCorrelatedColumns(expr&: order.expression);
156 if (!success) {
157 throw BinderException(error);
158 }
159 auto &bound_expr = BoundExpression::GetExpression(expr&: *order.expression);
160 ExtractCorrelatedExpressions(binder, expr&: *bound_expr);
161 }
162 }
163 } else {
164 // we didn't bind columns, try again in children
165 return BindResult(error);
166 }
167 } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) {
168 return BindResult("Aggregate with only constant parameters has to be bound in the root subquery");
169 }
170 if (!filter_error.empty()) {
171 return BindResult(filter_error);
172 }
173
174 if (aggr.filter) {
175 auto &child = BoundExpression::GetExpression(expr&: *aggr.filter);
176 bound_filter = BoundCastExpression::AddCastToType(context, expr: std::move(child), target_type: LogicalType::BOOLEAN);
177 }
178
179 // all children bound successfully
180 // extract the children and types
181 vector<LogicalType> types;
182 vector<LogicalType> arguments;
183 vector<unique_ptr<Expression>> children;
184
185 if (ordered_set_agg) {
186 const bool order_sensitive = (aggr.function_name == "mode");
187 for (auto &order : aggr.order_bys->orders) {
188 auto &child = BoundExpression::GetExpression(expr&: *order.expression);
189 types.push_back(x: child->return_type);
190 arguments.push_back(x: child->return_type);
191 if (order_sensitive) {
192 children.push_back(x: child->Copy());
193 } else {
194 children.push_back(x: std::move(child));
195 }
196 }
197 if (!order_sensitive) {
198 aggr.order_bys->orders.clear();
199 }
200 }
201
202 for (idx_t i = 0; i < aggr.children.size(); i++) {
203 auto &child = BoundExpression::GetExpression(expr&: *aggr.children[i]);
204 types.push_back(x: child->return_type);
205 arguments.push_back(x: child->return_type);
206 children.push_back(x: std::move(child));
207 }
208
209 // bind the aggregate
210 FunctionBinder function_binder(context);
211 idx_t best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error);
212 if (best_function == DConstants::INVALID_INDEX) {
213 throw BinderException(binder.FormatError(expr_context&: aggr, message: error));
214 }
215 // found a matching function!
216 auto bound_function = func.functions.GetFunctionByOffset(offset: best_function);
217
218 // Bind any sort columns, unless the aggregate is order-insensitive
219 unique_ptr<BoundOrderModifier> order_bys;
220 if (!aggr.order_bys->orders.empty()) {
221 order_bys = make_uniq<BoundOrderModifier>();
222 auto &config = DBConfig::GetConfig(context);
223 for (auto &order : aggr.order_bys->orders) {
224 auto &order_expr = BoundExpression::GetExpression(expr&: *order.expression);
225 const auto sense = config.ResolveOrder(order_type: order.type);
226 const auto null_order = config.ResolveNullOrder(order_type: sense, null_type: order.null_order);
227 order_bys->orders.emplace_back(args: sense, args: null_order, args: std::move(order_expr));
228 }
229 }
230
231 auto aggregate =
232 function_binder.BindAggregateFunction(bound_function, children: std::move(children), filter: std::move(bound_filter),
233 aggr_type: aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT);
234 if (aggr.export_state) {
235 aggregate = ExportAggregateFunction::Bind(child_aggregate: std::move(aggregate));
236 }
237 aggregate->order_bys = std::move(order_bys);
238
239 // check for all the aggregates if this aggregate already exists
240 idx_t aggr_index;
241 auto entry = node.aggregate_map.find(x: *aggregate);
242 if (entry == node.aggregate_map.end()) {
243 // new aggregate: insert into aggregate list
244 aggr_index = node.aggregates.size();
245 node.aggregate_map[*aggregate] = aggr_index;
246 node.aggregates.push_back(x: std::move(aggregate));
247 } else {
248 // duplicate aggregate: simplify refer to this aggregate
249 aggr_index = entry->second;
250 }
251
252 // now create a column reference referring to the aggregate
253 auto colref = make_uniq<BoundColumnRefExpression>(
254 args: aggr.alias.empty() ? node.aggregates[aggr_index]->ToString() : aggr.alias,
255 args&: node.aggregates[aggr_index]->return_type, args: ColumnBinding(node.aggregate_index, aggr_index), args&: depth);
256 // move the aggregate expression into the set of bound aggregates
257 return BindResult(std::move(colref));
258}
259} // namespace duckdb
260