1#include "duckdb/optimizer/rule/equal_or_null_simplification.hpp"
2
3#include "duckdb/planner/expression/bound_comparison_expression.hpp"
4#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
5#include "duckdb/planner/expression/bound_operator_expression.hpp"
6
7namespace duckdb {
8
9EqualOrNullSimplification::EqualOrNullSimplification(ExpressionRewriter &rewriter) : Rule(rewriter) {
10 // match on OR conjunction
11 auto op = make_uniq<ConjunctionExpressionMatcher>();
12 op->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::CONJUNCTION_OR);
13 op->policy = SetMatcher::Policy::SOME;
14
15 // equi comparison on one side
16 auto equal_child = make_uniq<ComparisonExpressionMatcher>();
17 equal_child->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::COMPARE_EQUAL);
18 equal_child->policy = SetMatcher::Policy::SOME;
19 op->matchers.push_back(x: std::move(equal_child));
20
21 // AND conjuction on the other
22 auto and_child = make_uniq<ConjunctionExpressionMatcher>();
23 and_child->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::CONJUNCTION_AND);
24 and_child->policy = SetMatcher::Policy::SOME;
25
26 // IS NULL tests inside AND
27 auto isnull_child = make_uniq<ExpressionMatcher>();
28 isnull_child->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::OPERATOR_IS_NULL);
29 // I could try to use std::make_uniq for a copy, but it's available from C++14 only
30 auto isnull_child2 = make_uniq<ExpressionMatcher>();
31 isnull_child2->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::OPERATOR_IS_NULL);
32 and_child->matchers.push_back(x: std::move(isnull_child));
33 and_child->matchers.push_back(x: std::move(isnull_child2));
34
35 op->matchers.push_back(x: std::move(and_child));
36 root = std::move(op);
37}
38
39// a=b OR (a IS NULL AND b IS NULL) to a IS NOT DISTINCT FROM b
40static unique_ptr<Expression> TryRewriteEqualOrIsNull(Expression &equal_expr, Expression &and_expr) {
41 if (equal_expr.type != ExpressionType::COMPARE_EQUAL || and_expr.type != ExpressionType::CONJUNCTION_AND) {
42 return nullptr;
43 }
44
45 auto &equal_cast = equal_expr.Cast<BoundComparisonExpression>();
46 auto &and_cast = and_expr.Cast<BoundConjunctionExpression>();
47
48 if (and_cast.children.size() != 2) {
49 return nullptr;
50 }
51
52 // Make sure on the AND conjuction the relevant conditions appear
53 auto &a_exp = *equal_cast.left;
54 auto &b_exp = *equal_cast.right;
55 bool a_is_null_found = false;
56 bool b_is_null_found = false;
57
58 for (const auto &item : and_cast.children) {
59 auto &next_exp = *item;
60
61 if (next_exp.type == ExpressionType::OPERATOR_IS_NULL) {
62 auto &next_exp_cast = next_exp.Cast<BoundOperatorExpression>();
63 auto &child = *next_exp_cast.children[0];
64
65 // Test for equality on both 'a' and 'b' expressions
66 if (Expression::Equals(left: child, right: a_exp)) {
67 a_is_null_found = true;
68 } else if (Expression::Equals(left: child, right: b_exp)) {
69 b_is_null_found = true;
70 } else {
71 return nullptr;
72 }
73 } else {
74 return nullptr;
75 }
76 }
77 if (a_is_null_found && b_is_null_found) {
78 return make_uniq<BoundComparisonExpression>(args: ExpressionType::COMPARE_NOT_DISTINCT_FROM,
79 args: std::move(equal_cast.left), args: std::move(equal_cast.right));
80 }
81 return nullptr;
82}
83
84unique_ptr<Expression> EqualOrNullSimplification::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
85 bool &changes_made, bool is_root) {
86 const Expression &or_exp = bindings[0].get();
87
88 if (or_exp.type != ExpressionType::CONJUNCTION_OR) {
89 return nullptr;
90 }
91
92 const auto &or_exp_cast = or_exp.Cast<BoundConjunctionExpression>();
93
94 if (or_exp_cast.children.size() != 2) {
95 return nullptr;
96 }
97
98 auto &left_exp = *or_exp_cast.children[0];
99 auto &right_exp = *or_exp_cast.children[1];
100 // Test for: a=b OR (a IS NULL AND b IS NULL)
101 auto first_try = TryRewriteEqualOrIsNull(equal_expr&: left_exp, and_expr&: right_exp);
102 if (first_try) {
103 return first_try;
104 }
105 // Test for: (a IS NULL AND b IS NULL) OR a=b
106 return TryRewriteEqualOrIsNull(equal_expr&: right_exp, and_expr&: left_exp);
107}
108
109} // namespace duckdb
110