| 1 | #include "duckdb/optimizer/rule/distributivity.hpp" | 
|---|
| 2 |  | 
|---|
| 3 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" | 
|---|
| 4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" | 
|---|
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
|---|
| 6 | #include "duckdb/planner/expression_iterator.hpp" | 
|---|
| 7 | #include "duckdb/planner/operator/logical_filter.hpp" | 
|---|
| 8 |  | 
|---|
| 9 | using namespace duckdb; | 
|---|
| 10 | using namespace std; | 
|---|
| 11 |  | 
|---|
| 12 | DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { | 
|---|
| 13 | // we match on an OR expression within a LogicalFilter node | 
|---|
| 14 | root = make_unique<ExpressionMatcher>(); | 
|---|
| 15 | root->expr_type = make_unique<SpecificExpressionTypeMatcher>(ExpressionType::CONJUNCTION_OR); | 
|---|
| 16 | } | 
|---|
| 17 |  | 
|---|
| 18 | void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { | 
|---|
| 19 | if (expr.type == ExpressionType::CONJUNCTION_AND) { | 
|---|
| 20 | auto &and_expr = (BoundConjunctionExpression &)expr; | 
|---|
| 21 | for (auto &child : and_expr.children) { | 
|---|
| 22 | set.insert(child.get()); | 
|---|
| 23 | } | 
|---|
| 24 | } else { | 
|---|
| 25 | set.insert(&expr); | 
|---|
| 26 | } | 
|---|
| 27 | } | 
|---|
| 28 |  | 
|---|
| 29 | unique_ptr<Expression> DistributivityRule::(BoundConjunctionExpression &conj, idx_t idx, | 
|---|
| 30 | Expression &expr) { | 
|---|
| 31 | auto &child = conj.children[idx]; | 
|---|
| 32 | unique_ptr<Expression> result; | 
|---|
| 33 | if (child->type == ExpressionType::CONJUNCTION_AND) { | 
|---|
| 34 | // AND, remove expression from the list | 
|---|
| 35 | auto &and_expr = (BoundConjunctionExpression &)*child; | 
|---|
| 36 | for (idx_t i = 0; i < and_expr.children.size(); i++) { | 
|---|
| 37 | if (Expression::Equals(and_expr.children[i].get(), &expr)) { | 
|---|
| 38 | result = move(and_expr.children[i]); | 
|---|
| 39 | and_expr.children.erase(and_expr.children.begin() + i); | 
|---|
| 40 | break; | 
|---|
| 41 | } | 
|---|
| 42 | } | 
|---|
| 43 | if (and_expr.children.size() == 1) { | 
|---|
| 44 | conj.children[idx] = move(and_expr.children[0]); | 
|---|
| 45 | } | 
|---|
| 46 | } else { | 
|---|
| 47 | // not an AND node! remove the entire expression | 
|---|
| 48 | // this happens in the case of e.g. (X AND B) OR X | 
|---|
| 49 | assert(Expression::Equals(child.get(), &expr)); | 
|---|
| 50 | result = move(child); | 
|---|
| 51 | conj.children[idx] = nullptr; | 
|---|
| 52 | } | 
|---|
| 53 | assert(result); | 
|---|
| 54 | return result; | 
|---|
| 55 | } | 
|---|
| 56 |  | 
|---|
| 57 | unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, | 
|---|
| 58 | bool &changes_made) { | 
|---|
| 59 | auto initial_or = (BoundConjunctionExpression *)bindings[0]; | 
|---|
| 60 |  | 
|---|
| 61 | // we want to find expressions that occur in each of the children of the OR | 
|---|
| 62 | // i.e. (X AND A) OR (X AND B) => X occurs in all branches | 
|---|
| 63 | // first, for the initial child, we create an expression set of which expressions occur | 
|---|
| 64 | // this is our initial candidate set (in the example: [X, A]) | 
|---|
| 65 | expression_set_t candidate_set; | 
|---|
| 66 | AddExpressionSet(*initial_or->children[0], candidate_set); | 
|---|
| 67 | // now for each of the remaining children, we create a set again and intersect them | 
|---|
| 68 | // in our example: the second set would be [X, B] | 
|---|
| 69 | // the intersection would leave [X] | 
|---|
| 70 | for (idx_t i = 1; i < initial_or->children.size(); i++) { | 
|---|
| 71 | expression_set_t next_set; | 
|---|
| 72 | AddExpressionSet(*initial_or->children[i], next_set); | 
|---|
| 73 | expression_set_t intersect_result; | 
|---|
| 74 | for (auto &expr : candidate_set) { | 
|---|
| 75 | if (next_set.find(expr) != next_set.end()) { | 
|---|
| 76 | intersect_result.insert(expr); | 
|---|
| 77 | } | 
|---|
| 78 | } | 
|---|
| 79 | candidate_set = intersect_result; | 
|---|
| 80 | } | 
|---|
| 81 | if (candidate_set.size() == 0) { | 
|---|
| 82 | // nothing found: abort | 
|---|
| 83 | return nullptr; | 
|---|
| 84 | } | 
|---|
| 85 | // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of | 
|---|
| 86 | // the OR | 
|---|
| 87 | auto new_root = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND); | 
|---|
| 88 | for (auto &expr : candidate_set) { | 
|---|
| 89 | assert(initial_or->children.size() > 0); | 
|---|
| 90 |  | 
|---|
| 91 | // extract the expression from the first child of the OR | 
|---|
| 92 | auto result = ExtractExpression(*initial_or, 0, (Expression &)*expr); | 
|---|
| 93 | // now for the subsequent expressions, simply remove the expression | 
|---|
| 94 | for (idx_t i = 1; i < initial_or->children.size(); i++) { | 
|---|
| 95 | ExtractExpression(*initial_or, i, *result); | 
|---|
| 96 | } | 
|---|
| 97 | // now we add the expression to the new root | 
|---|
| 98 | new_root->children.push_back(move(result)); | 
|---|
| 99 | // remove any expressions that were set to nullptr | 
|---|
| 100 | for (idx_t i = 0; i < initial_or->children.size(); i++) { | 
|---|
| 101 | if (!initial_or->children[i]) { | 
|---|
| 102 | initial_or->children.erase(initial_or->children.begin() + i); | 
|---|
| 103 | i--; | 
|---|
| 104 | } | 
|---|
| 105 | } | 
|---|
| 106 | } | 
|---|
| 107 | // finally we need to add the remaining expressions in the OR to the new root | 
|---|
| 108 | if (initial_or->children.size() == 1) { | 
|---|
| 109 | // one child: skip the OR entirely and only add the single child | 
|---|
| 110 | new_root->children.push_back(move(initial_or->children[0])); | 
|---|
| 111 | } else if (initial_or->children.size() > 1) { | 
|---|
| 112 | // multiple children still remain: push them into a new OR and add that to the new root | 
|---|
| 113 | auto new_or = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR); | 
|---|
| 114 | for (auto &child : initial_or->children) { | 
|---|
| 115 | new_or->children.push_back(move(child)); | 
|---|
| 116 | } | 
|---|
| 117 | new_root->children.push_back(move(new_or)); | 
|---|
| 118 | } | 
|---|
| 119 | // finally return the new root | 
|---|
| 120 | if (new_root->children.size() == 1) { | 
|---|
| 121 | return move(new_root->children[0]); | 
|---|
| 122 | } | 
|---|
| 123 | return move(new_root); | 
|---|
| 124 | } | 
|---|
| 125 |  | 
|---|