1#include "duckdb/optimizer/expression_rewriter.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/planner/expression_iterator.hpp"
5#include "duckdb/planner/operator/logical_filter.hpp"
6#include "duckdb/function/scalar/generic_functions.hpp"
7#include "duckdb/planner/expression/bound_constant_expression.hpp"
8#include "duckdb/planner/expression/bound_function_expression.hpp"
9
10namespace duckdb {
11
12unique_ptr<Expression> ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector<reference<Rule>> &rules,
13 unique_ptr<Expression> expr, bool &changes_made, bool is_root) {
14 for (auto &rule : rules) {
15 vector<reference<Expression>> bindings;
16 if (rule.get().root->Match(expr&: *expr, bindings)) {
17 // the rule matches! try to apply it
18 bool rule_made_change = false;
19 auto result = rule.get().Apply(op, bindings, fixed_point&: rule_made_change, is_root);
20 if (result) {
21 changes_made = true;
22 // the base node changed: the rule applied changes
23 // rerun on the new node
24 return ExpressionRewriter::ApplyRules(op, rules, expr: std::move(result), changes_made);
25 } else if (rule_made_change) {
26 changes_made = true;
27 // the base node didn't change, but changes were made, rerun
28 return expr;
29 }
30 // else nothing changed, continue to the next rule
31 continue;
32 }
33 }
34 // no changes could be made to this node
35 // recursively run on the children of this node
36 ExpressionIterator::EnumerateChildren(expression&: *expr, callback: [&](unique_ptr<Expression> &child) {
37 child = ExpressionRewriter::ApplyRules(op, rules, expr: std::move(child), changes_made);
38 });
39 return expr;
40}
41
42unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(unique_ptr<Expression> child, Value value) {
43 vector<unique_ptr<Expression>> children;
44 children.push_back(x: make_uniq<BoundConstantExpression>(args&: value));
45 children.push_back(x: std::move(child));
46 return ConstantOrNull(children: std::move(children), value: std::move(value));
47}
48
49unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(vector<unique_ptr<Expression>> children, Value value) {
50 auto type = value.type();
51 children.insert(position: children.begin(), x: make_uniq<BoundConstantExpression>(args&: value));
52 return make_uniq<BoundFunctionExpression>(args&: type, args: ConstantOrNull::GetFunction(return_type: type), args: std::move(children),
53 args: ConstantOrNull::Bind(value: std::move(value)));
54}
55
56void ExpressionRewriter::VisitOperator(LogicalOperator &op) {
57 VisitOperatorChildren(op);
58 this->op = &op;
59
60 to_apply_rules.clear();
61 for (auto &rule : rules) {
62 if (rule->logical_root && !rule->logical_root->Match(type: op.type)) {
63 // this rule does not apply to this type of LogicalOperator
64 continue;
65 }
66 to_apply_rules.push_back(x: *rule);
67 }
68 if (to_apply_rules.empty()) {
69 // no rules to apply on this node
70 return;
71 }
72
73 VisitOperatorExpressions(op);
74
75 // if it is a LogicalFilter, we split up filter conjunctions again
76 if (op.type == LogicalOperatorType::LOGICAL_FILTER) {
77 auto &filter = op.Cast<LogicalFilter>();
78 filter.SplitPredicates();
79 }
80}
81
82void ExpressionRewriter::VisitExpression(unique_ptr<Expression> *expression) {
83 bool changes_made;
84 do {
85 changes_made = false;
86 *expression = ExpressionRewriter::ApplyRules(op&: *op, rules: to_apply_rules, expr: std::move(*expression), changes_made, is_root: true);
87 } while (changes_made);
88}
89
90ClientContext &Rule::GetContext() const {
91 return rewriter.context;
92}
93
94} // namespace duckdb
95