1 | #include "duckdb/optimizer/rule/conjunction_simplification.hpp" |
2 | |
3 | #include "duckdb/execution/expression_executor.hpp" |
4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
10 | // match on a ComparisonExpression that has a ConstantExpression as a check |
11 | auto op = make_uniq<ConjunctionExpressionMatcher>(); |
12 | op->matchers.push_back(x: make_uniq<FoldableConstantMatcher>()); |
13 | op->policy = SetMatcher::Policy::SOME; |
14 | root = std::move(op); |
15 | } |
16 | |
17 | unique_ptr<Expression> ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, |
18 | const Expression &expr) { |
19 | for (idx_t i = 0; i < conj.children.size(); i++) { |
20 | if (conj.children[i].get() == &expr) { |
21 | // erase the expression |
22 | conj.children.erase(position: conj.children.begin() + i); |
23 | break; |
24 | } |
25 | } |
26 | if (conj.children.size() == 1) { |
27 | // one expression remaining: simply return that expression and erase the conjunction |
28 | return std::move(conj.children[0]); |
29 | } |
30 | return nullptr; |
31 | } |
32 | |
33 | unique_ptr<Expression> ConjunctionSimplificationRule::Apply(LogicalOperator &op, |
34 | vector<reference<Expression>> &bindings, bool &changes_made, |
35 | bool is_root) { |
36 | auto &conjunction = bindings[0].get().Cast<BoundConjunctionExpression>(); |
37 | auto &constant_expr = bindings[1].get(); |
38 | // the constant_expr is a scalar expression that we have to fold |
39 | // use an ExpressionExecutor to execute the expression |
40 | D_ASSERT(constant_expr.IsFoldable()); |
41 | Value constant_value; |
42 | if (!ExpressionExecutor::TryEvaluateScalar(context&: GetContext(), expr: constant_expr, result&: constant_value)) { |
43 | return nullptr; |
44 | } |
45 | constant_value = constant_value.DefaultCastAs(target_type: LogicalType::BOOLEAN); |
46 | if (constant_value.IsNull()) { |
47 | // we can't simplify conjunctions with a constant NULL |
48 | return nullptr; |
49 | } |
50 | if (conjunction.type == ExpressionType::CONJUNCTION_AND) { |
51 | if (!BooleanValue::Get(value: constant_value)) { |
52 | // FALSE in AND, result of expression is false |
53 | return make_uniq<BoundConstantExpression>(args: Value::BOOLEAN(value: false)); |
54 | } else { |
55 | // TRUE in AND, remove the expression from the set |
56 | return RemoveExpression(conj&: conjunction, expr: constant_expr); |
57 | } |
58 | } else { |
59 | D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR); |
60 | if (!BooleanValue::Get(value: constant_value)) { |
61 | // FALSE in OR, remove the expression from the set |
62 | return RemoveExpression(conj&: conjunction, expr: constant_expr); |
63 | } else { |
64 | // TRUE in OR, result of expression is true |
65 | return make_uniq<BoundConstantExpression>(args: Value::BOOLEAN(value: true)); |
66 | } |
67 | } |
68 | } |
69 | |
70 | } // namespace duckdb |
71 | |