1#include "duckdb/parser/expression/lambda_expression.hpp"
2#include "duckdb/parser/expression/operator_expression.hpp"
3#include "duckdb/planner/expression_binder.hpp"
4#include "duckdb/planner/bind_context.hpp"
5#include "duckdb/parser/expression/columnref_expression.hpp"
6#include "duckdb/planner/binder.hpp"
7#include "duckdb/parser/expression/function_expression.hpp"
8#include "duckdb/planner/expression/bound_reference_expression.hpp"
9#include "duckdb/planner/expression/bound_lambdaref_expression.hpp"
10#include "duckdb/planner/expression/bound_lambda_expression.hpp"
11#include "duckdb/planner/expression_iterator.hpp"
12
13namespace duckdb {
14
15BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda,
16 const LogicalType &list_child_type) {
17
18 if (!is_lambda) {
19 // this is for binding JSON
20 auto lhs_expr = expr.lhs->Copy();
21 OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(lhs_expr), expr.expr->Copy());
22 return BindExpression(expr&: arrow_expr, depth);
23 }
24
25 // binding the lambda expression
26 D_ASSERT(expr.lhs);
27 if (expr.lhs->expression_class != ExpressionClass::FUNCTION &&
28 expr.lhs->expression_class != ExpressionClass::COLUMN_REF) {
29 throw BinderException(
30 "Invalid parameter list! Parameters must be comma-separated column names, e.g. x or (x, y).");
31 }
32
33 // move the lambda parameters to the params vector
34 if (expr.lhs->expression_class == ExpressionClass::COLUMN_REF) {
35 expr.params.push_back(x: std::move(expr.lhs));
36 } else {
37 auto &func_expr = expr.lhs->Cast<FunctionExpression>();
38 for (idx_t i = 0; i < func_expr.children.size(); i++) {
39 expr.params.push_back(x: std::move(func_expr.children[i]));
40 }
41 }
42 D_ASSERT(!expr.params.empty());
43
44 // create dummy columns for the lambda parameters (lhs)
45 vector<LogicalType> column_types;
46 vector<string> column_names;
47 vector<string> params_strings;
48
49 // positional parameters as column references
50 for (idx_t i = 0; i < expr.params.size(); i++) {
51 if (expr.params[i]->GetExpressionClass() != ExpressionClass::COLUMN_REF) {
52 throw BinderException("Parameter must be a column name.");
53 }
54
55 auto column_ref = expr.params[i]->Cast<ColumnRefExpression>();
56 if (column_ref.IsQualified()) {
57 throw BinderException("Invalid parameter name '%s': must be unqualified", column_ref.ToString());
58 }
59
60 column_types.emplace_back(args: list_child_type);
61 column_names.push_back(x: column_ref.GetColumnName());
62 params_strings.push_back(x: expr.params[i]->ToString());
63 }
64
65 // base table alias
66 auto params_alias = StringUtil::Join(input: params_strings, separator: ", ");
67 if (params_strings.size() > 1) {
68 params_alias = "(" + params_alias + ")";
69 }
70
71 // create a lambda binding and push it to the lambda bindings vector
72 vector<DummyBinding> local_bindings;
73 if (!lambda_bindings) {
74 lambda_bindings = &local_bindings;
75 }
76 DummyBinding new_lambda_binding(column_types, column_names, params_alias);
77 lambda_bindings->push_back(x: new_lambda_binding);
78
79 // bind the parameter expressions
80 for (idx_t i = 0; i < expr.params.size(); i++) {
81 auto result = BindExpression(expr_ptr&: expr.params[i], depth, root_expression: false);
82 if (result.HasError()) {
83 throw InternalException("Error during lambda binding: %s", result.error);
84 }
85 }
86
87 auto result = BindExpression(expr_ptr&: expr.expr, depth, root_expression: false);
88 lambda_bindings->pop_back();
89
90 // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the
91 // query also contain lambdas
92 if (lambda_bindings->empty()) {
93 lambda_bindings = nullptr;
94 }
95
96 if (result.HasError()) {
97 throw BinderException(result.error);
98 }
99
100 return BindResult(make_uniq<BoundLambdaExpression>(args: ExpressionType::LAMBDA, args: LogicalType::LAMBDA,
101 args: std::move(result.expression), args: params_strings.size()));
102}
103
104void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr<Expression> &original,
105 unique_ptr<Expression> &replacement,
106 vector<unique_ptr<Expression>> &captures,
107 LogicalType &list_child_type) {
108
109 // check if the original expression is a lambda parameter
110 if (original->expression_class == ExpressionClass::BOUND_LAMBDA_REF) {
111
112 // determine if this is the lambda parameter
113 auto &bound_lambda_ref = original->Cast<BoundLambdaRefExpression>();
114 auto alias = bound_lambda_ref.alias;
115
116 if (lambda_bindings && bound_lambda_ref.lambda_index != lambda_bindings->size()) {
117
118 D_ASSERT(bound_lambda_ref.lambda_index < lambda_bindings->size());
119 auto &lambda_binding = (*lambda_bindings)[bound_lambda_ref.lambda_index];
120
121 D_ASSERT(lambda_binding.names.size() == 1);
122 D_ASSERT(lambda_binding.types.size() == 1);
123 // refers to a lambda parameter outside of the current lambda function
124 replacement =
125 make_uniq<BoundReferenceExpression>(args&: lambda_binding.names[0], args&: lambda_binding.types[0],
126 args: lambda_bindings->size() - bound_lambda_ref.lambda_index + 1);
127
128 } else {
129 // refers to current lambda parameter
130 replacement = make_uniq<BoundReferenceExpression>(args&: alias, args&: list_child_type, args: 0);
131 }
132
133 } else {
134 // always at least the current lambda parameter
135 idx_t index_offset = 1;
136 if (lambda_bindings) {
137 index_offset += lambda_bindings->size();
138 }
139
140 // this is not a lambda parameter, so we need to create a new argument for the arguments vector
141 replacement = make_uniq<BoundReferenceExpression>(args&: original->alias, args&: original->return_type,
142 args: captures.size() + index_offset + 1);
143 captures.push_back(x: std::move(original));
144 }
145}
146
147void ExpressionBinder::CaptureLambdaColumns(vector<unique_ptr<Expression>> &captures, LogicalType &list_child_type,
148 unique_ptr<Expression> &expr) {
149
150 if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) {
151 throw InvalidInputException("Subqueries are not supported in lambda expressions!");
152 }
153
154 // these expression classes do not have children, transform them
155 if (expr->expression_class == ExpressionClass::BOUND_CONSTANT ||
156 expr->expression_class == ExpressionClass::BOUND_COLUMN_REF ||
157 expr->expression_class == ExpressionClass::BOUND_PARAMETER ||
158 expr->expression_class == ExpressionClass::BOUND_LAMBDA_REF) {
159
160 // move the expr because we are going to replace it
161 auto original = std::move(expr);
162 unique_ptr<Expression> replacement;
163
164 TransformCapturedLambdaColumn(original, replacement, captures, list_child_type);
165
166 // replace the expression
167 expr = std::move(replacement);
168
169 } else {
170 // recursively enumerate the children of the expression
171 ExpressionIterator::EnumerateChildren(
172 expression&: *expr, callback: [&](unique_ptr<Expression> &child) { CaptureLambdaColumns(captures, list_child_type, expr&: child); });
173 }
174
175 expr->Verify();
176}
177
178} // namespace duckdb
179