1#include "duckdb/planner/expression_binder.hpp"
2
3#include "duckdb/parser/expression/columnref_expression.hpp"
4#include "duckdb/parser/expression/subquery_expression.hpp"
5#include "duckdb/parser/parsed_expression_iterator.hpp"
6#include "duckdb/planner/binder.hpp"
7#include "duckdb/planner/expression/bound_cast_expression.hpp"
8#include "duckdb/planner/expression/bound_default_expression.hpp"
9#include "duckdb/planner/expression/bound_parameter_expression.hpp"
10#include "duckdb/planner/expression/bound_subquery_expression.hpp"
11#include "duckdb/planner/expression_iterator.hpp"
12
13using namespace duckdb;
14using namespace std;
15
16ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder)
17 : binder(binder), context(context), stored_binder(nullptr) {
18 if (replace_binder) {
19 stored_binder = binder.GetActiveBinder();
20 binder.SetActiveBinder(this);
21 } else {
22 binder.PushExpressionBinder(this);
23 }
24}
25
26ExpressionBinder::~ExpressionBinder() {
27 if (binder.HasActiveBinder()) {
28 if (stored_binder) {
29 binder.SetActiveBinder(stored_binder);
30 } else {
31 binder.PopExpressionBinder();
32 }
33 }
34}
35
36BindResult ExpressionBinder::BindExpression(ParsedExpression &expr, idx_t depth, bool root_expression) {
37 switch (expr.expression_class) {
38 case ExpressionClass::CASE:
39 return BindExpression((CaseExpression &)expr, depth);
40 case ExpressionClass::CAST:
41 return BindExpression((CastExpression &)expr, depth);
42 case ExpressionClass::COLLATE:
43 return BindExpression((CollateExpression &)expr, depth);
44 case ExpressionClass::COLUMN_REF:
45 return BindExpression((ColumnRefExpression &)expr, depth);
46 case ExpressionClass::COMPARISON:
47 return BindExpression((ComparisonExpression &)expr, depth);
48 case ExpressionClass::CONJUNCTION:
49 return BindExpression((ConjunctionExpression &)expr, depth);
50 case ExpressionClass::CONSTANT:
51 return BindExpression((ConstantExpression &)expr, depth);
52 case ExpressionClass::FUNCTION:
53 return BindExpression((FunctionExpression &)expr, depth);
54 case ExpressionClass::OPERATOR:
55 return BindExpression((OperatorExpression &)expr, depth);
56 case ExpressionClass::SUBQUERY:
57 return BindExpression((SubqueryExpression &)expr, depth);
58 case ExpressionClass::PARAMETER:
59 return BindExpression((ParameterExpression &)expr, depth);
60 default:
61 throw NotImplementedException("Unimplemented expression class");
62 }
63}
64
65bool ExpressionBinder::BindCorrelatedColumns(unique_ptr<ParsedExpression> &expr) {
66 // try to bind in one of the outer queries, if the binding error occurred in a subquery
67 auto &active_binders = binder.GetActiveBinders();
68 // make a copy of the set of binders, so we can restore it later
69 auto binders = active_binders;
70 active_binders.pop_back();
71 idx_t depth = 1;
72 bool success = false;
73 while (active_binders.size() > 0) {
74 auto &next_binder = active_binders.back();
75 ExpressionBinder::BindTableNames(next_binder->binder, *expr);
76 auto bind_result = next_binder->Bind(&expr, depth);
77 if (bind_result.empty()) {
78 success = true;
79 break;
80 }
81 depth++;
82 active_binders.pop_back();
83 }
84 active_binders = binders;
85 return success;
86}
87
88void ExpressionBinder::BindChild(unique_ptr<ParsedExpression> &expr, idx_t depth, string &error) {
89 if (expr.get()) {
90 string bind_error = Bind(&expr, depth);
91 if (error.empty()) {
92 error = bind_error;
93 }
94 }
95}
96
97void ExpressionBinder::ExtractCorrelatedExpressions(Binder &binder, Expression &expr) {
98 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
99 auto &bound_colref = (BoundColumnRefExpression &)expr;
100 if (bound_colref.depth > 0) {
101 binder.AddCorrelatedColumn(CorrelatedColumnInfo(bound_colref));
102 }
103 }
104 ExpressionIterator::EnumerateChildren(expr,
105 [&](Expression &child) { ExtractCorrelatedExpressions(binder, child); });
106}
107
108unique_ptr<Expression> ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, SQLType *result_type,
109 bool root_expression) {
110 // bind the main expression
111 auto error_msg = Bind(&expr, 0, root_expression);
112 if (!error_msg.empty()) {
113 // failed to bind: try to bind correlated columns in the expression (if any)
114 bool success = BindCorrelatedColumns(expr);
115 if (!success) {
116 throw BinderException(error_msg);
117 }
118 auto bound_expr = (BoundExpression *)expr.get();
119 ExtractCorrelatedExpressions(binder, *bound_expr->expr);
120 }
121 assert(expr->expression_class == ExpressionClass::BOUND_EXPRESSION);
122 auto bound_expr = (BoundExpression *)expr.get();
123 unique_ptr<Expression> result = move(bound_expr->expr);
124 if (target_type.id != SQLTypeId::INVALID) {
125 // the binder has a specific target type: add a cast to that type
126 result = BoundCastExpression::AddCastToType(move(result), bound_expr->sql_type, target_type);
127 } else {
128 if (bound_expr->sql_type.id == SQLTypeId::SQLNULL) {
129 // SQL NULL type is only used internally in the binder
130 // cast to INTEGER if we encounter it outside of the binder
131 bound_expr->sql_type = SQLType::INTEGER;
132 result = BoundCastExpression::AddCastToType(move(result), bound_expr->sql_type, bound_expr->sql_type);
133 }
134 }
135 if (result_type) {
136 *result_type = bound_expr->sql_type;
137 }
138 return result;
139}
140
141string ExpressionBinder::Bind(unique_ptr<ParsedExpression> *expr, idx_t depth, bool root_expression) {
142 // bind the node, but only if it has not been bound yet
143 auto &expression = **expr;
144 auto alias = expression.alias;
145 if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) {
146 // already bound, don't bind it again
147 return string();
148 }
149 // bind the expression
150 BindResult result = BindExpression(**expr, depth, root_expression);
151 if (result.HasError()) {
152 return result.error;
153 } else {
154 // successfully bound: replace the node with a BoundExpression
155 *expr = make_unique<BoundExpression>(move(result.expression), move(*expr), result.sql_type);
156 auto be = (BoundExpression *)expr->get();
157 assert(be);
158 be->alias = alias;
159 if (!alias.empty()) {
160 be->expr->alias = alias;
161 }
162 return string();
163 }
164}
165
166void ExpressionBinder::BindTableNames(Binder &binder, ParsedExpression &expr) {
167 if (expr.type == ExpressionType::COLUMN_REF) {
168 auto &colref = (ColumnRefExpression &)expr;
169 if (colref.table_name.empty()) {
170 // no table name: find a binding that contains this
171 colref.table_name = binder.bind_context.GetMatchingBinding(colref.column_name);
172 }
173 }
174 ParsedExpressionIterator::EnumerateChildren(
175 expr, [&](const ParsedExpression &child) { BindTableNames(binder, (ParsedExpression &)child); });
176}
177