| 1 | #include "duckdb/planner/expression_binder.hpp" |
| 2 | |
| 3 | #include "duckdb/parser/expression/list.hpp" |
| 4 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
| 5 | #include "duckdb/planner/binder.hpp" |
| 6 | #include "duckdb/planner/expression/list.hpp" |
| 7 | #include "duckdb/planner/expression_iterator.hpp" |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) |
| 12 | : binder(binder), context(context) { |
| 13 | if (replace_binder) { |
| 14 | stored_binder = &binder.GetActiveBinder(); |
| 15 | binder.SetActiveBinder(*this); |
| 16 | } else { |
| 17 | binder.PushExpressionBinder(binder&: *this); |
| 18 | } |
| 19 | } |
| 20 | |
| 21 | ExpressionBinder::~ExpressionBinder() { |
| 22 | if (binder.HasActiveBinder()) { |
| 23 | if (stored_binder) { |
| 24 | binder.SetActiveBinder(*stored_binder); |
| 25 | } else { |
| 26 | binder.PopExpressionBinder(); |
| 27 | } |
| 28 | } |
| 29 | } |
| 30 | |
| 31 | BindResult ExpressionBinder::BindExpression(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) { |
| 32 | auto &expr_ref = *expr; |
| 33 | switch (expr_ref.expression_class) { |
| 34 | case ExpressionClass::BETWEEN: |
| 35 | return BindExpression(expr&: expr_ref.Cast<BetweenExpression>(), depth); |
| 36 | case ExpressionClass::CASE: |
| 37 | return BindExpression(expr&: expr_ref.Cast<CaseExpression>(), depth); |
| 38 | case ExpressionClass::CAST: |
| 39 | return BindExpression(expr&: expr_ref.Cast<CastExpression>(), depth); |
| 40 | case ExpressionClass::COLLATE: |
| 41 | return BindExpression(expr&: expr_ref.Cast<CollateExpression>(), depth); |
| 42 | case ExpressionClass::COLUMN_REF: |
| 43 | return BindExpression(expr&: expr_ref.Cast<ColumnRefExpression>(), depth); |
| 44 | case ExpressionClass::COMPARISON: |
| 45 | return BindExpression(expr&: expr_ref.Cast<ComparisonExpression>(), depth); |
| 46 | case ExpressionClass::CONJUNCTION: |
| 47 | return BindExpression(expr&: expr_ref.Cast<ConjunctionExpression>(), depth); |
| 48 | case ExpressionClass::CONSTANT: |
| 49 | return BindExpression(expr&: expr_ref.Cast<ConstantExpression>(), depth); |
| 50 | case ExpressionClass::FUNCTION: { |
| 51 | auto &function = expr_ref.Cast<FunctionExpression>(); |
| 52 | if (function.function_name == "unnest" || function.function_name == "unlist" ) { |
| 53 | // special case, not in catalog |
| 54 | return BindUnnest(expr&: function, depth, root_expression); |
| 55 | } |
| 56 | // binding function expression has extra parameter needed for macro's |
| 57 | return BindExpression(expr&: function, depth, expr_ptr&: expr); |
| 58 | } |
| 59 | case ExpressionClass::LAMBDA: |
| 60 | return BindExpression(expr&: expr_ref.Cast<LambdaExpression>(), depth, is_lambda: false, list_child_type: LogicalTypeId::INVALID); |
| 61 | case ExpressionClass::OPERATOR: |
| 62 | return BindExpression(expr&: expr_ref.Cast<OperatorExpression>(), depth); |
| 63 | case ExpressionClass::SUBQUERY: |
| 64 | return BindExpression(expr&: expr_ref.Cast<SubqueryExpression>(), depth); |
| 65 | case ExpressionClass::PARAMETER: |
| 66 | return BindExpression(expr&: expr_ref.Cast<ParameterExpression>(), depth); |
| 67 | case ExpressionClass::POSITIONAL_REFERENCE: { |
| 68 | return BindPositionalReference(expr, depth, root_expression); |
| 69 | } |
| 70 | case ExpressionClass::STAR: |
| 71 | return BindResult(binder.FormatError(expr_context&: expr_ref, message: "STAR expression is not supported here" )); |
| 72 | default: |
| 73 | throw NotImplementedException("Unimplemented expression class" ); |
| 74 | } |
| 75 | } |
| 76 | |
| 77 | bool ExpressionBinder::BindCorrelatedColumns(unique_ptr<ParsedExpression> &expr) { |
| 78 | // try to bind in one of the outer queries, if the binding error occurred in a subquery |
| 79 | auto &active_binders = binder.GetActiveBinders(); |
| 80 | // make a copy of the set of binders, so we can restore it later |
| 81 | auto binders = active_binders; |
| 82 | active_binders.pop_back(); |
| 83 | idx_t depth = 1; |
| 84 | bool success = false; |
| 85 | while (!active_binders.empty()) { |
| 86 | auto &next_binder = active_binders.back().get(); |
| 87 | ExpressionBinder::QualifyColumnNames(binder&: next_binder.binder, expr); |
| 88 | auto bind_result = next_binder.Bind(expr, depth); |
| 89 | if (bind_result.empty()) { |
| 90 | success = true; |
| 91 | break; |
| 92 | } |
| 93 | depth++; |
| 94 | active_binders.pop_back(); |
| 95 | } |
| 96 | active_binders = binders; |
| 97 | return success; |
| 98 | } |
| 99 | |
| 100 | void ExpressionBinder::BindChild(unique_ptr<ParsedExpression> &expr, idx_t depth, string &error) { |
| 101 | if (expr) { |
| 102 | string bind_error = Bind(expr, depth); |
| 103 | if (error.empty()) { |
| 104 | error = bind_error; |
| 105 | } |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | void ExpressionBinder::(Binder &binder, Expression &expr) { |
| 110 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
| 111 | auto &bound_colref = expr.Cast<BoundColumnRefExpression>(); |
| 112 | if (bound_colref.depth > 0) { |
| 113 | binder.AddCorrelatedColumn(info: CorrelatedColumnInfo(bound_colref)); |
| 114 | } |
| 115 | } |
| 116 | ExpressionIterator::EnumerateChildren(expr, |
| 117 | callback: [&](Expression &child) { ExtractCorrelatedExpressions(binder, expr&: child); }); |
| 118 | } |
| 119 | |
| 120 | bool ExpressionBinder::ContainsType(const LogicalType &type, LogicalTypeId target) { |
| 121 | if (type.id() == target) { |
| 122 | return true; |
| 123 | } |
| 124 | switch (type.id()) { |
| 125 | case LogicalTypeId::STRUCT: { |
| 126 | auto child_count = StructType::GetChildCount(type); |
| 127 | for (idx_t i = 0; i < child_count; i++) { |
| 128 | if (ContainsType(type: StructType::GetChildType(type, index: i), target)) { |
| 129 | return true; |
| 130 | } |
| 131 | } |
| 132 | return false; |
| 133 | } |
| 134 | case LogicalTypeId::UNION: { |
| 135 | auto member_count = UnionType::GetMemberCount(type); |
| 136 | for (idx_t i = 0; i < member_count; i++) { |
| 137 | if (ContainsType(type: UnionType::GetMemberType(type, index: i), target)) { |
| 138 | return true; |
| 139 | } |
| 140 | } |
| 141 | return false; |
| 142 | } |
| 143 | case LogicalTypeId::LIST: |
| 144 | case LogicalTypeId::MAP: |
| 145 | return ContainsType(type: ListType::GetChildType(type), target); |
| 146 | default: |
| 147 | return false; |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | LogicalType ExpressionBinder::ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type) { |
| 152 | if (type.id() == target) { |
| 153 | return new_type; |
| 154 | } |
| 155 | switch (type.id()) { |
| 156 | case LogicalTypeId::STRUCT: { |
| 157 | // we make a copy of the child types of the struct here |
| 158 | auto child_types = StructType::GetChildTypes(type); |
| 159 | for (auto &child_type : child_types) { |
| 160 | child_type.second = ExchangeType(type: child_type.second, target, new_type); |
| 161 | } |
| 162 | return LogicalType::STRUCT(children: child_types); |
| 163 | } |
| 164 | case LogicalTypeId::UNION: { |
| 165 | auto member_types = UnionType::CopyMemberTypes(type); |
| 166 | for (auto &member_type : member_types) { |
| 167 | member_type.second = ExchangeType(type: member_type.second, target, new_type); |
| 168 | } |
| 169 | return LogicalType::UNION(members: std::move(member_types)); |
| 170 | } |
| 171 | case LogicalTypeId::LIST: |
| 172 | return LogicalType::LIST(child: ExchangeType(type: ListType::GetChildType(type), target, new_type)); |
| 173 | case LogicalTypeId::MAP: |
| 174 | return LogicalType::MAP(child: ExchangeType(type: ListType::GetChildType(type), target, new_type)); |
| 175 | default: |
| 176 | return type; |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | bool ExpressionBinder::ContainsNullType(const LogicalType &type) { |
| 181 | return ContainsType(type, target: LogicalTypeId::SQLNULL); |
| 182 | } |
| 183 | |
| 184 | LogicalType ExpressionBinder::ExchangeNullType(const LogicalType &type) { |
| 185 | return ExchangeType(type, target: LogicalTypeId::SQLNULL, new_type: LogicalType::INTEGER); |
| 186 | } |
| 187 | |
| 188 | unique_ptr<Expression> ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, optional_ptr<LogicalType> result_type, |
| 189 | bool root_expression) { |
| 190 | // bind the main expression |
| 191 | auto error_msg = Bind(expr, depth: 0, root_expression); |
| 192 | if (!error_msg.empty()) { |
| 193 | // failed to bind: try to bind correlated columns in the expression (if any) |
| 194 | bool success = BindCorrelatedColumns(expr); |
| 195 | if (!success) { |
| 196 | throw BinderException(error_msg); |
| 197 | } |
| 198 | auto &bound_expr = expr->Cast<BoundExpression>(); |
| 199 | ExtractCorrelatedExpressions(binder, expr&: *bound_expr.expr); |
| 200 | } |
| 201 | auto &bound_expr = expr->Cast<BoundExpression>(); |
| 202 | unique_ptr<Expression> result = std::move(bound_expr.expr); |
| 203 | if (target_type.id() != LogicalTypeId::INVALID) { |
| 204 | // the binder has a specific target type: add a cast to that type |
| 205 | result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type); |
| 206 | } else { |
| 207 | if (!binder.can_contain_nulls) { |
| 208 | // SQL NULL type is only used internally in the binder |
| 209 | // cast to INTEGER if we encounter it outside of the binder |
| 210 | if (ContainsNullType(type: result->return_type)) { |
| 211 | auto exchanged_type = ExchangeNullType(type: result->return_type); |
| 212 | result = BoundCastExpression::AddCastToType(context, expr: std::move(result), target_type: exchanged_type); |
| 213 | } |
| 214 | } |
| 215 | if (result->return_type.id() == LogicalTypeId::UNKNOWN) { |
| 216 | throw ParameterNotResolvedException(); |
| 217 | } |
| 218 | } |
| 219 | if (result_type) { |
| 220 | *result_type = result->return_type; |
| 221 | } |
| 222 | return result; |
| 223 | } |
| 224 | |
| 225 | string ExpressionBinder::Bind(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) { |
| 226 | // bind the node, but only if it has not been bound yet |
| 227 | auto &expression = *expr; |
| 228 | auto alias = expression.alias; |
| 229 | if (expression.GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { |
| 230 | // already bound, don't bind it again |
| 231 | return string(); |
| 232 | } |
| 233 | // bind the expression |
| 234 | BindResult result = BindExpression(expr, depth, root_expression); |
| 235 | if (result.HasError()) { |
| 236 | return result.error; |
| 237 | } |
| 238 | // successfully bound: replace the node with a BoundExpression |
| 239 | expr = make_uniq<BoundExpression>(args: std::move(result.expression)); |
| 240 | auto &be = expr->Cast<BoundExpression>(); |
| 241 | be.alias = alias; |
| 242 | if (!alias.empty()) { |
| 243 | be.expr->alias = alias; |
| 244 | } |
| 245 | return string(); |
| 246 | } |
| 247 | |
| 248 | } // namespace duckdb |
| 249 | |