1#include "duckdb/parser/expression/window_expression.hpp"
2#include "duckdb/planner/expression/bound_columnref_expression.hpp"
3#include "duckdb/planner/expression/bound_window_expression.hpp"
4#include "duckdb/planner/expression_binder/select_binder.hpp"
5#include "duckdb/planner/query_node/bound_select_node.hpp"
6
7#include "duckdb/catalog/catalog.hpp"
8#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
9
10using namespace duckdb;
11using namespace std;
12
13static SQLType ResolveWindowExpressionType(ExpressionType window_type, SQLType child_type) {
14 switch (window_type) {
15 case ExpressionType::WINDOW_PERCENT_RANK:
16 case ExpressionType::WINDOW_CUME_DIST:
17 return SQLType(SQLTypeId::DECIMAL);
18 case ExpressionType::WINDOW_ROW_NUMBER:
19 case ExpressionType::WINDOW_RANK:
20 case ExpressionType::WINDOW_RANK_DENSE:
21 case ExpressionType::WINDOW_NTILE:
22 return SQLType::BIGINT;
23 case ExpressionType::WINDOW_FIRST_VALUE:
24 case ExpressionType::WINDOW_LAST_VALUE:
25 assert(child_type.id != SQLTypeId::INVALID); // "Window function needs an expression"
26 return child_type;
27 case ExpressionType::WINDOW_LEAD:
28 default:
29 assert(window_type == ExpressionType::WINDOW_LAG || window_type == ExpressionType::WINDOW_LEAD);
30 assert(child_type.id != SQLTypeId::INVALID); // "Window function needs an expression"
31 return child_type;
32 }
33}
34
35static unique_ptr<Expression> GetExpression(unique_ptr<ParsedExpression> &expr) {
36 if (!expr) {
37 return nullptr;
38 }
39 assert(expr.get());
40 assert(expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
41 return move(((BoundExpression &)*expr).expr);
42}
43
44BindResult SelectBinder::BindWindow(WindowExpression &window, idx_t depth) {
45 if (inside_window) {
46 throw BinderException("window function calls cannot be nested");
47 }
48 if (depth > 0) {
49 throw BinderException("correlated columns in window functions not supported");
50 }
51 // bind inside the children of the window function
52 // we set the inside_window flag to true to prevent binding nested window functions
53 this->inside_window = true;
54 string error;
55 for (auto &child : window.children) {
56 BindChild(child, depth, error);
57 }
58 for (auto &child : window.partitions) {
59 BindChild(child, depth, error);
60 }
61 for (auto &order : window.orders) {
62 BindChild(order.expression, depth, error);
63 }
64 BindChild(window.start_expr, depth, error);
65 BindChild(window.end_expr, depth, error);
66 BindChild(window.offset_expr, depth, error);
67 BindChild(window.default_expr, depth, error);
68 this->inside_window = false;
69 if (!error.empty()) {
70 // failed to bind children of window function
71 return BindResult(error);
72 }
73 // successfully bound all children: create bound window function
74 vector<SQLType> types;
75 vector<unique_ptr<Expression>> children;
76 for (auto &child : window.children) {
77 assert(child.get());
78 assert(child->expression_class == ExpressionClass::BOUND_EXPRESSION);
79 auto &bound = (BoundExpression &)*child;
80 types.push_back(bound.sql_type);
81 children.push_back(GetExpression(child));
82 }
83 // Determine the function type.
84 SQLType sql_type;
85 unique_ptr<AggregateFunction> aggregate;
86 if (window.type == ExpressionType::WINDOW_AGGREGATE) {
87 // Look up the aggregate function in the catalog
88 auto func =
89 (AggregateFunctionCatalogEntry *)Catalog::GetCatalog(context).GetEntry<AggregateFunctionCatalogEntry>(
90 context, window.schema, window.function_name);
91 if (func->type != CatalogType::AGGREGATE_FUNCTION) {
92 throw BinderException("Unknown windowed aggregate");
93 }
94 // bind the aggregate
95 auto best_function = Function::BindFunction(func->name, func->functions, types);
96 // found a matching function!
97 auto &bound_function = func->functions[best_function];
98 // check if we need to add casts to the children
99 bound_function.CastToFunctionArguments(children, types);
100 // create the aggregate
101 aggregate = make_unique<AggregateFunction>(func->functions[best_function]);
102 sql_type = aggregate->return_type;
103 } else {
104 // fetch the child of the non-aggregate window function (if any)
105 sql_type = ResolveWindowExpressionType(window.type, types.empty() ? SQLType() : types[0]);
106 }
107 auto result = make_unique<BoundWindowExpression>(window.type, GetInternalType(sql_type), move(aggregate));
108 result->children = move(children);
109 for (auto &child : window.partitions) {
110 result->partitions.push_back(GetExpression(child));
111 }
112 for (auto &order : window.orders) {
113 BoundOrderByNode bound_order;
114 bound_order.expression = GetExpression(order.expression);
115 bound_order.type = order.type;
116 result->orders.push_back(move(bound_order));
117 }
118 result->start_expr = GetExpression(window.start_expr);
119 result->end_expr = GetExpression(window.end_expr);
120 result->offset_expr = GetExpression(window.offset_expr);
121 result->default_expr = GetExpression(window.default_expr);
122 result->start = window.start;
123 result->end = window.end;
124
125 // create a BoundColumnRef that references this entry
126 auto colref = make_unique<BoundColumnRefExpression>(window.GetName(), result->return_type,
127 ColumnBinding(node.window_index, node.windows.size()), depth);
128 // move the WINDOW expression into the set of bound windows
129 node.windows.push_back(move(result));
130 return BindResult(move(colref), sql_type);
131}
132