1#include "duckdb/optimizer/rule/distributivity.hpp"
2
3#include "duckdb/optimizer/matcher/expression_matcher.hpp"
4#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
5#include "duckdb/planner/expression/bound_constant_expression.hpp"
6#include "duckdb/planner/expression_iterator.hpp"
7#include "duckdb/planner/operator/logical_filter.hpp"
8
9namespace duckdb {
10
11DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
12 // we match on an OR expression within a LogicalFilter node
13 root = make_uniq<ExpressionMatcher>();
14 root->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::CONJUNCTION_OR);
15}
16
17void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) {
18 if (expr.type == ExpressionType::CONJUNCTION_AND) {
19 auto &and_expr = expr.Cast<BoundConjunctionExpression>();
20 for (auto &child : and_expr.children) {
21 set.insert(x: *child);
22 }
23 } else {
24 set.insert(x: expr);
25 }
26}
27
28unique_ptr<Expression> DistributivityRule::ExtractExpression(BoundConjunctionExpression &conj, idx_t idx,
29 Expression &expr) {
30 auto &child = conj.children[idx];
31 unique_ptr<Expression> result;
32 if (child->type == ExpressionType::CONJUNCTION_AND) {
33 // AND, remove expression from the list
34 auto &and_expr = child->Cast<BoundConjunctionExpression>();
35 for (idx_t i = 0; i < and_expr.children.size(); i++) {
36 if (and_expr.children[i]->Equals(other: expr)) {
37 result = std::move(and_expr.children[i]);
38 and_expr.children.erase(position: and_expr.children.begin() + i);
39 break;
40 }
41 }
42 if (and_expr.children.size() == 1) {
43 conj.children[idx] = std::move(and_expr.children[0]);
44 }
45 } else {
46 // not an AND node! remove the entire expression
47 // this happens in the case of e.g. (X AND B) OR X
48 D_ASSERT(child->Equals(expr));
49 result = std::move(child);
50 conj.children[idx] = nullptr;
51 }
52 D_ASSERT(result);
53 return result;
54}
55
56unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
57 bool &changes_made, bool is_root) {
58 auto &initial_or = bindings[0].get().Cast<BoundConjunctionExpression>();
59
60 // we want to find expressions that occur in each of the children of the OR
61 // i.e. (X AND A) OR (X AND B) => X occurs in all branches
62 // first, for the initial child, we create an expression set of which expressions occur
63 // this is our initial candidate set (in the example: [X, A])
64 expression_set_t candidate_set;
65 AddExpressionSet(expr&: *initial_or.children[0], set&: candidate_set);
66 // now for each of the remaining children, we create a set again and intersect them
67 // in our example: the second set would be [X, B]
68 // the intersection would leave [X]
69 for (idx_t i = 1; i < initial_or.children.size(); i++) {
70 expression_set_t next_set;
71 AddExpressionSet(expr&: *initial_or.children[i], set&: next_set);
72 expression_set_t intersect_result;
73 for (auto &expr : candidate_set) {
74 if (next_set.find(x: expr) != next_set.end()) {
75 intersect_result.insert(x: expr);
76 }
77 }
78 candidate_set = intersect_result;
79 }
80 if (candidate_set.empty()) {
81 // nothing found: abort
82 return nullptr;
83 }
84 // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of
85 // the OR
86 auto new_root = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND);
87 for (auto &expr : candidate_set) {
88 D_ASSERT(initial_or.children.size() > 0);
89
90 // extract the expression from the first child of the OR
91 auto result = ExtractExpression(conj&: initial_or, idx: 0, expr&: expr.get());
92 // now for the subsequent expressions, simply remove the expression
93 for (idx_t i = 1; i < initial_or.children.size(); i++) {
94 ExtractExpression(conj&: initial_or, idx: i, expr&: *result);
95 }
96 // now we add the expression to the new root
97 new_root->children.push_back(x: std::move(result));
98 }
99
100 // check if we completely erased one of the children of the OR
101 // this happens if we have an OR in the form of "X OR (X AND A)"
102 // the left child will be completely empty, as it only contains common expressions
103 // in this case, any other children are not useful:
104 // X OR (X AND A) is the same as "X"
105 // since (1) only tuples that do not qualify "X" will not pass this predicate
106 // and (2) all tuples that qualify "X" will pass this predicate
107 for (idx_t i = 0; i < initial_or.children.size(); i++) {
108 if (!initial_or.children[i]) {
109 if (new_root->children.size() <= 1) {
110 return std::move(new_root->children[0]);
111 } else {
112 return std::move(new_root);
113 }
114 }
115 }
116 // finally we need to add the remaining expressions in the OR to the new root
117 if (initial_or.children.size() == 1) {
118 // one child: skip the OR entirely and only add the single child
119 new_root->children.push_back(x: std::move(initial_or.children[0]));
120 } else if (initial_or.children.size() > 1) {
121 // multiple children still remain: push them into a new OR and add that to the new root
122 auto new_or = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_OR);
123 for (auto &child : initial_or.children) {
124 new_or->children.push_back(x: std::move(child));
125 }
126 new_root->children.push_back(x: std::move(new_or));
127 }
128 // finally return the new root
129 if (new_root->children.size() == 1) {
130 return std::move(new_root->children[0]);
131 }
132 return std::move(new_root);
133}
134
135} // namespace duckdb
136