| 1 | #include "duckdb/parser/expression/window_expression.hpp" | 
|---|
| 2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" | 
|---|
| 3 | #include "duckdb/planner/expression/bound_cast_expression.hpp" | 
|---|
| 4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" | 
|---|
| 5 | #include "duckdb/planner/expression/bound_function_expression.hpp" | 
|---|
| 6 | #include "duckdb/planner/expression/bound_window_expression.hpp" | 
|---|
| 7 | #include "duckdb/planner/expression_binder/select_binder.hpp" | 
|---|
| 8 | #include "duckdb/planner/query_node/bound_select_node.hpp" | 
|---|
| 9 | #include "duckdb/planner/binder.hpp" | 
|---|
| 10 | #include "duckdb/main/config.hpp" | 
|---|
| 11 | #include "duckdb/function/scalar_function.hpp" | 
|---|
| 12 | #include "duckdb/function/function_binder.hpp" | 
|---|
| 13 |  | 
|---|
| 14 | #include "duckdb/catalog/catalog.hpp" | 
|---|
| 15 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" | 
|---|
| 16 |  | 
|---|
| 17 | namespace duckdb { | 
|---|
| 18 |  | 
|---|
| 19 | static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector<LogicalType> &child_types) { | 
|---|
| 20 |  | 
|---|
| 21 | idx_t param_count; | 
|---|
| 22 | switch (window_type) { | 
|---|
| 23 | case ExpressionType::WINDOW_RANK: | 
|---|
| 24 | case ExpressionType::WINDOW_RANK_DENSE: | 
|---|
| 25 | case ExpressionType::WINDOW_ROW_NUMBER: | 
|---|
| 26 | case ExpressionType::WINDOW_PERCENT_RANK: | 
|---|
| 27 | case ExpressionType::WINDOW_CUME_DIST: | 
|---|
| 28 | param_count = 0; | 
|---|
| 29 | break; | 
|---|
| 30 | case ExpressionType::WINDOW_NTILE: | 
|---|
| 31 | case ExpressionType::WINDOW_FIRST_VALUE: | 
|---|
| 32 | case ExpressionType::WINDOW_LAST_VALUE: | 
|---|
| 33 | case ExpressionType::WINDOW_LEAD: | 
|---|
| 34 | case ExpressionType::WINDOW_LAG: | 
|---|
| 35 | param_count = 1; | 
|---|
| 36 | break; | 
|---|
| 37 | case ExpressionType::WINDOW_NTH_VALUE: | 
|---|
| 38 | param_count = 2; | 
|---|
| 39 | break; | 
|---|
| 40 | default: | 
|---|
| 41 | throw InternalException( "Unrecognized window expression type "+ ExpressionTypeToString(type: window_type)); | 
|---|
| 42 | } | 
|---|
| 43 | if (child_types.size() != param_count) { | 
|---|
| 44 | throw BinderException( "%s needs %d parameter%s, got %d", ExpressionTypeToString(type: window_type), param_count, | 
|---|
| 45 | param_count == 1 ? "": "s", child_types.size()); | 
|---|
| 46 | } | 
|---|
| 47 | switch (window_type) { | 
|---|
| 48 | case ExpressionType::WINDOW_PERCENT_RANK: | 
|---|
| 49 | case ExpressionType::WINDOW_CUME_DIST: | 
|---|
| 50 | return LogicalType(LogicalTypeId::DOUBLE); | 
|---|
| 51 | case ExpressionType::WINDOW_ROW_NUMBER: | 
|---|
| 52 | case ExpressionType::WINDOW_RANK: | 
|---|
| 53 | case ExpressionType::WINDOW_RANK_DENSE: | 
|---|
| 54 | case ExpressionType::WINDOW_NTILE: | 
|---|
| 55 | return LogicalType::BIGINT; | 
|---|
| 56 | case ExpressionType::WINDOW_NTH_VALUE: | 
|---|
| 57 | case ExpressionType::WINDOW_FIRST_VALUE: | 
|---|
| 58 | case ExpressionType::WINDOW_LAST_VALUE: | 
|---|
| 59 | case ExpressionType::WINDOW_LEAD: | 
|---|
| 60 | case ExpressionType::WINDOW_LAG: | 
|---|
| 61 | return child_types[0]; | 
|---|
| 62 | default: | 
|---|
| 63 | throw InternalException( "Unrecognized window expression type "+ ExpressionTypeToString(type: window_type)); | 
|---|
| 64 | } | 
|---|
| 65 | } | 
|---|
| 66 |  | 
|---|
| 67 | static unique_ptr<Expression> GetExpression(unique_ptr<ParsedExpression> &expr) { | 
|---|
| 68 | if (!expr) { | 
|---|
| 69 | return nullptr; | 
|---|
| 70 | } | 
|---|
| 71 | D_ASSERT(expr.get()); | 
|---|
| 72 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 73 | return std::move(BoundExpression::GetExpression(expr&: *expr)); | 
|---|
| 74 | } | 
|---|
| 75 |  | 
|---|
| 76 | static unique_ptr<Expression> CastWindowExpression(unique_ptr<ParsedExpression> &expr, const LogicalType &type) { | 
|---|
| 77 | if (!expr) { | 
|---|
| 78 | return nullptr; | 
|---|
| 79 | } | 
|---|
| 80 | D_ASSERT(expr.get()); | 
|---|
| 81 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 82 |  | 
|---|
| 83 | auto &bound = BoundExpression::GetExpression(expr&: *expr); | 
|---|
| 84 | bound = BoundCastExpression::AddDefaultCastToType(expr: std::move(bound), target_type: type); | 
|---|
| 85 |  | 
|---|
| 86 | return std::move(bound); | 
|---|
| 87 | } | 
|---|
| 88 |  | 
|---|
| 89 | static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr<ParsedExpression> &expr, | 
|---|
| 90 | unique_ptr<ParsedExpression> &order_expr) { | 
|---|
| 91 |  | 
|---|
| 92 | vector<unique_ptr<Expression>> children; | 
|---|
| 93 |  | 
|---|
| 94 | D_ASSERT(order_expr.get()); | 
|---|
| 95 | D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 96 | auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr); | 
|---|
| 97 | children.emplace_back(args: bound_order->Copy()); | 
|---|
| 98 |  | 
|---|
| 99 | D_ASSERT(expr.get()); | 
|---|
| 100 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 101 | auto &bound = BoundExpression::GetExpression(expr&: *expr); | 
|---|
| 102 | children.emplace_back(args: std::move(bound)); | 
|---|
| 103 |  | 
|---|
| 104 | string error; | 
|---|
| 105 | FunctionBinder function_binder(context); | 
|---|
| 106 | auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, children: std::move(children), error, is_operator: true); | 
|---|
| 107 | if (!function) { | 
|---|
| 108 | throw BinderException(error); | 
|---|
| 109 | } | 
|---|
| 110 | bound = std::move(function); | 
|---|
| 111 | return bound->return_type; | 
|---|
| 112 | } | 
|---|
| 113 |  | 
|---|
| 114 | BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { | 
|---|
| 115 | auto name = window.GetName(); | 
|---|
| 116 |  | 
|---|
| 117 | QueryErrorContext error_context(binder.GetRootStatement(), window.query_location); | 
|---|
| 118 | if (inside_window) { | 
|---|
| 119 | throw BinderException(error_context.FormatError(msg: "window function calls cannot be nested")); | 
|---|
| 120 | } | 
|---|
| 121 | if (depth > 0) { | 
|---|
| 122 | throw BinderException(error_context.FormatError(msg: "correlated columns in window functions not supported")); | 
|---|
| 123 | } | 
|---|
| 124 | // If we have range expressions, then only one order by clause is allowed. | 
|---|
| 125 | if ((window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE || | 
|---|
| 126 | window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) && | 
|---|
| 127 | window.orders.size() != 1) { | 
|---|
| 128 | throw BinderException(error_context.FormatError(msg: "RANGE frames must have only one ORDER BY expression")); | 
|---|
| 129 | } | 
|---|
| 130 | // bind inside the children of the window function | 
|---|
| 131 | // we set the inside_window flag to true to prevent binding nested window functions | 
|---|
| 132 | this->inside_window = true; | 
|---|
| 133 | string error; | 
|---|
| 134 | for (auto &child : window.children) { | 
|---|
| 135 | BindChild(expr&: child, depth, error); | 
|---|
| 136 | } | 
|---|
| 137 | for (auto &child : window.partitions) { | 
|---|
| 138 | BindChild(expr&: child, depth, error); | 
|---|
| 139 | } | 
|---|
| 140 | for (auto &order : window.orders) { | 
|---|
| 141 | BindChild(expr&: order.expression, depth, error); | 
|---|
| 142 | } | 
|---|
| 143 | BindChild(expr&: window.filter_expr, depth, error); | 
|---|
| 144 | BindChild(expr&: window.start_expr, depth, error); | 
|---|
| 145 | BindChild(expr&: window.end_expr, depth, error); | 
|---|
| 146 | BindChild(expr&: window.offset_expr, depth, error); | 
|---|
| 147 | BindChild(expr&: window.default_expr, depth, error); | 
|---|
| 148 |  | 
|---|
| 149 | this->inside_window = false; | 
|---|
| 150 | if (!error.empty()) { | 
|---|
| 151 | // failed to bind children of window function | 
|---|
| 152 | return BindResult(error); | 
|---|
| 153 | } | 
|---|
| 154 | // successfully bound all children: create bound window function | 
|---|
| 155 | vector<LogicalType> types; | 
|---|
| 156 | vector<unique_ptr<Expression>> children; | 
|---|
| 157 | for (auto &child : window.children) { | 
|---|
| 158 | D_ASSERT(child.get()); | 
|---|
| 159 | D_ASSERT(child->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 160 | auto &bound = BoundExpression::GetExpression(expr&: *child); | 
|---|
| 161 | // Add casts for positional arguments | 
|---|
| 162 | const auto argno = children.size(); | 
|---|
| 163 | switch (window.type) { | 
|---|
| 164 | case ExpressionType::WINDOW_NTILE: | 
|---|
| 165 | // ntile(bigint) | 
|---|
| 166 | if (argno == 0) { | 
|---|
| 167 | bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT); | 
|---|
| 168 | } | 
|---|
| 169 | break; | 
|---|
| 170 | case ExpressionType::WINDOW_NTH_VALUE: | 
|---|
| 171 | // nth_value(<expr>, index) | 
|---|
| 172 | if (argno == 1) { | 
|---|
| 173 | bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT); | 
|---|
| 174 | } | 
|---|
| 175 | default: | 
|---|
| 176 | break; | 
|---|
| 177 | } | 
|---|
| 178 | types.push_back(x: bound->return_type); | 
|---|
| 179 | children.push_back(x: std::move(bound)); | 
|---|
| 180 | } | 
|---|
| 181 | //  Determine the function type. | 
|---|
| 182 | LogicalType sql_type; | 
|---|
| 183 | unique_ptr<AggregateFunction> aggregate; | 
|---|
| 184 | unique_ptr<FunctionData> bind_info; | 
|---|
| 185 | if (window.type == ExpressionType::WINDOW_AGGREGATE) { | 
|---|
| 186 | //  Look up the aggregate function in the catalog | 
|---|
| 187 | auto &func = Catalog::GetEntry<AggregateFunctionCatalogEntry>(context, catalog_name: window.catalog, schema_name: window.schema, | 
|---|
| 188 | name: window.function_name, error_context); | 
|---|
| 189 | D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); | 
|---|
| 190 |  | 
|---|
| 191 | // bind the aggregate | 
|---|
| 192 | string error; | 
|---|
| 193 | FunctionBinder function_binder(context); | 
|---|
| 194 | auto best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error); | 
|---|
| 195 | if (best_function == DConstants::INVALID_INDEX) { | 
|---|
| 196 | throw BinderException(binder.FormatError(expr_context&: window, message: error)); | 
|---|
| 197 | } | 
|---|
| 198 | // found a matching function! bind it as an aggregate | 
|---|
| 199 | auto bound_function = func.functions.GetFunctionByOffset(offset: best_function); | 
|---|
| 200 | auto bound_aggregate = function_binder.BindAggregateFunction(bound_function, children: std::move(children)); | 
|---|
| 201 | // create the aggregate | 
|---|
| 202 | aggregate = make_uniq<AggregateFunction>(args&: bound_aggregate->function); | 
|---|
| 203 | bind_info = std::move(bound_aggregate->bind_info); | 
|---|
| 204 | children = std::move(bound_aggregate->children); | 
|---|
| 205 | sql_type = bound_aggregate->return_type; | 
|---|
| 206 | } else { | 
|---|
| 207 | // fetch the child of the non-aggregate window function (if any) | 
|---|
| 208 | sql_type = ResolveWindowExpressionType(window_type: window.type, child_types: types); | 
|---|
| 209 | } | 
|---|
| 210 | auto result = make_uniq<BoundWindowExpression>(args&: window.type, args&: sql_type, args: std::move(aggregate), args: std::move(bind_info)); | 
|---|
| 211 | result->children = std::move(children); | 
|---|
| 212 | for (auto &child : window.partitions) { | 
|---|
| 213 | result->partitions.push_back(x: GetExpression(expr&: child)); | 
|---|
| 214 | } | 
|---|
| 215 | result->ignore_nulls = window.ignore_nulls; | 
|---|
| 216 |  | 
|---|
| 217 | // Convert RANGE boundary expressions to ORDER +/- expressions. | 
|---|
| 218 | // Note that PRECEEDING and FOLLOWING refer to the sequential order in the frame, | 
|---|
| 219 | // not the natural ordering of the type. This means that the offset arithmetic must be reversed | 
|---|
| 220 | // for ORDER BY DESC. | 
|---|
| 221 | auto &config = DBConfig::GetConfig(context); | 
|---|
| 222 | auto range_sense = OrderType::INVALID; | 
|---|
| 223 | LogicalType start_type = LogicalType::BIGINT; | 
|---|
| 224 | if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { | 
|---|
| 225 | D_ASSERT(window.orders.size() == 1); | 
|---|
| 226 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); | 
|---|
| 227 | const auto name = (range_sense == OrderType::ASCENDING) ? "-": "+"; | 
|---|
| 228 | start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression); | 
|---|
| 229 | } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { | 
|---|
| 230 | D_ASSERT(window.orders.size() == 1); | 
|---|
| 231 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); | 
|---|
| 232 | const auto name = (range_sense == OrderType::ASCENDING) ? "+": "-"; | 
|---|
| 233 | start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression); | 
|---|
| 234 | } | 
|---|
| 235 |  | 
|---|
| 236 | LogicalType end_type = LogicalType::BIGINT; | 
|---|
| 237 | if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { | 
|---|
| 238 | D_ASSERT(window.orders.size() == 1); | 
|---|
| 239 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); | 
|---|
| 240 | const auto name = (range_sense == OrderType::ASCENDING) ? "-": "+"; | 
|---|
| 241 | end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression); | 
|---|
| 242 | } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { | 
|---|
| 243 | D_ASSERT(window.orders.size() == 1); | 
|---|
| 244 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); | 
|---|
| 245 | const auto name = (range_sense == OrderType::ASCENDING) ? "+": "-"; | 
|---|
| 246 | end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression); | 
|---|
| 247 | } | 
|---|
| 248 |  | 
|---|
| 249 | // Cast ORDER and boundary expressions to the same type | 
|---|
| 250 | if (range_sense != OrderType::INVALID) { | 
|---|
| 251 | D_ASSERT(window.orders.size() == 1); | 
|---|
| 252 |  | 
|---|
| 253 | auto &order_expr = window.orders[0].expression; | 
|---|
| 254 | D_ASSERT(order_expr.get()); | 
|---|
| 255 | D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); | 
|---|
| 256 | auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr); | 
|---|
| 257 | auto order_type = bound_order->return_type; | 
|---|
| 258 | if (window.start_expr) { | 
|---|
| 259 | order_type = LogicalType::MaxLogicalType(left: order_type, right: start_type); | 
|---|
| 260 | } | 
|---|
| 261 | if (window.end_expr) { | 
|---|
| 262 | order_type = LogicalType::MaxLogicalType(left: order_type, right: end_type); | 
|---|
| 263 | } | 
|---|
| 264 |  | 
|---|
| 265 | // Cast all three to match | 
|---|
| 266 | bound_order = BoundCastExpression::AddCastToType(context, expr: std::move(bound_order), target_type: order_type); | 
|---|
| 267 | start_type = end_type = order_type; | 
|---|
| 268 | } | 
|---|
| 269 |  | 
|---|
| 270 | for (auto &order : window.orders) { | 
|---|
| 271 | auto type = config.ResolveOrder(order_type: order.type); | 
|---|
| 272 | auto null_order = config.ResolveNullOrder(order_type: type, null_type: order.null_order); | 
|---|
| 273 | auto expression = GetExpression(expr&: order.expression); | 
|---|
| 274 | result->orders.emplace_back(args&: type, args&: null_order, args: std::move(expression)); | 
|---|
| 275 | } | 
|---|
| 276 |  | 
|---|
| 277 | result->filter_expr = CastWindowExpression(expr&: window.filter_expr, type: LogicalType::BOOLEAN); | 
|---|
| 278 |  | 
|---|
| 279 | result->start_expr = CastWindowExpression(expr&: window.start_expr, type: start_type); | 
|---|
| 280 | result->end_expr = CastWindowExpression(expr&: window.end_expr, type: end_type); | 
|---|
| 281 | result->offset_expr = CastWindowExpression(expr&: window.offset_expr, type: LogicalType::BIGINT); | 
|---|
| 282 | result->default_expr = CastWindowExpression(expr&: window.default_expr, type: result->return_type); | 
|---|
| 283 | result->start = window.start; | 
|---|
| 284 | result->end = window.end; | 
|---|
| 285 |  | 
|---|
| 286 | // create a BoundColumnRef that references this entry | 
|---|
| 287 | auto colref = make_uniq<BoundColumnRefExpression>(args: std::move(name), args&: result->return_type, | 
|---|
| 288 | args: ColumnBinding(node.window_index, node.windows.size()), args&: depth); | 
|---|
| 289 | // move the WINDOW expression into the set of bound windows | 
|---|
| 290 | node.windows.push_back(x: std::move(result)); | 
|---|
| 291 | return BindResult(std::move(colref)); | 
|---|
| 292 | } | 
|---|
| 293 |  | 
|---|
| 294 | } // namespace duckdb | 
|---|
| 295 |  | 
|---|