1 | #include "duckdb/parser/expression/window_expression.hpp" |
2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
3 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
7 | #include "duckdb/planner/expression_binder/select_binder.hpp" |
8 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
9 | #include "duckdb/planner/binder.hpp" |
10 | #include "duckdb/main/config.hpp" |
11 | #include "duckdb/function/scalar_function.hpp" |
12 | #include "duckdb/function/function_binder.hpp" |
13 | |
14 | #include "duckdb/catalog/catalog.hpp" |
15 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
16 | |
17 | namespace duckdb { |
18 | |
19 | static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector<LogicalType> &child_types) { |
20 | |
21 | idx_t param_count; |
22 | switch (window_type) { |
23 | case ExpressionType::WINDOW_RANK: |
24 | case ExpressionType::WINDOW_RANK_DENSE: |
25 | case ExpressionType::WINDOW_ROW_NUMBER: |
26 | case ExpressionType::WINDOW_PERCENT_RANK: |
27 | case ExpressionType::WINDOW_CUME_DIST: |
28 | param_count = 0; |
29 | break; |
30 | case ExpressionType::WINDOW_NTILE: |
31 | case ExpressionType::WINDOW_FIRST_VALUE: |
32 | case ExpressionType::WINDOW_LAST_VALUE: |
33 | case ExpressionType::WINDOW_LEAD: |
34 | case ExpressionType::WINDOW_LAG: |
35 | param_count = 1; |
36 | break; |
37 | case ExpressionType::WINDOW_NTH_VALUE: |
38 | param_count = 2; |
39 | break; |
40 | default: |
41 | throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(type: window_type)); |
42 | } |
43 | if (child_types.size() != param_count) { |
44 | throw BinderException("%s needs %d parameter%s, got %d" , ExpressionTypeToString(type: window_type), param_count, |
45 | param_count == 1 ? "" : "s" , child_types.size()); |
46 | } |
47 | switch (window_type) { |
48 | case ExpressionType::WINDOW_PERCENT_RANK: |
49 | case ExpressionType::WINDOW_CUME_DIST: |
50 | return LogicalType(LogicalTypeId::DOUBLE); |
51 | case ExpressionType::WINDOW_ROW_NUMBER: |
52 | case ExpressionType::WINDOW_RANK: |
53 | case ExpressionType::WINDOW_RANK_DENSE: |
54 | case ExpressionType::WINDOW_NTILE: |
55 | return LogicalType::BIGINT; |
56 | case ExpressionType::WINDOW_NTH_VALUE: |
57 | case ExpressionType::WINDOW_FIRST_VALUE: |
58 | case ExpressionType::WINDOW_LAST_VALUE: |
59 | case ExpressionType::WINDOW_LEAD: |
60 | case ExpressionType::WINDOW_LAG: |
61 | return child_types[0]; |
62 | default: |
63 | throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(type: window_type)); |
64 | } |
65 | } |
66 | |
67 | static unique_ptr<Expression> GetExpression(unique_ptr<ParsedExpression> &expr) { |
68 | if (!expr) { |
69 | return nullptr; |
70 | } |
71 | D_ASSERT(expr.get()); |
72 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); |
73 | return std::move(BoundExpression::GetExpression(expr&: *expr)); |
74 | } |
75 | |
76 | static unique_ptr<Expression> CastWindowExpression(unique_ptr<ParsedExpression> &expr, const LogicalType &type) { |
77 | if (!expr) { |
78 | return nullptr; |
79 | } |
80 | D_ASSERT(expr.get()); |
81 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); |
82 | |
83 | auto &bound = BoundExpression::GetExpression(expr&: *expr); |
84 | bound = BoundCastExpression::AddDefaultCastToType(expr: std::move(bound), target_type: type); |
85 | |
86 | return std::move(bound); |
87 | } |
88 | |
89 | static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr<ParsedExpression> &expr, |
90 | unique_ptr<ParsedExpression> &order_expr) { |
91 | |
92 | vector<unique_ptr<Expression>> children; |
93 | |
94 | D_ASSERT(order_expr.get()); |
95 | D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); |
96 | auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr); |
97 | children.emplace_back(args: bound_order->Copy()); |
98 | |
99 | D_ASSERT(expr.get()); |
100 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION); |
101 | auto &bound = BoundExpression::GetExpression(expr&: *expr); |
102 | children.emplace_back(args: std::move(bound)); |
103 | |
104 | string error; |
105 | FunctionBinder function_binder(context); |
106 | auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, children: std::move(children), error, is_operator: true); |
107 | if (!function) { |
108 | throw BinderException(error); |
109 | } |
110 | bound = std::move(function); |
111 | return bound->return_type; |
112 | } |
113 | |
114 | BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { |
115 | auto name = window.GetName(); |
116 | |
117 | QueryErrorContext error_context(binder.GetRootStatement(), window.query_location); |
118 | if (inside_window) { |
119 | throw BinderException(error_context.FormatError(msg: "window function calls cannot be nested" )); |
120 | } |
121 | if (depth > 0) { |
122 | throw BinderException(error_context.FormatError(msg: "correlated columns in window functions not supported" )); |
123 | } |
124 | // If we have range expressions, then only one order by clause is allowed. |
125 | if ((window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE || |
126 | window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) && |
127 | window.orders.size() != 1) { |
128 | throw BinderException(error_context.FormatError(msg: "RANGE frames must have only one ORDER BY expression" )); |
129 | } |
130 | // bind inside the children of the window function |
131 | // we set the inside_window flag to true to prevent binding nested window functions |
132 | this->inside_window = true; |
133 | string error; |
134 | for (auto &child : window.children) { |
135 | BindChild(expr&: child, depth, error); |
136 | } |
137 | for (auto &child : window.partitions) { |
138 | BindChild(expr&: child, depth, error); |
139 | } |
140 | for (auto &order : window.orders) { |
141 | BindChild(expr&: order.expression, depth, error); |
142 | } |
143 | BindChild(expr&: window.filter_expr, depth, error); |
144 | BindChild(expr&: window.start_expr, depth, error); |
145 | BindChild(expr&: window.end_expr, depth, error); |
146 | BindChild(expr&: window.offset_expr, depth, error); |
147 | BindChild(expr&: window.default_expr, depth, error); |
148 | |
149 | this->inside_window = false; |
150 | if (!error.empty()) { |
151 | // failed to bind children of window function |
152 | return BindResult(error); |
153 | } |
154 | // successfully bound all children: create bound window function |
155 | vector<LogicalType> types; |
156 | vector<unique_ptr<Expression>> children; |
157 | for (auto &child : window.children) { |
158 | D_ASSERT(child.get()); |
159 | D_ASSERT(child->expression_class == ExpressionClass::BOUND_EXPRESSION); |
160 | auto &bound = BoundExpression::GetExpression(expr&: *child); |
161 | // Add casts for positional arguments |
162 | const auto argno = children.size(); |
163 | switch (window.type) { |
164 | case ExpressionType::WINDOW_NTILE: |
165 | // ntile(bigint) |
166 | if (argno == 0) { |
167 | bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT); |
168 | } |
169 | break; |
170 | case ExpressionType::WINDOW_NTH_VALUE: |
171 | // nth_value(<expr>, index) |
172 | if (argno == 1) { |
173 | bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT); |
174 | } |
175 | default: |
176 | break; |
177 | } |
178 | types.push_back(x: bound->return_type); |
179 | children.push_back(x: std::move(bound)); |
180 | } |
181 | // Determine the function type. |
182 | LogicalType sql_type; |
183 | unique_ptr<AggregateFunction> aggregate; |
184 | unique_ptr<FunctionData> bind_info; |
185 | if (window.type == ExpressionType::WINDOW_AGGREGATE) { |
186 | // Look up the aggregate function in the catalog |
187 | auto &func = Catalog::GetEntry<AggregateFunctionCatalogEntry>(context, catalog_name: window.catalog, schema_name: window.schema, |
188 | name: window.function_name, error_context); |
189 | D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); |
190 | |
191 | // bind the aggregate |
192 | string error; |
193 | FunctionBinder function_binder(context); |
194 | auto best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error); |
195 | if (best_function == DConstants::INVALID_INDEX) { |
196 | throw BinderException(binder.FormatError(expr_context&: window, message: error)); |
197 | } |
198 | // found a matching function! bind it as an aggregate |
199 | auto bound_function = func.functions.GetFunctionByOffset(offset: best_function); |
200 | auto bound_aggregate = function_binder.BindAggregateFunction(bound_function, children: std::move(children)); |
201 | // create the aggregate |
202 | aggregate = make_uniq<AggregateFunction>(args&: bound_aggregate->function); |
203 | bind_info = std::move(bound_aggregate->bind_info); |
204 | children = std::move(bound_aggregate->children); |
205 | sql_type = bound_aggregate->return_type; |
206 | } else { |
207 | // fetch the child of the non-aggregate window function (if any) |
208 | sql_type = ResolveWindowExpressionType(window_type: window.type, child_types: types); |
209 | } |
210 | auto result = make_uniq<BoundWindowExpression>(args&: window.type, args&: sql_type, args: std::move(aggregate), args: std::move(bind_info)); |
211 | result->children = std::move(children); |
212 | for (auto &child : window.partitions) { |
213 | result->partitions.push_back(x: GetExpression(expr&: child)); |
214 | } |
215 | result->ignore_nulls = window.ignore_nulls; |
216 | |
217 | // Convert RANGE boundary expressions to ORDER +/- expressions. |
218 | // Note that PRECEEDING and FOLLOWING refer to the sequential order in the frame, |
219 | // not the natural ordering of the type. This means that the offset arithmetic must be reversed |
220 | // for ORDER BY DESC. |
221 | auto &config = DBConfig::GetConfig(context); |
222 | auto range_sense = OrderType::INVALID; |
223 | LogicalType start_type = LogicalType::BIGINT; |
224 | if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { |
225 | D_ASSERT(window.orders.size() == 1); |
226 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); |
227 | const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+" ; |
228 | start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression); |
229 | } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { |
230 | D_ASSERT(window.orders.size() == 1); |
231 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); |
232 | const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-" ; |
233 | start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression); |
234 | } |
235 | |
236 | LogicalType end_type = LogicalType::BIGINT; |
237 | if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { |
238 | D_ASSERT(window.orders.size() == 1); |
239 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); |
240 | const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+" ; |
241 | end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression); |
242 | } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { |
243 | D_ASSERT(window.orders.size() == 1); |
244 | range_sense = config.ResolveOrder(order_type: window.orders[0].type); |
245 | const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-" ; |
246 | end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression); |
247 | } |
248 | |
249 | // Cast ORDER and boundary expressions to the same type |
250 | if (range_sense != OrderType::INVALID) { |
251 | D_ASSERT(window.orders.size() == 1); |
252 | |
253 | auto &order_expr = window.orders[0].expression; |
254 | D_ASSERT(order_expr.get()); |
255 | D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION); |
256 | auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr); |
257 | auto order_type = bound_order->return_type; |
258 | if (window.start_expr) { |
259 | order_type = LogicalType::MaxLogicalType(left: order_type, right: start_type); |
260 | } |
261 | if (window.end_expr) { |
262 | order_type = LogicalType::MaxLogicalType(left: order_type, right: end_type); |
263 | } |
264 | |
265 | // Cast all three to match |
266 | bound_order = BoundCastExpression::AddCastToType(context, expr: std::move(bound_order), target_type: order_type); |
267 | start_type = end_type = order_type; |
268 | } |
269 | |
270 | for (auto &order : window.orders) { |
271 | auto type = config.ResolveOrder(order_type: order.type); |
272 | auto null_order = config.ResolveNullOrder(order_type: type, null_type: order.null_order); |
273 | auto expression = GetExpression(expr&: order.expression); |
274 | result->orders.emplace_back(args&: type, args&: null_order, args: std::move(expression)); |
275 | } |
276 | |
277 | result->filter_expr = CastWindowExpression(expr&: window.filter_expr, type: LogicalType::BOOLEAN); |
278 | |
279 | result->start_expr = CastWindowExpression(expr&: window.start_expr, type: start_type); |
280 | result->end_expr = CastWindowExpression(expr&: window.end_expr, type: end_type); |
281 | result->offset_expr = CastWindowExpression(expr&: window.offset_expr, type: LogicalType::BIGINT); |
282 | result->default_expr = CastWindowExpression(expr&: window.default_expr, type: result->return_type); |
283 | result->start = window.start; |
284 | result->end = window.end; |
285 | |
286 | // create a BoundColumnRef that references this entry |
287 | auto colref = make_uniq<BoundColumnRefExpression>(args: std::move(name), args&: result->return_type, |
288 | args: ColumnBinding(node.window_index, node.windows.size()), args&: depth); |
289 | // move the WINDOW expression into the set of bound windows |
290 | node.windows.push_back(x: std::move(result)); |
291 | return BindResult(std::move(colref)); |
292 | } |
293 | |
294 | } // namespace duckdb |
295 | |