1 | #include "duckdb/catalog/catalog.hpp" |
2 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
3 | #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" |
4 | #include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" |
5 | #include "duckdb/execution/expression_executor.hpp" |
6 | #include "duckdb/function/function_binder.hpp" |
7 | #include "duckdb/parser/expression/function_expression.hpp" |
8 | #include "duckdb/parser/expression/lambda_expression.hpp" |
9 | #include "duckdb/planner/binder.hpp" |
10 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
11 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
12 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
13 | #include "duckdb/planner/expression/bound_lambda_expression.hpp" |
14 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
15 | #include "duckdb/planner/expression_binder.hpp" |
16 | |
17 | namespace duckdb { |
18 | |
19 | BindResult ExpressionBinder::BindExpression(FunctionExpression &function, idx_t depth, |
20 | unique_ptr<ParsedExpression> &expr_ptr) { |
21 | // lookup the function in the catalog |
22 | QueryErrorContext error_context(binder.root_statement, function.query_location); |
23 | auto func = Catalog::GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY, catalog: function.catalog, schema: function.schema, |
24 | name: function.function_name, if_not_found: OnEntryNotFound::RETURN_NULL, error_context); |
25 | if (!func) { |
26 | // function was not found - check if we this is a table function |
27 | auto table_func = |
28 | Catalog::GetEntry(context, type: CatalogType::TABLE_FUNCTION_ENTRY, catalog: function.catalog, schema: function.schema, |
29 | name: function.function_name, if_not_found: OnEntryNotFound::RETURN_NULL, error_context); |
30 | if (table_func) { |
31 | throw BinderException(binder.FormatError( |
32 | expr_context&: function, |
33 | message: StringUtil::Format(fmt_str: "Function \"%s\" is a table function but it was used as a scalar function. This " |
34 | "function has to be called in a FROM clause (similar to a table)." , |
35 | params: function.function_name))); |
36 | } |
37 | // not a table function - check if the schema is set |
38 | if (!function.schema.empty()) { |
39 | // the schema is set - check if we can turn this the schema into a column ref |
40 | string error; |
41 | unique_ptr<ColumnRefExpression> colref; |
42 | if (function.catalog.empty()) { |
43 | colref = make_uniq<ColumnRefExpression>(args&: function.schema); |
44 | } else { |
45 | colref = make_uniq<ColumnRefExpression>(args&: function.schema, args&: function.catalog); |
46 | } |
47 | auto new_colref = QualifyColumnName(colref&: *colref, error_message&: error); |
48 | bool is_col = error.empty() ? true : false; |
49 | bool is_col_alias = QualifyColumnAlias(colref: *colref); |
50 | |
51 | if (is_col || is_col_alias) { |
52 | // we can! transform this into a function call on the column |
53 | // i.e. "x.lower()" becomes "lower(x)" |
54 | function.children.insert(position: function.children.begin(), x: std::move(colref)); |
55 | function.catalog = INVALID_CATALOG; |
56 | function.schema = INVALID_SCHEMA; |
57 | } |
58 | } |
59 | // rebind the function |
60 | func = Catalog::GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY, catalog: function.catalog, schema: function.schema, |
61 | name: function.function_name, if_not_found: OnEntryNotFound::THROW_EXCEPTION, error_context); |
62 | } |
63 | |
64 | if (func->type != CatalogType::AGGREGATE_FUNCTION_ENTRY && |
65 | (function.distinct || function.filter || !function.order_bys->orders.empty())) { |
66 | throw InvalidInputException("Function \"%s\" is a %s. \"DISTINCT\", \"FILTER\", and \"ORDER BY\" are only " |
67 | "applicable to aggregate functions." , |
68 | function.function_name, CatalogTypeToString(type: func->type)); |
69 | } |
70 | |
71 | switch (func->type) { |
72 | case CatalogType::SCALAR_FUNCTION_ENTRY: |
73 | // scalar function |
74 | |
75 | // check for lambda parameters, ignore ->> operator (JSON extension) |
76 | if (function.function_name != "->>" ) { |
77 | for (auto &child : function.children) { |
78 | if (child->expression_class == ExpressionClass::LAMBDA) { |
79 | return BindLambdaFunction(expr&: function, function&: func->Cast<ScalarFunctionCatalogEntry>(), depth); |
80 | } |
81 | } |
82 | } |
83 | |
84 | // other scalar function |
85 | return BindFunction(expr&: function, function&: func->Cast<ScalarFunctionCatalogEntry>(), depth); |
86 | |
87 | case CatalogType::MACRO_ENTRY: |
88 | // macro function |
89 | return BindMacro(expr&: function, macro&: func->Cast<ScalarMacroCatalogEntry>(), depth, expr_ptr); |
90 | default: |
91 | // aggregate function |
92 | return BindAggregate(expr&: function, function&: func->Cast<AggregateFunctionCatalogEntry>(), depth); |
93 | } |
94 | } |
95 | |
96 | BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, idx_t depth) { |
97 | |
98 | // bind the children of the function expression |
99 | string error; |
100 | |
101 | // bind of each child |
102 | for (idx_t i = 0; i < function.children.size(); i++) { |
103 | BindChild(expr&: function.children[i], depth, error); |
104 | } |
105 | |
106 | if (!error.empty()) { |
107 | return BindResult(error); |
108 | } |
109 | if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { |
110 | return BindResult(make_uniq<BoundConstantExpression>(args: Value(LogicalType::SQLNULL))); |
111 | } |
112 | |
113 | // all children bound successfully |
114 | // extract the children and types |
115 | vector<unique_ptr<Expression>> children; |
116 | for (idx_t i = 0; i < function.children.size(); i++) { |
117 | auto &child = BoundExpression::GetExpression(expr&: *function.children[i]); |
118 | children.push_back(x: std::move(child)); |
119 | } |
120 | |
121 | FunctionBinder function_binder(context); |
122 | unique_ptr<Expression> result = |
123 | function_binder.BindScalarFunction(function&: func, children: std::move(children), error, is_operator: function.is_operator, binder: &binder); |
124 | if (!result) { |
125 | throw BinderException(binder.FormatError(expr_context&: function, message: error)); |
126 | } |
127 | return BindResult(std::move(result)); |
128 | } |
129 | |
130 | BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, |
131 | idx_t depth) { |
132 | |
133 | // bind the children of the function expression |
134 | string error; |
135 | |
136 | if (function.children.size() != 2) { |
137 | throw BinderException("Invalid function arguments!" ); |
138 | } |
139 | D_ASSERT(function.children[1]->GetExpressionClass() == ExpressionClass::LAMBDA); |
140 | |
141 | // bind the list parameter |
142 | BindChild(expr&: function.children[0], depth, error); |
143 | if (!error.empty()) { |
144 | return BindResult(error); |
145 | } |
146 | |
147 | // get the logical type of the children of the list |
148 | auto &list_child = BoundExpression::GetExpression(expr&: *function.children[0]); |
149 | if (list_child->return_type.id() != LogicalTypeId::LIST && list_child->return_type.id() != LogicalTypeId::SQLNULL && |
150 | list_child->return_type.id() != LogicalTypeId::UNKNOWN) { |
151 | throw BinderException(" Invalid LIST argument to " + function.function_name + "!" ); |
152 | } |
153 | |
154 | LogicalType list_child_type = list_child->return_type.id(); |
155 | if (list_child->return_type.id() != LogicalTypeId::SQLNULL && |
156 | list_child->return_type.id() != LogicalTypeId::UNKNOWN) { |
157 | list_child_type = ListType::GetChildType(type: list_child->return_type); |
158 | } |
159 | |
160 | // bind the lambda parameter |
161 | auto &lambda_expr = function.children[1]->Cast<LambdaExpression>(); |
162 | BindResult bind_lambda_result = BindExpression(expr&: lambda_expr, depth, is_lambda: true, list_child_type); |
163 | |
164 | if (bind_lambda_result.HasError()) { |
165 | error = bind_lambda_result.error; |
166 | } else { |
167 | // successfully bound: replace the node with a BoundExpression |
168 | auto alias = function.children[1]->alias; |
169 | bind_lambda_result.expression->alias = alias; |
170 | if (!alias.empty()) { |
171 | bind_lambda_result.expression->alias = alias; |
172 | } |
173 | function.children[1] = make_uniq<BoundExpression>(args: std::move(bind_lambda_result.expression)); |
174 | } |
175 | |
176 | if (!error.empty()) { |
177 | return BindResult(error); |
178 | } |
179 | if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { |
180 | return BindResult(make_uniq<BoundConstantExpression>(args: Value(LogicalType::SQLNULL))); |
181 | } |
182 | |
183 | // all children bound successfully |
184 | // extract the children and types |
185 | vector<unique_ptr<Expression>> children; |
186 | for (idx_t i = 0; i < function.children.size(); i++) { |
187 | auto &child = BoundExpression::GetExpression(expr&: *function.children[i]); |
188 | children.push_back(x: std::move(child)); |
189 | } |
190 | |
191 | // capture the (lambda) columns |
192 | auto &bound_lambda_expr = children.back()->Cast<BoundLambdaExpression>(); |
193 | CaptureLambdaColumns(captures&: bound_lambda_expr.captures, list_child_type, expr&: bound_lambda_expr.lambda_expr); |
194 | |
195 | FunctionBinder function_binder(context); |
196 | unique_ptr<Expression> result = |
197 | function_binder.BindScalarFunction(function&: func, children: std::move(children), error, is_operator: function.is_operator, binder: &binder); |
198 | if (!result) { |
199 | throw BinderException(binder.FormatError(expr_context&: function, message: error)); |
200 | } |
201 | |
202 | auto &bound_function_expr = result->Cast<BoundFunctionExpression>(); |
203 | D_ASSERT(bound_function_expr.children.size() == 2); |
204 | |
205 | // remove the lambda expression from the children |
206 | auto lambda = std::move(bound_function_expr.children.back()); |
207 | bound_function_expr.children.pop_back(); |
208 | auto &bound_lambda = lambda->Cast<BoundLambdaExpression>(); |
209 | |
210 | // push back (in reverse order) any nested lambda parameters so that we can later use them in the lambda expression |
211 | // (rhs) |
212 | if (lambda_bindings) { |
213 | for (idx_t i = lambda_bindings->size(); i > 0; i--) { |
214 | |
215 | idx_t lambda_index = lambda_bindings->size() - i + 1; |
216 | auto &binding = (*lambda_bindings)[i - 1]; |
217 | |
218 | D_ASSERT(binding.names.size() == 1); |
219 | D_ASSERT(binding.types.size() == 1); |
220 | |
221 | auto bound_lambda_param = |
222 | make_uniq<BoundReferenceExpression>(args&: binding.names[0], args&: binding.types[0], args&: lambda_index); |
223 | bound_function_expr.children.push_back(x: std::move(bound_lambda_param)); |
224 | } |
225 | } |
226 | |
227 | // push back the captures into the children vector and the correct return types into the bound_function arguments |
228 | for (auto &capture : bound_lambda.captures) { |
229 | bound_function_expr.children.push_back(x: std::move(capture)); |
230 | } |
231 | |
232 | return BindResult(std::move(result)); |
233 | } |
234 | |
235 | BindResult ExpressionBinder::BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, |
236 | idx_t depth) { |
237 | return BindResult(binder.FormatError(expr_context&: expr, message: UnsupportedAggregateMessage())); |
238 | } |
239 | |
240 | BindResult ExpressionBinder::BindUnnest(FunctionExpression &expr, idx_t depth, bool root_expression) { |
241 | return BindResult(binder.FormatError(expr_context&: expr, message: UnsupportedUnnestMessage())); |
242 | } |
243 | |
244 | string ExpressionBinder::UnsupportedAggregateMessage() { |
245 | return "Aggregate functions are not supported here" ; |
246 | } |
247 | |
248 | string ExpressionBinder::UnsupportedUnnestMessage() { |
249 | return "UNNEST not supported here" ; |
250 | } |
251 | |
252 | } // namespace duckdb |
253 | |