1 | #include "duckdb/parser/expression/lambda_expression.hpp" |
2 | #include "duckdb/parser/expression/operator_expression.hpp" |
3 | #include "duckdb/planner/expression_binder.hpp" |
4 | #include "duckdb/planner/bind_context.hpp" |
5 | #include "duckdb/parser/expression/columnref_expression.hpp" |
6 | #include "duckdb/planner/binder.hpp" |
7 | #include "duckdb/parser/expression/function_expression.hpp" |
8 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
9 | #include "duckdb/planner/expression/bound_lambdaref_expression.hpp" |
10 | #include "duckdb/planner/expression/bound_lambda_expression.hpp" |
11 | #include "duckdb/planner/expression_iterator.hpp" |
12 | |
13 | namespace duckdb { |
14 | |
15 | BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const bool is_lambda, |
16 | const LogicalType &list_child_type) { |
17 | |
18 | if (!is_lambda) { |
19 | // this is for binding JSON |
20 | auto lhs_expr = expr.lhs->Copy(); |
21 | OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(lhs_expr), expr.expr->Copy()); |
22 | return BindExpression(expr&: arrow_expr, depth); |
23 | } |
24 | |
25 | // binding the lambda expression |
26 | D_ASSERT(expr.lhs); |
27 | if (expr.lhs->expression_class != ExpressionClass::FUNCTION && |
28 | expr.lhs->expression_class != ExpressionClass::COLUMN_REF) { |
29 | throw BinderException( |
30 | "Invalid parameter list! Parameters must be comma-separated column names, e.g. x or (x, y)." ); |
31 | } |
32 | |
33 | // move the lambda parameters to the params vector |
34 | if (expr.lhs->expression_class == ExpressionClass::COLUMN_REF) { |
35 | expr.params.push_back(x: std::move(expr.lhs)); |
36 | } else { |
37 | auto &func_expr = expr.lhs->Cast<FunctionExpression>(); |
38 | for (idx_t i = 0; i < func_expr.children.size(); i++) { |
39 | expr.params.push_back(x: std::move(func_expr.children[i])); |
40 | } |
41 | } |
42 | D_ASSERT(!expr.params.empty()); |
43 | |
44 | // create dummy columns for the lambda parameters (lhs) |
45 | vector<LogicalType> column_types; |
46 | vector<string> column_names; |
47 | vector<string> params_strings; |
48 | |
49 | // positional parameters as column references |
50 | for (idx_t i = 0; i < expr.params.size(); i++) { |
51 | if (expr.params[i]->GetExpressionClass() != ExpressionClass::COLUMN_REF) { |
52 | throw BinderException("Parameter must be a column name." ); |
53 | } |
54 | |
55 | auto column_ref = expr.params[i]->Cast<ColumnRefExpression>(); |
56 | if (column_ref.IsQualified()) { |
57 | throw BinderException("Invalid parameter name '%s': must be unqualified" , column_ref.ToString()); |
58 | } |
59 | |
60 | column_types.emplace_back(args: list_child_type); |
61 | column_names.push_back(x: column_ref.GetColumnName()); |
62 | params_strings.push_back(x: expr.params[i]->ToString()); |
63 | } |
64 | |
65 | // base table alias |
66 | auto params_alias = StringUtil::Join(input: params_strings, separator: ", " ); |
67 | if (params_strings.size() > 1) { |
68 | params_alias = "(" + params_alias + ")" ; |
69 | } |
70 | |
71 | // create a lambda binding and push it to the lambda bindings vector |
72 | vector<DummyBinding> local_bindings; |
73 | if (!lambda_bindings) { |
74 | lambda_bindings = &local_bindings; |
75 | } |
76 | DummyBinding new_lambda_binding(column_types, column_names, params_alias); |
77 | lambda_bindings->push_back(x: new_lambda_binding); |
78 | |
79 | // bind the parameter expressions |
80 | for (idx_t i = 0; i < expr.params.size(); i++) { |
81 | auto result = BindExpression(expr_ptr&: expr.params[i], depth, root_expression: false); |
82 | if (result.HasError()) { |
83 | throw InternalException("Error during lambda binding: %s" , result.error); |
84 | } |
85 | } |
86 | |
87 | auto result = BindExpression(expr_ptr&: expr.expr, depth, root_expression: false); |
88 | lambda_bindings->pop_back(); |
89 | |
90 | // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the |
91 | // query also contain lambdas |
92 | if (lambda_bindings->empty()) { |
93 | lambda_bindings = nullptr; |
94 | } |
95 | |
96 | if (result.HasError()) { |
97 | throw BinderException(result.error); |
98 | } |
99 | |
100 | return BindResult(make_uniq<BoundLambdaExpression>(args: ExpressionType::LAMBDA, args: LogicalType::LAMBDA, |
101 | args: std::move(result.expression), args: params_strings.size())); |
102 | } |
103 | |
104 | void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr<Expression> &original, |
105 | unique_ptr<Expression> &replacement, |
106 | vector<unique_ptr<Expression>> &captures, |
107 | LogicalType &list_child_type) { |
108 | |
109 | // check if the original expression is a lambda parameter |
110 | if (original->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { |
111 | |
112 | // determine if this is the lambda parameter |
113 | auto &bound_lambda_ref = original->Cast<BoundLambdaRefExpression>(); |
114 | auto alias = bound_lambda_ref.alias; |
115 | |
116 | if (lambda_bindings && bound_lambda_ref.lambda_index != lambda_bindings->size()) { |
117 | |
118 | D_ASSERT(bound_lambda_ref.lambda_index < lambda_bindings->size()); |
119 | auto &lambda_binding = (*lambda_bindings)[bound_lambda_ref.lambda_index]; |
120 | |
121 | D_ASSERT(lambda_binding.names.size() == 1); |
122 | D_ASSERT(lambda_binding.types.size() == 1); |
123 | // refers to a lambda parameter outside of the current lambda function |
124 | replacement = |
125 | make_uniq<BoundReferenceExpression>(args&: lambda_binding.names[0], args&: lambda_binding.types[0], |
126 | args: lambda_bindings->size() - bound_lambda_ref.lambda_index + 1); |
127 | |
128 | } else { |
129 | // refers to current lambda parameter |
130 | replacement = make_uniq<BoundReferenceExpression>(args&: alias, args&: list_child_type, args: 0); |
131 | } |
132 | |
133 | } else { |
134 | // always at least the current lambda parameter |
135 | idx_t index_offset = 1; |
136 | if (lambda_bindings) { |
137 | index_offset += lambda_bindings->size(); |
138 | } |
139 | |
140 | // this is not a lambda parameter, so we need to create a new argument for the arguments vector |
141 | replacement = make_uniq<BoundReferenceExpression>(args&: original->alias, args&: original->return_type, |
142 | args: captures.size() + index_offset + 1); |
143 | captures.push_back(x: std::move(original)); |
144 | } |
145 | } |
146 | |
147 | void ExpressionBinder::CaptureLambdaColumns(vector<unique_ptr<Expression>> &captures, LogicalType &list_child_type, |
148 | unique_ptr<Expression> &expr) { |
149 | |
150 | if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { |
151 | throw InvalidInputException("Subqueries are not supported in lambda expressions!" ); |
152 | } |
153 | |
154 | // these expression classes do not have children, transform them |
155 | if (expr->expression_class == ExpressionClass::BOUND_CONSTANT || |
156 | expr->expression_class == ExpressionClass::BOUND_COLUMN_REF || |
157 | expr->expression_class == ExpressionClass::BOUND_PARAMETER || |
158 | expr->expression_class == ExpressionClass::BOUND_LAMBDA_REF) { |
159 | |
160 | // move the expr because we are going to replace it |
161 | auto original = std::move(expr); |
162 | unique_ptr<Expression> replacement; |
163 | |
164 | TransformCapturedLambdaColumn(original, replacement, captures, list_child_type); |
165 | |
166 | // replace the expression |
167 | expr = std::move(replacement); |
168 | |
169 | } else { |
170 | // recursively enumerate the children of the expression |
171 | ExpressionIterator::EnumerateChildren( |
172 | expression&: *expr, callback: [&](unique_ptr<Expression> &child) { CaptureLambdaColumns(captures, list_child_type, expr&: child); }); |
173 | } |
174 | |
175 | expr->Verify(); |
176 | } |
177 | |
178 | } // namespace duckdb |
179 | |