1#include "duckdb/optimizer/rule/conjunction_simplification.hpp"
2
3#include "duckdb/execution/expression_executor.hpp"
4#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
5#include "duckdb/planner/expression/bound_constant_expression.hpp"
6
7namespace duckdb {
8
9ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
10 // match on a ComparisonExpression that has a ConstantExpression as a check
11 auto op = make_uniq<ConjunctionExpressionMatcher>();
12 op->matchers.push_back(x: make_uniq<FoldableConstantMatcher>());
13 op->policy = SetMatcher::Policy::SOME;
14 root = std::move(op);
15}
16
17unique_ptr<Expression> ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj,
18 const Expression &expr) {
19 for (idx_t i = 0; i < conj.children.size(); i++) {
20 if (conj.children[i].get() == &expr) {
21 // erase the expression
22 conj.children.erase(position: conj.children.begin() + i);
23 break;
24 }
25 }
26 if (conj.children.size() == 1) {
27 // one expression remaining: simply return that expression and erase the conjunction
28 return std::move(conj.children[0]);
29 }
30 return nullptr;
31}
32
33unique_ptr<Expression> ConjunctionSimplificationRule::Apply(LogicalOperator &op,
34 vector<reference<Expression>> &bindings, bool &changes_made,
35 bool is_root) {
36 auto &conjunction = bindings[0].get().Cast<BoundConjunctionExpression>();
37 auto &constant_expr = bindings[1].get();
38 // the constant_expr is a scalar expression that we have to fold
39 // use an ExpressionExecutor to execute the expression
40 D_ASSERT(constant_expr.IsFoldable());
41 Value constant_value;
42 if (!ExpressionExecutor::TryEvaluateScalar(context&: GetContext(), expr: constant_expr, result&: constant_value)) {
43 return nullptr;
44 }
45 constant_value = constant_value.DefaultCastAs(target_type: LogicalType::BOOLEAN);
46 if (constant_value.IsNull()) {
47 // we can't simplify conjunctions with a constant NULL
48 return nullptr;
49 }
50 if (conjunction.type == ExpressionType::CONJUNCTION_AND) {
51 if (!BooleanValue::Get(value: constant_value)) {
52 // FALSE in AND, result of expression is false
53 return make_uniq<BoundConstantExpression>(args: Value::BOOLEAN(value: false));
54 } else {
55 // TRUE in AND, remove the expression from the set
56 return RemoveExpression(conj&: conjunction, expr: constant_expr);
57 }
58 } else {
59 D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR);
60 if (!BooleanValue::Get(value: constant_value)) {
61 // FALSE in OR, remove the expression from the set
62 return RemoveExpression(conj&: conjunction, expr: constant_expr);
63 } else {
64 // TRUE in OR, result of expression is true
65 return make_uniq<BoundConstantExpression>(args: Value::BOOLEAN(value: true));
66 }
67 }
68}
69
70} // namespace duckdb
71