1 | #include "duckdb/optimizer/rule/distributivity.hpp" |
2 | |
3 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" |
4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
6 | #include "duckdb/planner/expression_iterator.hpp" |
7 | #include "duckdb/planner/operator/logical_filter.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
12 | // we match on an OR expression within a LogicalFilter node |
13 | root = make_uniq<ExpressionMatcher>(); |
14 | root->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::CONJUNCTION_OR); |
15 | } |
16 | |
17 | void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { |
18 | if (expr.type == ExpressionType::CONJUNCTION_AND) { |
19 | auto &and_expr = expr.Cast<BoundConjunctionExpression>(); |
20 | for (auto &child : and_expr.children) { |
21 | set.insert(x: *child); |
22 | } |
23 | } else { |
24 | set.insert(x: expr); |
25 | } |
26 | } |
27 | |
28 | unique_ptr<Expression> DistributivityRule::(BoundConjunctionExpression &conj, idx_t idx, |
29 | Expression &expr) { |
30 | auto &child = conj.children[idx]; |
31 | unique_ptr<Expression> result; |
32 | if (child->type == ExpressionType::CONJUNCTION_AND) { |
33 | // AND, remove expression from the list |
34 | auto &and_expr = child->Cast<BoundConjunctionExpression>(); |
35 | for (idx_t i = 0; i < and_expr.children.size(); i++) { |
36 | if (and_expr.children[i]->Equals(other: expr)) { |
37 | result = std::move(and_expr.children[i]); |
38 | and_expr.children.erase(position: and_expr.children.begin() + i); |
39 | break; |
40 | } |
41 | } |
42 | if (and_expr.children.size() == 1) { |
43 | conj.children[idx] = std::move(and_expr.children[0]); |
44 | } |
45 | } else { |
46 | // not an AND node! remove the entire expression |
47 | // this happens in the case of e.g. (X AND B) OR X |
48 | D_ASSERT(child->Equals(expr)); |
49 | result = std::move(child); |
50 | conj.children[idx] = nullptr; |
51 | } |
52 | D_ASSERT(result); |
53 | return result; |
54 | } |
55 | |
56 | unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
57 | bool &changes_made, bool is_root) { |
58 | auto &initial_or = bindings[0].get().Cast<BoundConjunctionExpression>(); |
59 | |
60 | // we want to find expressions that occur in each of the children of the OR |
61 | // i.e. (X AND A) OR (X AND B) => X occurs in all branches |
62 | // first, for the initial child, we create an expression set of which expressions occur |
63 | // this is our initial candidate set (in the example: [X, A]) |
64 | expression_set_t candidate_set; |
65 | AddExpressionSet(expr&: *initial_or.children[0], set&: candidate_set); |
66 | // now for each of the remaining children, we create a set again and intersect them |
67 | // in our example: the second set would be [X, B] |
68 | // the intersection would leave [X] |
69 | for (idx_t i = 1; i < initial_or.children.size(); i++) { |
70 | expression_set_t next_set; |
71 | AddExpressionSet(expr&: *initial_or.children[i], set&: next_set); |
72 | expression_set_t intersect_result; |
73 | for (auto &expr : candidate_set) { |
74 | if (next_set.find(x: expr) != next_set.end()) { |
75 | intersect_result.insert(x: expr); |
76 | } |
77 | } |
78 | candidate_set = intersect_result; |
79 | } |
80 | if (candidate_set.empty()) { |
81 | // nothing found: abort |
82 | return nullptr; |
83 | } |
84 | // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of |
85 | // the OR |
86 | auto new_root = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND); |
87 | for (auto &expr : candidate_set) { |
88 | D_ASSERT(initial_or.children.size() > 0); |
89 | |
90 | // extract the expression from the first child of the OR |
91 | auto result = ExtractExpression(conj&: initial_or, idx: 0, expr&: expr.get()); |
92 | // now for the subsequent expressions, simply remove the expression |
93 | for (idx_t i = 1; i < initial_or.children.size(); i++) { |
94 | ExtractExpression(conj&: initial_or, idx: i, expr&: *result); |
95 | } |
96 | // now we add the expression to the new root |
97 | new_root->children.push_back(x: std::move(result)); |
98 | } |
99 | |
100 | // check if we completely erased one of the children of the OR |
101 | // this happens if we have an OR in the form of "X OR (X AND A)" |
102 | // the left child will be completely empty, as it only contains common expressions |
103 | // in this case, any other children are not useful: |
104 | // X OR (X AND A) is the same as "X" |
105 | // since (1) only tuples that do not qualify "X" will not pass this predicate |
106 | // and (2) all tuples that qualify "X" will pass this predicate |
107 | for (idx_t i = 0; i < initial_or.children.size(); i++) { |
108 | if (!initial_or.children[i]) { |
109 | if (new_root->children.size() <= 1) { |
110 | return std::move(new_root->children[0]); |
111 | } else { |
112 | return std::move(new_root); |
113 | } |
114 | } |
115 | } |
116 | // finally we need to add the remaining expressions in the OR to the new root |
117 | if (initial_or.children.size() == 1) { |
118 | // one child: skip the OR entirely and only add the single child |
119 | new_root->children.push_back(x: std::move(initial_or.children[0])); |
120 | } else if (initial_or.children.size() > 1) { |
121 | // multiple children still remain: push them into a new OR and add that to the new root |
122 | auto new_or = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_OR); |
123 | for (auto &child : initial_or.children) { |
124 | new_or->children.push_back(x: std::move(child)); |
125 | } |
126 | new_root->children.push_back(x: std::move(new_or)); |
127 | } |
128 | // finally return the new root |
129 | if (new_root->children.size() == 1) { |
130 | return std::move(new_root->children[0]); |
131 | } |
132 | return std::move(new_root); |
133 | } |
134 | |
135 | } // namespace duckdb |
136 | |