| 1 | #include "duckdb/parser/expression/lambda_expression.hpp" |
| 2 | #include "duckdb/parser/expression/operator_expression.hpp" |
| 3 | #include "duckdb/planner/expression_binder.hpp" |
| 4 | #include "duckdb/planner/bind_context.hpp" |
| 5 | #include "duckdb/parser/expression/columnref_expression.hpp" |
| 6 | #include "duckdb/planner/binder.hpp" |
| 7 | #include "duckdb/parser/expression/function_expression.hpp" |
| 8 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
| 9 | #include "duckdb/planner/expression/bound_lambdaref_expression.hpp" |
| 10 | #include "duckdb/planner/expression/bound_lambda_expression.hpp" |
| 11 | #include "duckdb/planner/expression_iterator.hpp" |
| 12 | |
| 13 | namespace duckdb { |
| 14 | |
| 15 | BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda, |
| 16 | const LogicalType &list_child_type) { |
| 17 | |
| 18 | if (!is_lambda) { |
| 19 | // this is for binding JSON |
| 20 | auto lhs_expr = expr.lhs->Copy(); |
| 21 | OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(lhs_expr), expr.expr->Copy()); |
| 22 | return BindExpression(expr&: arrow_expr, depth); |
| 23 | } |
| 24 | |
| 25 | // binding the lambda expression |
| 26 | D_ASSERT(expr.lhs); |
| 27 | if (expr.lhs->expression_class != ExpressionClass::FUNCTION && |
| 28 | expr.lhs->expression_class != ExpressionClass::COLUMN_REF) { |
| 29 | throw BinderException( |
| 30 | "Invalid parameter list! Parameters must be comma-separated column names, e.g. x or (x, y)." ); |
| 31 | } |
| 32 | |
| 33 | // move the lambda parameters to the params vector |
| 34 | if (expr.lhs->expression_class == ExpressionClass::COLUMN_REF) { |
| 35 | expr.params.push_back(x: std::move(expr.lhs)); |
| 36 | } else { |
| 37 | auto &func_expr = expr.lhs->Cast<FunctionExpression>(); |
| 38 | for (idx_t i = 0; i < func_expr.children.size(); i++) { |
| 39 | expr.params.push_back(x: std::move(func_expr.children[i])); |
| 40 | } |
| 41 | } |
| 42 | D_ASSERT(!expr.params.empty()); |
| 43 | |
| 44 | // create dummy columns for the lambda parameters (lhs) |
| 45 | vector<LogicalType> column_types; |
| 46 | vector<string> column_names; |
| 47 | vector<string> params_strings; |
| 48 | |
| 49 | // positional parameters as column references |
| 50 | for (idx_t i = 0; i < expr.params.size(); i++) { |
| 51 | if (expr.params[i]->GetExpressionClass() != ExpressionClass::COLUMN_REF) { |
| 52 | throw BinderException("Parameter must be a column name." ); |
| 53 | } |
| 54 | |
| 55 | auto column_ref = expr.params[i]->Cast<ColumnRefExpression>(); |
| 56 | if (column_ref.IsQualified()) { |
| 57 | throw BinderException("Invalid parameter name '%s': must be unqualified" , column_ref.ToString()); |
| 58 | } |
| 59 | |
| 60 | column_types.emplace_back(args: list_child_type); |
| 61 | column_names.push_back(x: column_ref.GetColumnName()); |
| 62 | params_strings.push_back(x: expr.params[i]->ToString()); |
| 63 | } |
| 64 | |
| 65 | // base table alias |
| 66 | auto params_alias = StringUtil::Join(input: params_strings, separator: ", " ); |
| 67 | if (params_strings.size() > 1) { |
| 68 | params_alias = "(" + params_alias + ")" ; |
| 69 | } |
| 70 | |
| 71 | // create a lambda binding and push it to the lambda bindings vector |
| 72 | vector<DummyBinding> local_bindings; |
| 73 | if (!lambda_bindings) { |
| 74 | lambda_bindings = &local_bindings; |
| 75 | } |
| 76 | DummyBinding new_lambda_binding(column_types, column_names, params_alias); |
| 77 | lambda_bindings->push_back(x: new_lambda_binding); |
| 78 | |
| 79 | // bind the parameter expressions |
| 80 | for (idx_t i = 0; i < expr.params.size(); i++) { |
| 81 | auto result = BindExpression(expr_ptr&: expr.params[i], depth, root_expression: false); |
| 82 | if (result.HasError()) { |
| 83 | throw InternalException("Error during lambda binding: %s" , result.error); |
| 84 | } |
| 85 | } |
| 86 | |
| 87 | auto result = BindExpression(expr_ptr&: expr.expr, depth, root_expression: false); |
| 88 | lambda_bindings->pop_back(); |
| 89 | |
| 90 | // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the |
| 91 | // query also contain lambdas |
| 92 | if (lambda_bindings->empty()) { |
| 93 | lambda_bindings = nullptr; |
| 94 | } |
| 95 | |
| 96 | if (result.HasError()) { |
| 97 | throw BinderException(result.error); |
| 98 | } |
| 99 | |
| 100 | return BindResult(make_uniq<BoundLambdaExpression>(args: ExpressionType::LAMBDA, args: LogicalType::LAMBDA, |
| 101 | args: std::move(result.expression), args: params_strings.size())); |
| 102 | } |
| 103 | |
| 104 | void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr<Expression> &original, |
| 105 | unique_ptr<Expression> &replacement, |
| 106 | vector<unique_ptr<Expression>> &captures, |
| 107 | LogicalType &list_child_type) { |
| 108 | |
| 109 | // check if the original expression is a lambda parameter |
| 110 | if (original->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { |
| 111 | |
| 112 | // determine if this is the lambda parameter |
| 113 | auto &bound_lambda_ref = original->Cast<BoundLambdaRefExpression>(); |
| 114 | auto alias = bound_lambda_ref.alias; |
| 115 | |
| 116 | if (lambda_bindings && bound_lambda_ref.lambda_index != lambda_bindings->size()) { |
| 117 | |
| 118 | D_ASSERT(bound_lambda_ref.lambda_index < lambda_bindings->size()); |
| 119 | auto &lambda_binding = (*lambda_bindings)[bound_lambda_ref.lambda_index]; |
| 120 | |
| 121 | D_ASSERT(lambda_binding.names.size() == 1); |
| 122 | D_ASSERT(lambda_binding.types.size() == 1); |
| 123 | // refers to a lambda parameter outside of the current lambda function |
| 124 | replacement = |
| 125 | make_uniq<BoundReferenceExpression>(args&: lambda_binding.names[0], args&: lambda_binding.types[0], |
| 126 | args: lambda_bindings->size() - bound_lambda_ref.lambda_index + 1); |
| 127 | |
| 128 | } else { |
| 129 | // refers to current lambda parameter |
| 130 | replacement = make_uniq<BoundReferenceExpression>(args&: alias, args&: list_child_type, args: 0); |
| 131 | } |
| 132 | |
| 133 | } else { |
| 134 | // always at least the current lambda parameter |
| 135 | idx_t index_offset = 1; |
| 136 | if (lambda_bindings) { |
| 137 | index_offset += lambda_bindings->size(); |
| 138 | } |
| 139 | |
| 140 | // this is not a lambda parameter, so we need to create a new argument for the arguments vector |
| 141 | replacement = make_uniq<BoundReferenceExpression>(args&: original->alias, args&: original->return_type, |
| 142 | args: captures.size() + index_offset + 1); |
| 143 | captures.push_back(x: std::move(original)); |
| 144 | } |
| 145 | } |
| 146 | |
| 147 | void ExpressionBinder::CaptureLambdaColumns(vector<unique_ptr<Expression>> &captures, LogicalType &list_child_type, |
| 148 | unique_ptr<Expression> &expr) { |
| 149 | |
| 150 | if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { |
| 151 | throw InvalidInputException("Subqueries are not supported in lambda expressions!" ); |
| 152 | } |
| 153 | |
| 154 | // these expression classes do not have children, transform them |
| 155 | if (expr->expression_class == ExpressionClass::BOUND_CONSTANT || |
| 156 | expr->expression_class == ExpressionClass::BOUND_COLUMN_REF || |
| 157 | expr->expression_class == ExpressionClass::BOUND_PARAMETER || |
| 158 | expr->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { |
| 159 | |
| 160 | // move the expr because we are going to replace it |
| 161 | auto original = std::move(expr); |
| 162 | unique_ptr<Expression> replacement; |
| 163 | |
| 164 | TransformCapturedLambdaColumn(original, replacement, captures, list_child_type); |
| 165 | |
| 166 | // replace the expression |
| 167 | expr = std::move(replacement); |
| 168 | |
| 169 | } else { |
| 170 | // recursively enumerate the children of the expression |
| 171 | ExpressionIterator::EnumerateChildren( |
| 172 | expression&: *expr, callback: [&](unique_ptr<Expression> &child) { CaptureLambdaColumns(captures, list_child_type, expr&: child); }); |
| 173 | } |
| 174 | |
| 175 | expr->Verify(); |
| 176 | } |
| 177 | |
| 178 | } // namespace duckdb |
| 179 | |