1#include "duckdb/planner/expression_binder.hpp"
2
3#include "duckdb/parser/expression/list.hpp"
4#include "duckdb/parser/parsed_expression_iterator.hpp"
5#include "duckdb/planner/binder.hpp"
6#include "duckdb/planner/expression/list.hpp"
7#include "duckdb/planner/expression_iterator.hpp"
8
9namespace duckdb {
10
11ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder)
12 : binder(binder), context(context) {
13 if (replace_binder) {
14 stored_binder = &binder.GetActiveBinder();
15 binder.SetActiveBinder(*this);
16 } else {
17 binder.PushExpressionBinder(binder&: *this);
18 }
19}
20
21ExpressionBinder::~ExpressionBinder() {
22 if (binder.HasActiveBinder()) {
23 if (stored_binder) {
24 binder.SetActiveBinder(*stored_binder);
25 } else {
26 binder.PopExpressionBinder();
27 }
28 }
29}
30
31BindResult ExpressionBinder::BindExpression(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) {
32 auto &expr_ref = *expr;
33 switch (expr_ref.expression_class) {
34 case ExpressionClass::BETWEEN:
35 return BindExpression(expr&: expr_ref.Cast<BetweenExpression>(), depth);
36 case ExpressionClass::CASE:
37 return BindExpression(expr&: expr_ref.Cast<CaseExpression>(), depth);
38 case ExpressionClass::CAST:
39 return BindExpression(expr&: expr_ref.Cast<CastExpression>(), depth);
40 case ExpressionClass::COLLATE:
41 return BindExpression(expr&: expr_ref.Cast<CollateExpression>(), depth);
42 case ExpressionClass::COLUMN_REF:
43 return BindExpression(expr&: expr_ref.Cast<ColumnRefExpression>(), depth);
44 case ExpressionClass::COMPARISON:
45 return BindExpression(expr&: expr_ref.Cast<ComparisonExpression>(), depth);
46 case ExpressionClass::CONJUNCTION:
47 return BindExpression(expr&: expr_ref.Cast<ConjunctionExpression>(), depth);
48 case ExpressionClass::CONSTANT:
49 return BindExpression(expr&: expr_ref.Cast<ConstantExpression>(), depth);
50 case ExpressionClass::FUNCTION: {
51 auto &function = expr_ref.Cast<FunctionExpression>();
52 if (function.function_name == "unnest" || function.function_name == "unlist") {
53 // special case, not in catalog
54 return BindUnnest(expr&: function, depth, root_expression);
55 }
56 // binding function expression has extra parameter needed for macro's
57 return BindExpression(expr&: function, depth, expr_ptr&: expr);
58 }
59 case ExpressionClass::LAMBDA:
60 return BindExpression(expr&: expr_ref.Cast<LambdaExpression>(), depth, is_lambda: false, list_child_type: LogicalTypeId::INVALID);
61 case ExpressionClass::OPERATOR:
62 return BindExpression(expr&: expr_ref.Cast<OperatorExpression>(), depth);
63 case ExpressionClass::SUBQUERY:
64 return BindExpression(expr&: expr_ref.Cast<SubqueryExpression>(), depth);
65 case ExpressionClass::PARAMETER:
66 return BindExpression(expr&: expr_ref.Cast<ParameterExpression>(), depth);
67 case ExpressionClass::POSITIONAL_REFERENCE: {
68 return BindPositionalReference(expr, depth, root_expression);
69 }
70 case ExpressionClass::STAR:
71 return BindResult(binder.FormatError(expr_context&: expr_ref, message: "STAR expression is not supported here"));
72 default:
73 throw NotImplementedException("Unimplemented expression class");
74 }
75}
76
77bool ExpressionBinder::BindCorrelatedColumns(unique_ptr<ParsedExpression> &expr) {
78 // try to bind in one of the outer queries, if the binding error occurred in a subquery
79 auto &active_binders = binder.GetActiveBinders();
80 // make a copy of the set of binders, so we can restore it later
81 auto binders = active_binders;
82 active_binders.pop_back();
83 idx_t depth = 1;
84 bool success = false;
85 while (!active_binders.empty()) {
86 auto &next_binder = active_binders.back().get();
87 ExpressionBinder::QualifyColumnNames(binder&: next_binder.binder, expr);
88 auto bind_result = next_binder.Bind(expr, depth);
89 if (bind_result.empty()) {
90 success = true;
91 break;
92 }
93 depth++;
94 active_binders.pop_back();
95 }
96 active_binders = binders;
97 return success;
98}
99
100void ExpressionBinder::BindChild(unique_ptr<ParsedExpression> &expr, idx_t depth, string &error) {
101 if (expr) {
102 string bind_error = Bind(expr, depth);
103 if (error.empty()) {
104 error = bind_error;
105 }
106 }
107}
108
109void ExpressionBinder::ExtractCorrelatedExpressions(Binder &binder, Expression &expr) {
110 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
111 auto &bound_colref = expr.Cast<BoundColumnRefExpression>();
112 if (bound_colref.depth > 0) {
113 binder.AddCorrelatedColumn(info: CorrelatedColumnInfo(bound_colref));
114 }
115 }
116 ExpressionIterator::EnumerateChildren(expr,
117 callback: [&](Expression &child) { ExtractCorrelatedExpressions(binder, expr&: child); });
118}
119
120bool ExpressionBinder::ContainsType(const LogicalType &type, LogicalTypeId target) {
121 if (type.id() == target) {
122 return true;
123 }
124 switch (type.id()) {
125 case LogicalTypeId::STRUCT: {
126 auto child_count = StructType::GetChildCount(type);
127 for (idx_t i = 0; i < child_count; i++) {
128 if (ContainsType(type: StructType::GetChildType(type, index: i), target)) {
129 return true;
130 }
131 }
132 return false;
133 }
134 case LogicalTypeId::UNION: {
135 auto member_count = UnionType::GetMemberCount(type);
136 for (idx_t i = 0; i < member_count; i++) {
137 if (ContainsType(type: UnionType::GetMemberType(type, index: i), target)) {
138 return true;
139 }
140 }
141 return false;
142 }
143 case LogicalTypeId::LIST:
144 case LogicalTypeId::MAP:
145 return ContainsType(type: ListType::GetChildType(type), target);
146 default:
147 return false;
148 }
149}
150
151LogicalType ExpressionBinder::ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type) {
152 if (type.id() == target) {
153 return new_type;
154 }
155 switch (type.id()) {
156 case LogicalTypeId::STRUCT: {
157 // we make a copy of the child types of the struct here
158 auto child_types = StructType::GetChildTypes(type);
159 for (auto &child_type : child_types) {
160 child_type.second = ExchangeType(type: child_type.second, target, new_type);
161 }
162 return LogicalType::STRUCT(children: child_types);
163 }
164 case LogicalTypeId::UNION: {
165 auto member_types = UnionType::CopyMemberTypes(type);
166 for (auto &member_type : member_types) {
167 member_type.second = ExchangeType(type: member_type.second, target, new_type);
168 }
169 return LogicalType::UNION(members: std::move(member_types));
170 }
171 case LogicalTypeId::LIST:
172 return LogicalType::LIST(child: ExchangeType(type: ListType::GetChildType(type), target, new_type));
173 case LogicalTypeId::MAP:
174 return LogicalType::MAP(child: ExchangeType(type: ListType::GetChildType(type), target, new_type));
175 default:
176 return type;
177 }
178}
179
180bool ExpressionBinder::ContainsNullType(const LogicalType &type) {
181 return ContainsType(type, target: LogicalTypeId::SQLNULL);
182}
183
184LogicalType ExpressionBinder::ExchangeNullType(const LogicalType &type) {
185 return ExchangeType(type, target: LogicalTypeId::SQLNULL, new_type: LogicalType::INTEGER);
186}
187
188unique_ptr<Expression> ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, optional_ptr<LogicalType> result_type,
189 bool root_expression) {
190 // bind the main expression
191 auto error_msg = Bind(expr, depth: 0, root_expression);
192 if (!error_msg.empty()) {
193 // failed to bind: try to bind correlated columns in the expression (if any)
194 bool success = BindCorrelatedColumns(expr);
195 if (!success) {
196 throw BinderException(error_msg);
197 }
198 auto &bound_expr = expr->Cast<BoundExpression>();
199 ExtractCorrelatedExpressions(binder, expr&: *bound_expr.expr);
200 }
201 auto &bound_expr = expr->Cast<BoundExpression>();
202 unique_ptr<Expression> result = std::move(bound_expr.expr);
203 if (target_type.id() != LogicalTypeId::INVALID) {
204 // the binder has a specific target type: add a cast to that type
205 result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type);
206 } else {
207 if (!binder.can_contain_nulls) {
208 // SQL NULL type is only used internally in the binder
209 // cast to INTEGER if we encounter it outside of the binder
210 if (ContainsNullType(type: result->return_type)) {
211 auto exchanged_type = ExchangeNullType(type: result->return_type);
212 result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type: exchanged_type);
213 }
214 }
215 if (result->return_type.id() == LogicalTypeId::UNKNOWN) {
216 throw ParameterNotResolvedException();
217 }
218 }
219 if (result_type) {
220 *result_type = result->return_type;
221 }
222 return result;
223}
224
225string ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) {
226 // bind the node, but only if it has not been bound yet
227 auto &expression = *expr;
228 auto alias = expression.alias;
229 if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) {
230 // already bound, don't bind it again
231 return string();
232 }
233 // bind the expression
234 BindResult result = BindExpression(expr, depth, root_expression);
235 if (result.HasError()) {
236 return result.error;
237 }
238 // successfully bound: replace the node with a BoundExpression
239 expr = make_uniq<BoundExpression>(args: std::move(result.expression));
240 auto &be = expr->Cast<BoundExpression>();
241 be.alias = alias;
242 if (!alias.empty()) {
243 be.expr->alias = alias;
244 }
245 return string();
246}
247
248} // namespace duckdb
249