1 | #include "duckdb/optimizer/rule/distributivity.hpp" |
2 | |
3 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" |
4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
6 | #include "duckdb/planner/expression_iterator.hpp" |
7 | #include "duckdb/planner/operator/logical_filter.hpp" |
8 | |
9 | using namespace duckdb; |
10 | using namespace std; |
11 | |
12 | DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
13 | // we match on an OR expression within a LogicalFilter node |
14 | root = make_unique<ExpressionMatcher>(); |
15 | root->expr_type = make_unique<SpecificExpressionTypeMatcher>(ExpressionType::CONJUNCTION_OR); |
16 | } |
17 | |
18 | void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { |
19 | if (expr.type == ExpressionType::CONJUNCTION_AND) { |
20 | auto &and_expr = (BoundConjunctionExpression &)expr; |
21 | for (auto &child : and_expr.children) { |
22 | set.insert(child.get()); |
23 | } |
24 | } else { |
25 | set.insert(&expr); |
26 | } |
27 | } |
28 | |
29 | unique_ptr<Expression> DistributivityRule::(BoundConjunctionExpression &conj, idx_t idx, |
30 | Expression &expr) { |
31 | auto &child = conj.children[idx]; |
32 | unique_ptr<Expression> result; |
33 | if (child->type == ExpressionType::CONJUNCTION_AND) { |
34 | // AND, remove expression from the list |
35 | auto &and_expr = (BoundConjunctionExpression &)*child; |
36 | for (idx_t i = 0; i < and_expr.children.size(); i++) { |
37 | if (Expression::Equals(and_expr.children[i].get(), &expr)) { |
38 | result = move(and_expr.children[i]); |
39 | and_expr.children.erase(and_expr.children.begin() + i); |
40 | break; |
41 | } |
42 | } |
43 | if (and_expr.children.size() == 1) { |
44 | conj.children[idx] = move(and_expr.children[0]); |
45 | } |
46 | } else { |
47 | // not an AND node! remove the entire expression |
48 | // this happens in the case of e.g. (X AND B) OR X |
49 | assert(Expression::Equals(child.get(), &expr)); |
50 | result = move(child); |
51 | conj.children[idx] = nullptr; |
52 | } |
53 | assert(result); |
54 | return result; |
55 | } |
56 | |
57 | unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, |
58 | bool &changes_made) { |
59 | auto initial_or = (BoundConjunctionExpression *)bindings[0]; |
60 | |
61 | // we want to find expressions that occur in each of the children of the OR |
62 | // i.e. (X AND A) OR (X AND B) => X occurs in all branches |
63 | // first, for the initial child, we create an expression set of which expressions occur |
64 | // this is our initial candidate set (in the example: [X, A]) |
65 | expression_set_t candidate_set; |
66 | AddExpressionSet(*initial_or->children[0], candidate_set); |
67 | // now for each of the remaining children, we create a set again and intersect them |
68 | // in our example: the second set would be [X, B] |
69 | // the intersection would leave [X] |
70 | for (idx_t i = 1; i < initial_or->children.size(); i++) { |
71 | expression_set_t next_set; |
72 | AddExpressionSet(*initial_or->children[i], next_set); |
73 | expression_set_t intersect_result; |
74 | for (auto &expr : candidate_set) { |
75 | if (next_set.find(expr) != next_set.end()) { |
76 | intersect_result.insert(expr); |
77 | } |
78 | } |
79 | candidate_set = intersect_result; |
80 | } |
81 | if (candidate_set.size() == 0) { |
82 | // nothing found: abort |
83 | return nullptr; |
84 | } |
85 | // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of |
86 | // the OR |
87 | auto new_root = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND); |
88 | for (auto &expr : candidate_set) { |
89 | assert(initial_or->children.size() > 0); |
90 | |
91 | // extract the expression from the first child of the OR |
92 | auto result = ExtractExpression(*initial_or, 0, (Expression &)*expr); |
93 | // now for the subsequent expressions, simply remove the expression |
94 | for (idx_t i = 1; i < initial_or->children.size(); i++) { |
95 | ExtractExpression(*initial_or, i, *result); |
96 | } |
97 | // now we add the expression to the new root |
98 | new_root->children.push_back(move(result)); |
99 | // remove any expressions that were set to nullptr |
100 | for (idx_t i = 0; i < initial_or->children.size(); i++) { |
101 | if (!initial_or->children[i]) { |
102 | initial_or->children.erase(initial_or->children.begin() + i); |
103 | i--; |
104 | } |
105 | } |
106 | } |
107 | // finally we need to add the remaining expressions in the OR to the new root |
108 | if (initial_or->children.size() == 1) { |
109 | // one child: skip the OR entirely and only add the single child |
110 | new_root->children.push_back(move(initial_or->children[0])); |
111 | } else if (initial_or->children.size() > 1) { |
112 | // multiple children still remain: push them into a new OR and add that to the new root |
113 | auto new_or = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR); |
114 | for (auto &child : initial_or->children) { |
115 | new_or->children.push_back(move(child)); |
116 | } |
117 | new_root->children.push_back(move(new_or)); |
118 | } |
119 | // finally return the new root |
120 | if (new_root->children.size() == 1) { |
121 | return move(new_root->children[0]); |
122 | } |
123 | return move(new_root); |
124 | } |
125 | |