| 1 | #include "duckdb/optimizer/rule/distributivity.hpp" |
| 2 | |
| 3 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" |
| 4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 6 | #include "duckdb/planner/expression_iterator.hpp" |
| 7 | #include "duckdb/planner/operator/logical_filter.hpp" |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
| 12 | // we match on an OR expression within a LogicalFilter node |
| 13 | root = make_uniq<ExpressionMatcher>(); |
| 14 | root->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::CONJUNCTION_OR); |
| 15 | } |
| 16 | |
| 17 | void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { |
| 18 | if (expr.type == ExpressionType::CONJUNCTION_AND) { |
| 19 | auto &and_expr = expr.Cast<BoundConjunctionExpression>(); |
| 20 | for (auto &child : and_expr.children) { |
| 21 | set.insert(x: *child); |
| 22 | } |
| 23 | } else { |
| 24 | set.insert(x: expr); |
| 25 | } |
| 26 | } |
| 27 | |
| 28 | unique_ptr<Expression> DistributivityRule::(BoundConjunctionExpression &conj, idx_t idx, |
| 29 | Expression &expr) { |
| 30 | auto &child = conj.children[idx]; |
| 31 | unique_ptr<Expression> result; |
| 32 | if (child->type == ExpressionType::CONJUNCTION_AND) { |
| 33 | // AND, remove expression from the list |
| 34 | auto &and_expr = child->Cast<BoundConjunctionExpression>(); |
| 35 | for (idx_t i = 0; i < and_expr.children.size(); i++) { |
| 36 | if (and_expr.children[i]->Equals(other: expr)) { |
| 37 | result = std::move(and_expr.children[i]); |
| 38 | and_expr.children.erase(position: and_expr.children.begin() + i); |
| 39 | break; |
| 40 | } |
| 41 | } |
| 42 | if (and_expr.children.size() == 1) { |
| 43 | conj.children[idx] = std::move(and_expr.children[0]); |
| 44 | } |
| 45 | } else { |
| 46 | // not an AND node! remove the entire expression |
| 47 | // this happens in the case of e.g. (X AND B) OR X |
| 48 | D_ASSERT(child->Equals(expr)); |
| 49 | result = std::move(child); |
| 50 | conj.children[idx] = nullptr; |
| 51 | } |
| 52 | D_ASSERT(result); |
| 53 | return result; |
| 54 | } |
| 55 | |
| 56 | unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
| 57 | bool &changes_made, bool is_root) { |
| 58 | auto &initial_or = bindings[0].get().Cast<BoundConjunctionExpression>(); |
| 59 | |
| 60 | // we want to find expressions that occur in each of the children of the OR |
| 61 | // i.e. (X AND A) OR (X AND B) => X occurs in all branches |
| 62 | // first, for the initial child, we create an expression set of which expressions occur |
| 63 | // this is our initial candidate set (in the example: [X, A]) |
| 64 | expression_set_t candidate_set; |
| 65 | AddExpressionSet(expr&: *initial_or.children[0], set&: candidate_set); |
| 66 | // now for each of the remaining children, we create a set again and intersect them |
| 67 | // in our example: the second set would be [X, B] |
| 68 | // the intersection would leave [X] |
| 69 | for (idx_t i = 1; i < initial_or.children.size(); i++) { |
| 70 | expression_set_t next_set; |
| 71 | AddExpressionSet(expr&: *initial_or.children[i], set&: next_set); |
| 72 | expression_set_t intersect_result; |
| 73 | for (auto &expr : candidate_set) { |
| 74 | if (next_set.find(x: expr) != next_set.end()) { |
| 75 | intersect_result.insert(x: expr); |
| 76 | } |
| 77 | } |
| 78 | candidate_set = intersect_result; |
| 79 | } |
| 80 | if (candidate_set.empty()) { |
| 81 | // nothing found: abort |
| 82 | return nullptr; |
| 83 | } |
| 84 | // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of |
| 85 | // the OR |
| 86 | auto new_root = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND); |
| 87 | for (auto &expr : candidate_set) { |
| 88 | D_ASSERT(initial_or.children.size() > 0); |
| 89 | |
| 90 | // extract the expression from the first child of the OR |
| 91 | auto result = ExtractExpression(conj&: initial_or, idx: 0, expr&: expr.get()); |
| 92 | // now for the subsequent expressions, simply remove the expression |
| 93 | for (idx_t i = 1; i < initial_or.children.size(); i++) { |
| 94 | ExtractExpression(conj&: initial_or, idx: i, expr&: *result); |
| 95 | } |
| 96 | // now we add the expression to the new root |
| 97 | new_root->children.push_back(x: std::move(result)); |
| 98 | } |
| 99 | |
| 100 | // check if we completely erased one of the children of the OR |
| 101 | // this happens if we have an OR in the form of "X OR (X AND A)" |
| 102 | // the left child will be completely empty, as it only contains common expressions |
| 103 | // in this case, any other children are not useful: |
| 104 | // X OR (X AND A) is the same as "X" |
| 105 | // since (1) only tuples that do not qualify "X" will not pass this predicate |
| 106 | // and (2) all tuples that qualify "X" will pass this predicate |
| 107 | for (idx_t i = 0; i < initial_or.children.size(); i++) { |
| 108 | if (!initial_or.children[i]) { |
| 109 | if (new_root->children.size() <= 1) { |
| 110 | return std::move(new_root->children[0]); |
| 111 | } else { |
| 112 | return std::move(new_root); |
| 113 | } |
| 114 | } |
| 115 | } |
| 116 | // finally we need to add the remaining expressions in the OR to the new root |
| 117 | if (initial_or.children.size() == 1) { |
| 118 | // one child: skip the OR entirely and only add the single child |
| 119 | new_root->children.push_back(x: std::move(initial_or.children[0])); |
| 120 | } else if (initial_or.children.size() > 1) { |
| 121 | // multiple children still remain: push them into a new OR and add that to the new root |
| 122 | auto new_or = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_OR); |
| 123 | for (auto &child : initial_or.children) { |
| 124 | new_or->children.push_back(x: std::move(child)); |
| 125 | } |
| 126 | new_root->children.push_back(x: std::move(new_or)); |
| 127 | } |
| 128 | // finally return the new root |
| 129 | if (new_root->children.size() == 1) { |
| 130 | return std::move(new_root->children[0]); |
| 131 | } |
| 132 | return std::move(new_root); |
| 133 | } |
| 134 | |
| 135 | } // namespace duckdb |
| 136 | |