1#include "duckdb/parser/expression/window_expression.hpp"
2#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
3#include "duckdb/planner/expression/bound_cast_expression.hpp"
4#include "duckdb/planner/expression/bound_columnref_expression.hpp"
5#include "duckdb/planner/expression/bound_function_expression.hpp"
6#include "duckdb/planner/expression/bound_window_expression.hpp"
7#include "duckdb/planner/expression_binder/select_binder.hpp"
8#include "duckdb/planner/query_node/bound_select_node.hpp"
9#include "duckdb/planner/binder.hpp"
10#include "duckdb/main/config.hpp"
11#include "duckdb/function/scalar_function.hpp"
12#include "duckdb/function/function_binder.hpp"
13
14#include "duckdb/catalog/catalog.hpp"
15#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
16
17namespace duckdb {
18
19static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector<LogicalType> &child_types) {
20
21 idx_t param_count;
22 switch (window_type) {
23 case ExpressionType::WINDOW_RANK:
24 case ExpressionType::WINDOW_RANK_DENSE:
25 case ExpressionType::WINDOW_ROW_NUMBER:
26 case ExpressionType::WINDOW_PERCENT_RANK:
27 case ExpressionType::WINDOW_CUME_DIST:
28 param_count = 0;
29 break;
30 case ExpressionType::WINDOW_NTILE:
31 case ExpressionType::WINDOW_FIRST_VALUE:
32 case ExpressionType::WINDOW_LAST_VALUE:
33 case ExpressionType::WINDOW_LEAD:
34 case ExpressionType::WINDOW_LAG:
35 param_count = 1;
36 break;
37 case ExpressionType::WINDOW_NTH_VALUE:
38 param_count = 2;
39 break;
40 default:
41 throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(type: window_type));
42 }
43 if (child_types.size() != param_count) {
44 throw BinderException("%s needs %d parameter%s, got %d", ExpressionTypeToString(type: window_type), param_count,
45 param_count == 1 ? "" : "s", child_types.size());
46 }
47 switch (window_type) {
48 case ExpressionType::WINDOW_PERCENT_RANK:
49 case ExpressionType::WINDOW_CUME_DIST:
50 return LogicalType(LogicalTypeId::DOUBLE);
51 case ExpressionType::WINDOW_ROW_NUMBER:
52 case ExpressionType::WINDOW_RANK:
53 case ExpressionType::WINDOW_RANK_DENSE:
54 case ExpressionType::WINDOW_NTILE:
55 return LogicalType::BIGINT;
56 case ExpressionType::WINDOW_NTH_VALUE:
57 case ExpressionType::WINDOW_FIRST_VALUE:
58 case ExpressionType::WINDOW_LAST_VALUE:
59 case ExpressionType::WINDOW_LEAD:
60 case ExpressionType::WINDOW_LAG:
61 return child_types[0];
62 default:
63 throw InternalException("Unrecognized window expression type " + ExpressionTypeToString(type: window_type));
64 }
65}
66
67static unique_ptr<Expression> GetExpression(unique_ptr<ParsedExpression> &expr) {
68 if (!expr) {
69 return nullptr;
70 }
71 D_ASSERT(expr.get());
72 D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
73 return std::move(BoundExpression::GetExpression(expr&: *expr));
74}
75
76static unique_ptr<Expression> CastWindowExpression(unique_ptr<ParsedExpression> &expr, const LogicalType &type) {
77 if (!expr) {
78 return nullptr;
79 }
80 D_ASSERT(expr.get());
81 D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
82
83 auto &bound = BoundExpression::GetExpression(expr&: *expr);
84 bound = BoundCastExpression::AddDefaultCastToType(expr: std::move(bound), target_type: type);
85
86 return std::move(bound);
87}
88
89static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr<ParsedExpression> &expr,
90 unique_ptr<ParsedExpression> &order_expr) {
91
92 vector<unique_ptr<Expression>> children;
93
94 D_ASSERT(order_expr.get());
95 D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
96 auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr);
97 children.emplace_back(args: bound_order->Copy());
98
99 D_ASSERT(expr.get());
100 D_ASSERT(expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
101 auto &bound = BoundExpression::GetExpression(expr&: *expr);
102 children.emplace_back(args: std::move(bound));
103
104 string error;
105 FunctionBinder function_binder(context);
106 auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, children: std::move(children), error, is_operator: true);
107 if (!function) {
108 throw BinderException(error);
109 }
110 bound = std::move(function);
111 return bound->return_type;
112}
113
114BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) {
115 auto name = window.GetName();
116
117 QueryErrorContext error_context(binder.GetRootStatement(), window.query_location);
118 if (inside_window) {
119 throw BinderException(error_context.FormatError(msg: "window function calls cannot be nested"));
120 }
121 if (depth > 0) {
122 throw BinderException(error_context.FormatError(msg: "correlated columns in window functions not supported"));
123 }
124 // If we have range expressions, then only one order by clause is allowed.
125 if ((window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE ||
126 window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) &&
127 window.orders.size() != 1) {
128 throw BinderException(error_context.FormatError(msg: "RANGE frames must have only one ORDER BY expression"));
129 }
130 // bind inside the children of the window function
131 // we set the inside_window flag to true to prevent binding nested window functions
132 this->inside_window = true;
133 string error;
134 for (auto &child : window.children) {
135 BindChild(expr&: child, depth, error);
136 }
137 for (auto &child : window.partitions) {
138 BindChild(expr&: child, depth, error);
139 }
140 for (auto &order : window.orders) {
141 BindChild(expr&: order.expression, depth, error);
142 }
143 BindChild(expr&: window.filter_expr, depth, error);
144 BindChild(expr&: window.start_expr, depth, error);
145 BindChild(expr&: window.end_expr, depth, error);
146 BindChild(expr&: window.offset_expr, depth, error);
147 BindChild(expr&: window.default_expr, depth, error);
148
149 this->inside_window = false;
150 if (!error.empty()) {
151 // failed to bind children of window function
152 return BindResult(error);
153 }
154 // successfully bound all children: create bound window function
155 vector<LogicalType> types;
156 vector<unique_ptr<Expression>> children;
157 for (auto &child : window.children) {
158 D_ASSERT(child.get());
159 D_ASSERT(child->expression_class == ExpressionClass::BOUND_EXPRESSION);
160 auto &bound = BoundExpression::GetExpression(expr&: *child);
161 // Add casts for positional arguments
162 const auto argno = children.size();
163 switch (window.type) {
164 case ExpressionType::WINDOW_NTILE:
165 // ntile(bigint)
166 if (argno == 0) {
167 bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT);
168 }
169 break;
170 case ExpressionType::WINDOW_NTH_VALUE:
171 // nth_value(<expr>, index)
172 if (argno == 1) {
173 bound = BoundCastExpression::AddCastToType(context, expr: std::move(bound), target_type: LogicalType::BIGINT);
174 }
175 default:
176 break;
177 }
178 types.push_back(x: bound->return_type);
179 children.push_back(x: std::move(bound));
180 }
181 // Determine the function type.
182 LogicalType sql_type;
183 unique_ptr<AggregateFunction> aggregate;
184 unique_ptr<FunctionData> bind_info;
185 if (window.type == ExpressionType::WINDOW_AGGREGATE) {
186 // Look up the aggregate function in the catalog
187 auto &func = Catalog::GetEntry<AggregateFunctionCatalogEntry>(context, catalog_name: window.catalog, schema_name: window.schema,
188 name: window.function_name, error_context);
189 D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY);
190
191 // bind the aggregate
192 string error;
193 FunctionBinder function_binder(context);
194 auto best_function = function_binder.BindFunction(name: func.name, functions&: func.functions, arguments: types, error);
195 if (best_function == DConstants::INVALID_INDEX) {
196 throw BinderException(binder.FormatError(expr_context&: window, message: error));
197 }
198 // found a matching function! bind it as an aggregate
199 auto bound_function = func.functions.GetFunctionByOffset(offset: best_function);
200 auto bound_aggregate = function_binder.BindAggregateFunction(bound_function, children: std::move(children));
201 // create the aggregate
202 aggregate = make_uniq<AggregateFunction>(args&: bound_aggregate->function);
203 bind_info = std::move(bound_aggregate->bind_info);
204 children = std::move(bound_aggregate->children);
205 sql_type = bound_aggregate->return_type;
206 } else {
207 // fetch the child of the non-aggregate window function (if any)
208 sql_type = ResolveWindowExpressionType(window_type: window.type, child_types: types);
209 }
210 auto result = make_uniq<BoundWindowExpression>(args&: window.type, args&: sql_type, args: std::move(aggregate), args: std::move(bind_info));
211 result->children = std::move(children);
212 for (auto &child : window.partitions) {
213 result->partitions.push_back(x: GetExpression(expr&: child));
214 }
215 result->ignore_nulls = window.ignore_nulls;
216
217 // Convert RANGE boundary expressions to ORDER +/- expressions.
218 // Note that PRECEEDING and FOLLOWING refer to the sequential order in the frame,
219 // not the natural ordering of the type. This means that the offset arithmetic must be reversed
220 // for ORDER BY DESC.
221 auto &config = DBConfig::GetConfig(context);
222 auto range_sense = OrderType::INVALID;
223 LogicalType start_type = LogicalType::BIGINT;
224 if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) {
225 D_ASSERT(window.orders.size() == 1);
226 range_sense = config.ResolveOrder(order_type: window.orders[0].type);
227 const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+";
228 start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression);
229 } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) {
230 D_ASSERT(window.orders.size() == 1);
231 range_sense = config.ResolveOrder(order_type: window.orders[0].type);
232 const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-";
233 start_type = BindRangeExpression(context, name, expr&: window.start_expr, order_expr&: window.orders[0].expression);
234 }
235
236 LogicalType end_type = LogicalType::BIGINT;
237 if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) {
238 D_ASSERT(window.orders.size() == 1);
239 range_sense = config.ResolveOrder(order_type: window.orders[0].type);
240 const auto name = (range_sense == OrderType::ASCENDING) ? "-" : "+";
241 end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression);
242 } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) {
243 D_ASSERT(window.orders.size() == 1);
244 range_sense = config.ResolveOrder(order_type: window.orders[0].type);
245 const auto name = (range_sense == OrderType::ASCENDING) ? "+" : "-";
246 end_type = BindRangeExpression(context, name, expr&: window.end_expr, order_expr&: window.orders[0].expression);
247 }
248
249 // Cast ORDER and boundary expressions to the same type
250 if (range_sense != OrderType::INVALID) {
251 D_ASSERT(window.orders.size() == 1);
252
253 auto &order_expr = window.orders[0].expression;
254 D_ASSERT(order_expr.get());
255 D_ASSERT(order_expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
256 auto &bound_order = BoundExpression::GetExpression(expr&: *order_expr);
257 auto order_type = bound_order->return_type;
258 if (window.start_expr) {
259 order_type = LogicalType::MaxLogicalType(left: order_type, right: start_type);
260 }
261 if (window.end_expr) {
262 order_type = LogicalType::MaxLogicalType(left: order_type, right: end_type);
263 }
264
265 // Cast all three to match
266 bound_order = BoundCastExpression::AddCastToType(context, expr: std::move(bound_order), target_type: order_type);
267 start_type = end_type = order_type;
268 }
269
270 for (auto &order : window.orders) {
271 auto type = config.ResolveOrder(order_type: order.type);
272 auto null_order = config.ResolveNullOrder(order_type: type, null_type: order.null_order);
273 auto expression = GetExpression(expr&: order.expression);
274 result->orders.emplace_back(args&: type, args&: null_order, args: std::move(expression));
275 }
276
277 result->filter_expr = CastWindowExpression(expr&: window.filter_expr, type: LogicalType::BOOLEAN);
278
279 result->start_expr = CastWindowExpression(expr&: window.start_expr, type: start_type);
280 result->end_expr = CastWindowExpression(expr&: window.end_expr, type: end_type);
281 result->offset_expr = CastWindowExpression(expr&: window.offset_expr, type: LogicalType::BIGINT);
282 result->default_expr = CastWindowExpression(expr&: window.default_expr, type: result->return_type);
283 result->start = window.start;
284 result->end = window.end;
285
286 // create a BoundColumnRef that references this entry
287 auto colref = make_uniq<BoundColumnRefExpression>(args: std::move(name), args&: result->return_type,
288 args: ColumnBinding(node.window_index, node.windows.size()), args&: depth);
289 // move the WINDOW expression into the set of bound windows
290 node.windows.push_back(x: std::move(result));
291 return BindResult(std::move(colref));
292}
293
294} // namespace duckdb
295