| 1 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
| 2 | #include "duckdb/common/pair.hpp" |
| 3 | #include "duckdb/common/operator/cast_operators.hpp" |
| 4 | #include "duckdb/parser/expression/function_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 6 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
| 7 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 8 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 9 | #include "duckdb/planner/expression_binder/aggregate_binder.hpp" |
| 10 | #include "duckdb/planner/expression_binder/base_select_binder.hpp" |
| 11 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
| 12 | #include "duckdb/execution/expression_executor.hpp" |
| 13 | #include "duckdb/function/scalar/generic_functions.hpp" |
| 14 | #include "duckdb/main/config.hpp" |
| 15 | #include "duckdb/function/function_binder.hpp" |
| 16 | #include "duckdb/planner/binder.hpp" |
| 17 | |
| 18 | namespace duckdb { |
| 19 | |
| 20 | static Value NegatePercentileValue(const Value &v, const bool desc) { |
| 21 | if (v.IsNull()) { |
| 22 | return v; |
| 23 | } |
| 24 | |
| 25 | const auto frac = v.GetValue<double>(); |
| 26 | if (frac < 0 || frac > 1) { |
| 27 | throw BinderException("PERCENTILEs can only take parameters in the range [0, 1]" ); |
| 28 | } |
| 29 | |
| 30 | if (!desc) { |
| 31 | return v; |
| 32 | } |
| 33 | |
| 34 | const auto &type = v.type(); |
| 35 | switch (type.id()) { |
| 36 | case LogicalTypeId::DECIMAL: { |
| 37 | // Negate DECIMALs as DECIMAL. |
| 38 | const auto integral = IntegralValue::Get(value: v); |
| 39 | const auto width = DecimalType::GetWidth(type); |
| 40 | const auto scale = DecimalType::GetScale(type); |
| 41 | switch (type.InternalType()) { |
| 42 | case PhysicalType::INT16: |
| 43 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int16_t>(input: -integral), width, scale); |
| 44 | case PhysicalType::INT32: |
| 45 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int32_t>(input: -integral), width, scale); |
| 46 | case PhysicalType::INT64: |
| 47 | return Value::DECIMAL(value: Cast::Operation<hugeint_t, int64_t>(input: -integral), width, scale); |
| 48 | case PhysicalType::INT128: |
| 49 | return Value::DECIMAL(value: -integral, width, scale); |
| 50 | default: |
| 51 | throw InternalException("Unknown DECIMAL type" ); |
| 52 | } |
| 53 | } |
| 54 | default: |
| 55 | // Everything else can just be a DOUBLE |
| 56 | return Value::DOUBLE(value: -v.GetValue<double>()); |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | static void NegatePercentileFractions(ClientContext &context, unique_ptr<ParsedExpression> &fractions, bool desc) { |
| 61 | D_ASSERT(fractions.get()); |
| 62 | D_ASSERT(fractions->expression_class == ExpressionClass::BOUND_EXPRESSION); |
| 63 | auto &bound = BoundExpression::GetExpression(expr&: *fractions); |
| 64 | |
| 65 | if (!bound->IsFoldable()) { |
| 66 | return; |
| 67 | } |
| 68 | |
| 69 | Value value = ExpressionExecutor::EvaluateScalar(context, expr: *bound); |
| 70 | if (value.type().id() == LogicalTypeId::LIST) { |
| 71 | vector<Value> values; |
| 72 | for (const auto &element_val : ListValue::GetChildren(value)) { |
| 73 | values.push_back(x: NegatePercentileValue(v: element_val, desc)); |
| 74 | } |
| 75 | if (values.empty()) { |
| 76 | throw BinderException("Empty list in percentile not allowed" ); |
| 77 | } |
| 78 | bound = make_uniq<BoundConstantExpression>(args: Value::LIST(values)); |
| 79 | } else { |
| 80 | bound = make_uniq<BoundConstantExpression>(args: NegatePercentileValue(v: value, desc)); |
| 81 | } |
| 82 | } |
| 83 | |
| 84 | BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFunctionCatalogEntry &func, idx_t depth) { |
| 85 | // first bind the child of the aggregate expression (if any) |
| 86 | this->bound_aggregate = true; |
| 87 | unique_ptr<Expression> bound_filter; |
| 88 | AggregateBinder aggregate_binder(binder, context); |
| 89 | string error, filter_error; |
| 90 | |
| 91 | // Now we bind the filter (if any) |
| 92 | if (aggr.filter) { |
| 93 | aggregate_binder.BindChild(expr&: aggr.filter, depth: 0, error); |
| 94 | } |
| 95 | |
| 96 | // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children. |
| 97 | // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE |
| 98 | bool ordered_set_agg = false; |
| 99 | bool negate_fractions = false; |
| 100 | if (aggr.order_bys && aggr.order_bys->orders.size() == 1) { |
| 101 | const auto &func_name = aggr.function_name; |
| 102 | ordered_set_agg = (func_name == "quantile_cont" || func_name == "quantile_disc" || |
| 103 | (func_name == "mode" && aggr.children.empty())); |
| 104 | |
| 105 | if (ordered_set_agg) { |
| 106 | auto &config = DBConfig::GetConfig(context); |
| 107 | const auto &order = aggr.order_bys->orders[0]; |
| 108 | const auto sense = |
| 109 | (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type; |
| 110 | negate_fractions = (sense == OrderType::DESCENDING); |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | for (auto &child : aggr.children) { |
| 115 | aggregate_binder.BindChild(expr&: child, depth: 0, error); |
| 116 | // We have to negate the fractions for PERCENTILE_XXXX DESC |
| 117 | if (error.empty() && ordered_set_agg) { |
| 118 | NegatePercentileFractions(context, fractions&: child, desc: negate_fractions); |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | // Bind the ORDER BYs, if any |
| 123 | if (aggr.order_bys && !aggr.order_bys->orders.empty()) { |
| 124 | for (auto &order : aggr.order_bys->orders) { |
| 125 | aggregate_binder.BindChild(expr&: order.expression, depth: 0, error); |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | if (!error.empty()) { |
| 130 | // failed to bind child |
| 131 | if (aggregate_binder.HasBoundColumns()) { |
| 132 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
| 133 | // however, we bound columns! |
| 134 | // that means this aggregation belongs to this node |
| 135 | // check if we have to resolve any errors by binding with parent binders |
| 136 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.children[i]); |
| 137 | // if there is still an error after this, we could not successfully bind the aggregate |
| 138 | if (!success) { |
| 139 | throw BinderException(error); |
| 140 | } |
| 141 | auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.children[i]); |
| 142 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
| 143 | } |
| 144 | if (aggr.filter) { |
| 145 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: aggr.filter); |
| 146 | // if there is still an error after this, we could not successfully bind the aggregate |
| 147 | if (!success) { |
| 148 | throw BinderException(error); |
| 149 | } |
| 150 | auto &bound_expr = BoundExpression::GetExpression(expr&: *aggr.filter); |
| 151 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
| 152 | } |
| 153 | if (aggr.order_bys && !aggr.order_bys->orders.empty()) { |
| 154 | for (auto &order : aggr.order_bys->orders) { |
| 155 | bool success = aggregate_binder.BindCorrelatedColumns(expr&: order.expression); |
| 156 | if (!success) { |
| 157 | throw BinderException(error); |
| 158 | } |
| 159 | auto &bound_expr = BoundExpression::GetExpression(expr&: *order.expression); |
| 160 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr); |
| 161 | } |
| 162 | } |
| 163 | } else { |
| 164 | // we didn't bind columns, try again in children |
| 165 | return BindResult(error); |
| 166 | } |
| 167 | } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) { |
| 168 | return BindResult("Aggregate with only constant parameters has to be bound in the root subquery" ); |
| 169 | } |
| 170 | if (!filter_error.empty()) { |
| 171 | return BindResult(filter_error); |
| 172 | } |
| 173 | |
| 174 | if (aggr.filter) { |
| 175 | auto &child = BoundExpression::GetExpression(expr&: *aggr.filter); |
| 176 | bound_filter = BoundCastExpression::AddCastToType(context, expr: std::move(child), target_type: LogicalType::BOOLEAN); |
| 177 | } |
| 178 | |
| 179 | // all children bound successfully |
| 180 | // extract the children and types |
| 181 | vector<LogicalType> types; |
| 182 | vector<LogicalType> arguments; |
| 183 | vector<unique_ptr<Expression>> children; |
| 184 | |
| 185 | if (ordered_set_agg) { |
| 186 | const bool order_sensitive = (aggr.function_name == "mode" ); |
| 187 | for (auto &order : aggr.order_bys->orders) { |
| 188 | auto &child = BoundExpression::GetExpression(expr&: *order.expression); |
| 189 | types.push_back(x: child->return_type); |
| 190 | arguments.push_back(x: child->return_type); |
| 191 | if (order_sensitive) { |
| 192 | children.push_back(x: child->Copy()); |
| 193 | } else { |
| 194 | children.push_back(x: std::move(child)); |
| 195 | } |
| 196 | } |
| 197 | if (!order_sensitive) { |
| 198 | aggr.order_bys->orders.clear(); |
| 199 | } |
| 200 | } |
| 201 | |
| 202 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
| 203 | auto &child = BoundExpression::GetExpression(expr&: *aggr.children[i]); |
| 204 | types.push_back(x: child->return_type); |
| 205 | arguments.push_back(x: child->return_type); |
| 206 | children.push_back(x: std::move(child)); |
| 207 | } |
| 208 | |
| 209 | // bind the aggregate |
| 210 | FunctionBinder function_binder(context); |
| 211 | idx_t best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error); |
| 212 | if (best_function == DConstants::INVALID_INDEX) { |
| 213 | throw BinderException(binder.FormatError(expr_context&: aggr, message: error)); |
| 214 | } |
| 215 | // found a matching function! |
| 216 | auto bound_function = func.functions.GetFunctionByOffset(offset: best_function); |
| 217 | |
| 218 | // Bind any sort columns, unless the aggregate is order-insensitive |
| 219 | unique_ptr<BoundOrderModifier> order_bys; |
| 220 | if (!aggr.order_bys->orders.empty()) { |
| 221 | order_bys = make_uniq<BoundOrderModifier>(); |
| 222 | auto &config = DBConfig::GetConfig(context); |
| 223 | for (auto &order : aggr.order_bys->orders) { |
| 224 | auto &order_expr = BoundExpression::GetExpression(expr&: *order.expression); |
| 225 | const auto sense = config.ResolveOrder(order_type: order.type); |
| 226 | const auto null_order = config.ResolveNullOrder(order_type: sense, null_type: order.null_order); |
| 227 | order_bys->orders.emplace_back(args: sense, args: null_order, args: std::move(order_expr)); |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | auto aggregate = |
| 232 | function_binder.BindAggregateFunction(bound_function, children: std::move(children), filter: std::move(bound_filter), |
| 233 | aggr_type: aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); |
| 234 | if (aggr.export_state) { |
| 235 | aggregate = ExportAggregateFunction::Bind(child_aggregate: std::move(aggregate)); |
| 236 | } |
| 237 | aggregate->order_bys = std::move(order_bys); |
| 238 | |
| 239 | // check for all the aggregates if this aggregate already exists |
| 240 | idx_t aggr_index; |
| 241 | auto entry = node.aggregate_map.find(x: *aggregate); |
| 242 | if (entry == node.aggregate_map.end()) { |
| 243 | // new aggregate: insert into aggregate list |
| 244 | aggr_index = node.aggregates.size(); |
| 245 | node.aggregate_map[*aggregate] = aggr_index; |
| 246 | node.aggregates.push_back(x: std::move(aggregate)); |
| 247 | } else { |
| 248 | // duplicate aggregate: simplify refer to this aggregate |
| 249 | aggr_index = entry->second; |
| 250 | } |
| 251 | |
| 252 | // now create a column reference referring to the aggregate |
| 253 | auto colref = make_uniq<BoundColumnRefExpression>( |
| 254 | args: aggr.alias.empty() ? node.aggregates[aggr_index]->ToString() : aggr.alias, |
| 255 | args&: node.aggregates[aggr_index]->return_type, args: ColumnBinding(node.aggregate_index, aggr_index), args&: depth); |
| 256 | // move the aggregate expression into the set of bound aggregates |
| 257 | return BindResult(std::move(colref)); |
| 258 | } |
| 259 | } // namespace duckdb |
| 260 | |