1 | #include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" |
2 | #include "duckdb/common/string_util.hpp" |
3 | #include "duckdb/parser/expression/function_expression.hpp" |
4 | #include "duckdb/parser/expression/subquery_expression.hpp" |
5 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
6 | #include "duckdb/planner/expression_binder.hpp" |
7 | |
8 | #include "duckdb/function/scalar_macro_function.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr<ParsedExpression> &expr) { |
13 | switch (expr->GetExpressionClass()) { |
14 | case ExpressionClass::COLUMN_REF: { |
15 | // if expr is a parameter, replace it with its argument |
16 | auto &colref = expr->Cast<ColumnRefExpression>(); |
17 | bool bind_macro_parameter = false; |
18 | if (colref.IsQualified()) { |
19 | bind_macro_parameter = false; |
20 | if (colref.GetTableName().find(s: DummyBinding::DUMMY_NAME) != string::npos) { |
21 | bind_macro_parameter = true; |
22 | } |
23 | } else { |
24 | bind_macro_parameter = macro_binding->HasMatchingBinding(column_name: colref.GetColumnName()); |
25 | } |
26 | if (bind_macro_parameter) { |
27 | D_ASSERT(macro_binding->HasMatchingBinding(colref.GetColumnName())); |
28 | expr = macro_binding->ParamToArg(colref); |
29 | } |
30 | return; |
31 | } |
32 | case ExpressionClass::SUBQUERY: { |
33 | // replacing parameters within a subquery is slightly different |
34 | auto &sq = (expr->Cast<SubqueryExpression>()).subquery; |
35 | ParsedExpressionIterator::EnumerateQueryNodeChildren( |
36 | node&: *sq->node, callback: [&](unique_ptr<ParsedExpression> &child) { ReplaceMacroParametersRecursive(expr&: child); }); |
37 | break; |
38 | } |
39 | default: // fall through |
40 | break; |
41 | } |
42 | // unfold child expressions |
43 | ParsedExpressionIterator::EnumerateChildren( |
44 | expr&: *expr, callback: [&](unique_ptr<ParsedExpression> &child) { ReplaceMacroParametersRecursive(expr&: child); }); |
45 | } |
46 | |
47 | BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, idx_t depth, |
48 | unique_ptr<ParsedExpression> &expr) { |
49 | // recast function so we can access the scalar member function->expression |
50 | auto ¯o_def = macro_func.function->Cast<ScalarMacroFunction>(); |
51 | |
52 | // validate the arguments and separate positional and default arguments |
53 | vector<unique_ptr<ParsedExpression>> positionals; |
54 | unordered_map<string, unique_ptr<ParsedExpression>> defaults; |
55 | |
56 | string error = |
57 | MacroFunction::ValidateArguments(macro_function&: *macro_func.function, name: macro_func.name, function_expr&: function, positionals, defaults); |
58 | if (!error.empty()) { |
59 | throw BinderException(binder.FormatError(expr_context&: *expr, message: error)); |
60 | } |
61 | |
62 | // create a MacroBinding to bind this macro's parameters to its arguments |
63 | vector<LogicalType> types; |
64 | vector<string> names; |
65 | // positional parameters |
66 | for (idx_t i = 0; i < macro_def.parameters.size(); i++) { |
67 | types.emplace_back(args: LogicalType::SQLNULL); |
68 | auto ¶m = macro_def.parameters[i]->Cast<ColumnRefExpression>(); |
69 | names.push_back(x: param.GetColumnName()); |
70 | } |
71 | // default parameters |
72 | for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { |
73 | types.emplace_back(args: LogicalType::SQLNULL); |
74 | names.push_back(x: it->first); |
75 | // now push the defaults into the positionals |
76 | positionals.push_back(x: std::move(defaults[it->first])); |
77 | } |
78 | auto new_macro_binding = make_uniq<DummyBinding>(args&: types, args&: names, args&: macro_func.name); |
79 | new_macro_binding->arguments = &positionals; |
80 | macro_binding = new_macro_binding.get(); |
81 | |
82 | // replace current expression with stored macro expression, and replace params |
83 | expr = macro_def.expression->Copy(); |
84 | ReplaceMacroParametersRecursive(expr); |
85 | |
86 | // bind the unfolded macro |
87 | return BindExpression(expr_ptr&: expr, depth); |
88 | } |
89 | |
90 | } // namespace duckdb |
91 | |