1 | #include "duckdb/planner/expression_binder.hpp" |
2 | |
3 | #include "duckdb/parser/expression/list.hpp" |
4 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
5 | #include "duckdb/planner/binder.hpp" |
6 | #include "duckdb/planner/expression/list.hpp" |
7 | #include "duckdb/planner/expression_iterator.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) |
12 | : binder(binder), context(context) { |
13 | if (replace_binder) { |
14 | stored_binder = &binder.GetActiveBinder(); |
15 | binder.SetActiveBinder(*this); |
16 | } else { |
17 | binder.PushExpressionBinder(binder&: *this); |
18 | } |
19 | } |
20 | |
21 | ExpressionBinder::~ExpressionBinder() { |
22 | if (binder.HasActiveBinder()) { |
23 | if (stored_binder) { |
24 | binder.SetActiveBinder(*stored_binder); |
25 | } else { |
26 | binder.PopExpressionBinder(); |
27 | } |
28 | } |
29 | } |
30 | |
31 | BindResult ExpressionBinder::BindExpression(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) { |
32 | auto &expr_ref = *expr; |
33 | switch (expr_ref.expression_class) { |
34 | case ExpressionClass::BETWEEN: |
35 | return BindExpression(expr&: expr_ref.Cast<BetweenExpression>(), depth); |
36 | case ExpressionClass::CASE: |
37 | return BindExpression(expr&: expr_ref.Cast<CaseExpression>(), depth); |
38 | case ExpressionClass::CAST: |
39 | return BindExpression(expr&: expr_ref.Cast<CastExpression>(), depth); |
40 | case ExpressionClass::COLLATE: |
41 | return BindExpression(expr&: expr_ref.Cast<CollateExpression>(), depth); |
42 | case ExpressionClass::COLUMN_REF: |
43 | return BindExpression(expr&: expr_ref.Cast<ColumnRefExpression>(), depth); |
44 | case ExpressionClass::COMPARISON: |
45 | return BindExpression(expr&: expr_ref.Cast<ComparisonExpression>(), depth); |
46 | case ExpressionClass::CONJUNCTION: |
47 | return BindExpression(expr&: expr_ref.Cast<ConjunctionExpression>(), depth); |
48 | case ExpressionClass::CONSTANT: |
49 | return BindExpression(expr&: expr_ref.Cast<ConstantExpression>(), depth); |
50 | case ExpressionClass::FUNCTION: { |
51 | auto &function = expr_ref.Cast<FunctionExpression>(); |
52 | if (function.function_name == "unnest" || function.function_name == "unlist" ) { |
53 | // special case, not in catalog |
54 | return BindUnnest(expr&: function, depth, root_expression); |
55 | } |
56 | // binding function expression has extra parameter needed for macro's |
57 | return BindExpression(expr&: function, depth, expr_ptr&: expr); |
58 | } |
59 | case ExpressionClass::LAMBDA: |
60 | return BindExpression(expr&: expr_ref.Cast<LambdaExpression>(), depth, is_lambda: false, list_child_type: LogicalTypeId::INVALID); |
61 | case ExpressionClass::OPERATOR: |
62 | return BindExpression(expr&: expr_ref.Cast<OperatorExpression>(), depth); |
63 | case ExpressionClass::SUBQUERY: |
64 | return BindExpression(expr&: expr_ref.Cast<SubqueryExpression>(), depth); |
65 | case ExpressionClass::PARAMETER: |
66 | return BindExpression(expr&: expr_ref.Cast<ParameterExpression>(), depth); |
67 | case ExpressionClass::POSITIONAL_REFERENCE: { |
68 | return BindPositionalReference(expr, depth, root_expression); |
69 | } |
70 | case ExpressionClass::STAR: |
71 | return BindResult(binder.FormatError(expr_context&: expr_ref, message: "STAR expression is not supported here" )); |
72 | default: |
73 | throw NotImplementedException("Unimplemented expression class" ); |
74 | } |
75 | } |
76 | |
77 | bool ExpressionBinder::BindCorrelatedColumns(unique_ptr<ParsedExpression> &expr) { |
78 | // try to bind in one of the outer queries, if the binding error occurred in a subquery |
79 | auto &active_binders = binder.GetActiveBinders(); |
80 | // make a copy of the set of binders, so we can restore it later |
81 | auto binders = active_binders; |
82 | active_binders.pop_back(); |
83 | idx_t depth = 1; |
84 | bool success = false; |
85 | while (!active_binders.empty()) { |
86 | auto &next_binder = active_binders.back().get(); |
87 | ExpressionBinder::QualifyColumnNames(binder&: next_binder.binder, expr); |
88 | auto bind_result = next_binder.Bind(expr, depth); |
89 | if (bind_result.empty()) { |
90 | success = true; |
91 | break; |
92 | } |
93 | depth++; |
94 | active_binders.pop_back(); |
95 | } |
96 | active_binders = binders; |
97 | return success; |
98 | } |
99 | |
100 | void ExpressionBinder::BindChild(unique_ptr<ParsedExpression> &expr, idx_t depth, string &error) { |
101 | if (expr) { |
102 | string bind_error = Bind(expr, depth); |
103 | if (error.empty()) { |
104 | error = bind_error; |
105 | } |
106 | } |
107 | } |
108 | |
109 | void ExpressionBinder::(Binder &binder, Expression &expr) { |
110 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
111 | auto &bound_colref = expr.Cast<BoundColumnRefExpression>(); |
112 | if (bound_colref.depth > 0) { |
113 | binder.AddCorrelatedColumn(info: CorrelatedColumnInfo(bound_colref)); |
114 | } |
115 | } |
116 | ExpressionIterator::EnumerateChildren(expr, |
117 | callback: [&](Expression &child) { ExtractCorrelatedExpressions(binder, expr&: child); }); |
118 | } |
119 | |
120 | bool ExpressionBinder::ContainsType(const LogicalType &type, LogicalTypeId target) { |
121 | if (type.id() == target) { |
122 | return true; |
123 | } |
124 | switch (type.id()) { |
125 | case LogicalTypeId::STRUCT: { |
126 | auto child_count = StructType::GetChildCount(type); |
127 | for (idx_t i = 0; i < child_count; i++) { |
128 | if (ContainsType(type: StructType::GetChildType(type, index: i), target)) { |
129 | return true; |
130 | } |
131 | } |
132 | return false; |
133 | } |
134 | case LogicalTypeId::UNION: { |
135 | auto member_count = UnionType::GetMemberCount(type); |
136 | for (idx_t i = 0; i < member_count; i++) { |
137 | if (ContainsType(type: UnionType::GetMemberType(type, index: i), target)) { |
138 | return true; |
139 | } |
140 | } |
141 | return false; |
142 | } |
143 | case LogicalTypeId::LIST: |
144 | case LogicalTypeId::MAP: |
145 | return ContainsType(type: ListType::GetChildType(type), target); |
146 | default: |
147 | return false; |
148 | } |
149 | } |
150 | |
151 | LogicalType ExpressionBinder::ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type) { |
152 | if (type.id() == target) { |
153 | return new_type; |
154 | } |
155 | switch (type.id()) { |
156 | case LogicalTypeId::STRUCT: { |
157 | // we make a copy of the child types of the struct here |
158 | auto child_types = StructType::GetChildTypes(type); |
159 | for (auto &child_type : child_types) { |
160 | child_type.second = ExchangeType(type: child_type.second, target, new_type); |
161 | } |
162 | return LogicalType::STRUCT(children: child_types); |
163 | } |
164 | case LogicalTypeId::UNION: { |
165 | auto member_types = UnionType::CopyMemberTypes(type); |
166 | for (auto &member_type : member_types) { |
167 | member_type.second = ExchangeType(type: member_type.second, target, new_type); |
168 | } |
169 | return LogicalType::UNION(members: std::move(member_types)); |
170 | } |
171 | case LogicalTypeId::LIST: |
172 | return LogicalType::LIST(child: ExchangeType(type: ListType::GetChildType(type), target, new_type)); |
173 | case LogicalTypeId::MAP: |
174 | return LogicalType::MAP(child: ExchangeType(type: ListType::GetChildType(type), target, new_type)); |
175 | default: |
176 | return type; |
177 | } |
178 | } |
179 | |
180 | bool ExpressionBinder::ContainsNullType(const LogicalType &type) { |
181 | return ContainsType(type, target: LogicalTypeId::SQLNULL); |
182 | } |
183 | |
184 | LogicalType ExpressionBinder::ExchangeNullType(const LogicalType &type) { |
185 | return ExchangeType(type, target: LogicalTypeId::SQLNULL, new_type: LogicalType::INTEGER); |
186 | } |
187 | |
188 | unique_ptr<Expression> ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, optional_ptr<LogicalType> result_type, |
189 | bool root_expression) { |
190 | // bind the main expression |
191 | auto error_msg = Bind(expr, depth: 0, root_expression); |
192 | if (!error_msg.empty()) { |
193 | // failed to bind: try to bind correlated columns in the expression (if any) |
194 | bool success = BindCorrelatedColumns(expr); |
195 | if (!success) { |
196 | throw BinderException(error_msg); |
197 | } |
198 | auto &bound_expr = expr->Cast<BoundExpression>(); |
199 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr.expr); |
200 | } |
201 | auto &bound_expr = expr->Cast<BoundExpression>(); |
202 | unique_ptr<Expression> result = std::move(bound_expr.expr); |
203 | if (target_type.id() != LogicalTypeId::INVALID) { |
204 | // the binder has a specific target type: add a cast to that type |
205 | result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type); |
206 | } else { |
207 | if (!binder.can_contain_nulls) { |
208 | // SQL NULL type is only used internally in the binder |
209 | // cast to INTEGER if we encounter it outside of the binder |
210 | if (ContainsNullType(type: result->return_type)) { |
211 | auto exchanged_type = ExchangeNullType(type: result->return_type); |
212 | result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type: exchanged_type); |
213 | } |
214 | } |
215 | if (result->return_type.id() == LogicalTypeId::UNKNOWN) { |
216 | throw ParameterNotResolvedException(); |
217 | } |
218 | } |
219 | if (result_type) { |
220 | *result_type = result->return_type; |
221 | } |
222 | return result; |
223 | } |
224 | |
225 | string ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) { |
226 | // bind the node, but only if it has not been bound yet |
227 | auto &expression = *expr; |
228 | auto alias = expression.alias; |
229 | if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { |
230 | // already bound, don't bind it again |
231 | return string(); |
232 | } |
233 | // bind the expression |
234 | BindResult result = BindExpression(expr, depth, root_expression); |
235 | if (result.HasError()) { |
236 | return result.error; |
237 | } |
238 | // successfully bound: replace the node with a BoundExpression |
239 | expr = make_uniq<BoundExpression>(args: std::move(result.expression)); |
240 | auto &be = expr->Cast<BoundExpression>(); |
241 | be.alias = alias; |
242 | if (!alias.empty()) { |
243 | be.expr->alias = alias; |
244 | } |
245 | return string(); |
246 | } |
247 | |
248 | } // namespace duckdb |
249 | |