| 1 | #include "duckdb/optimizer/expression_rewriter.hpp" | 
| 2 |  | 
| 3 | #include "duckdb/common/exception.hpp" | 
| 4 | #include "duckdb/planner/expression_iterator.hpp" | 
| 5 | #include "duckdb/planner/operator/logical_filter.hpp" | 
| 6 | #include "duckdb/function/scalar/generic_functions.hpp" | 
| 7 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
| 8 | #include "duckdb/planner/expression/bound_function_expression.hpp" | 
| 9 |  | 
| 10 | namespace duckdb { | 
| 11 |  | 
| 12 | unique_ptr<Expression> ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector<reference<Rule>> &rules, | 
| 13 |                                                       unique_ptr<Expression> expr, bool &changes_made, bool is_root) { | 
| 14 | 	for (auto &rule : rules) { | 
| 15 | 		vector<reference<Expression>> bindings; | 
| 16 | 		if (rule.get().root->Match(expr&: *expr, bindings)) { | 
| 17 | 			// the rule matches! try to apply it | 
| 18 | 			bool rule_made_change = false; | 
| 19 | 			auto result = rule.get().Apply(op, bindings, fixed_point&: rule_made_change, is_root); | 
| 20 | 			if (result) { | 
| 21 | 				changes_made = true; | 
| 22 | 				// the base node changed: the rule applied changes | 
| 23 | 				// rerun on the new node | 
| 24 | 				return ExpressionRewriter::ApplyRules(op, rules, expr: std::move(result), changes_made); | 
| 25 | 			} else if (rule_made_change) { | 
| 26 | 				changes_made = true; | 
| 27 | 				// the base node didn't change, but changes were made, rerun | 
| 28 | 				return expr; | 
| 29 | 			} | 
| 30 | 			// else nothing changed, continue to the next rule | 
| 31 | 			continue; | 
| 32 | 		} | 
| 33 | 	} | 
| 34 | 	// no changes could be made to this node | 
| 35 | 	// recursively run on the children of this node | 
| 36 | 	ExpressionIterator::EnumerateChildren(expression&: *expr, callback: [&](unique_ptr<Expression> &child) { | 
| 37 | 		child = ExpressionRewriter::ApplyRules(op, rules, expr: std::move(child), changes_made); | 
| 38 | 	}); | 
| 39 | 	return expr; | 
| 40 | } | 
| 41 |  | 
| 42 | unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(unique_ptr<Expression> child, Value value) { | 
| 43 | 	vector<unique_ptr<Expression>> children; | 
| 44 | 	children.push_back(x: make_uniq<BoundConstantExpression>(args&: value)); | 
| 45 | 	children.push_back(x: std::move(child)); | 
| 46 | 	return ConstantOrNull(children: std::move(children), value: std::move(value)); | 
| 47 | } | 
| 48 |  | 
| 49 | unique_ptr<Expression> ExpressionRewriter::ConstantOrNull(vector<unique_ptr<Expression>> children, Value value) { | 
| 50 | 	auto type = value.type(); | 
| 51 | 	children.insert(position: children.begin(), x: make_uniq<BoundConstantExpression>(args&: value)); | 
| 52 | 	return make_uniq<BoundFunctionExpression>(args&: type, args: ConstantOrNull::GetFunction(return_type: type), args: std::move(children), | 
| 53 | 	                                          args: ConstantOrNull::Bind(value: std::move(value))); | 
| 54 | } | 
| 55 |  | 
| 56 | void ExpressionRewriter::VisitOperator(LogicalOperator &op) { | 
| 57 | 	VisitOperatorChildren(op); | 
| 58 | 	this->op = &op; | 
| 59 |  | 
| 60 | 	to_apply_rules.clear(); | 
| 61 | 	for (auto &rule : rules) { | 
| 62 | 		if (rule->logical_root && !rule->logical_root->Match(type: op.type)) { | 
| 63 | 			// this rule does not apply to this type of LogicalOperator | 
| 64 | 			continue; | 
| 65 | 		} | 
| 66 | 		to_apply_rules.push_back(x: *rule); | 
| 67 | 	} | 
| 68 | 	if (to_apply_rules.empty()) { | 
| 69 | 		// no rules to apply on this node | 
| 70 | 		return; | 
| 71 | 	} | 
| 72 |  | 
| 73 | 	VisitOperatorExpressions(op); | 
| 74 |  | 
| 75 | 	// if it is a LogicalFilter, we split up filter conjunctions again | 
| 76 | 	if (op.type == LogicalOperatorType::LOGICAL_FILTER) { | 
| 77 | 		auto &filter = op.Cast<LogicalFilter>(); | 
| 78 | 		filter.SplitPredicates(); | 
| 79 | 	} | 
| 80 | } | 
| 81 |  | 
| 82 | void ExpressionRewriter::VisitExpression(unique_ptr<Expression> *expression) { | 
| 83 | 	bool changes_made; | 
| 84 | 	do { | 
| 85 | 		changes_made = false; | 
| 86 | 		*expression = ExpressionRewriter::ApplyRules(op&: *op, rules: to_apply_rules, expr: std::move(*expression), changes_made, is_root: true); | 
| 87 | 	} while (changes_made); | 
| 88 | } | 
| 89 |  | 
| 90 | ClientContext &Rule::GetContext() const { | 
| 91 | 	return rewriter.context; | 
| 92 | } | 
| 93 |  | 
| 94 | } // namespace duckdb | 
| 95 |  |