1#include "duckdb/optimizer/rule/distributivity.hpp"
2
3#include "duckdb/optimizer/matcher/expression_matcher.hpp"
4#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
5#include "duckdb/planner/expression/bound_constant_expression.hpp"
6#include "duckdb/planner/expression_iterator.hpp"
7#include "duckdb/planner/operator/logical_filter.hpp"
8
9using namespace duckdb;
10using namespace std;
11
12DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
13 // we match on an OR expression within a LogicalFilter node
14 root = make_unique<ExpressionMatcher>();
15 root->expr_type = make_unique<SpecificExpressionTypeMatcher>(ExpressionType::CONJUNCTION_OR);
16}
17
18void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) {
19 if (expr.type == ExpressionType::CONJUNCTION_AND) {
20 auto &and_expr = (BoundConjunctionExpression &)expr;
21 for (auto &child : and_expr.children) {
22 set.insert(child.get());
23 }
24 } else {
25 set.insert(&expr);
26 }
27}
28
29unique_ptr<Expression> DistributivityRule::ExtractExpression(BoundConjunctionExpression &conj, idx_t idx,
30 Expression &expr) {
31 auto &child = conj.children[idx];
32 unique_ptr<Expression> result;
33 if (child->type == ExpressionType::CONJUNCTION_AND) {
34 // AND, remove expression from the list
35 auto &and_expr = (BoundConjunctionExpression &)*child;
36 for (idx_t i = 0; i < and_expr.children.size(); i++) {
37 if (Expression::Equals(and_expr.children[i].get(), &expr)) {
38 result = move(and_expr.children[i]);
39 and_expr.children.erase(and_expr.children.begin() + i);
40 break;
41 }
42 }
43 if (and_expr.children.size() == 1) {
44 conj.children[idx] = move(and_expr.children[0]);
45 }
46 } else {
47 // not an AND node! remove the entire expression
48 // this happens in the case of e.g. (X AND B) OR X
49 assert(Expression::Equals(child.get(), &expr));
50 result = move(child);
51 conj.children[idx] = nullptr;
52 }
53 assert(result);
54 return result;
55}
56
57unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<Expression *> &bindings,
58 bool &changes_made) {
59 auto initial_or = (BoundConjunctionExpression *)bindings[0];
60
61 // we want to find expressions that occur in each of the children of the OR
62 // i.e. (X AND A) OR (X AND B) => X occurs in all branches
63 // first, for the initial child, we create an expression set of which expressions occur
64 // this is our initial candidate set (in the example: [X, A])
65 expression_set_t candidate_set;
66 AddExpressionSet(*initial_or->children[0], candidate_set);
67 // now for each of the remaining children, we create a set again and intersect them
68 // in our example: the second set would be [X, B]
69 // the intersection would leave [X]
70 for (idx_t i = 1; i < initial_or->children.size(); i++) {
71 expression_set_t next_set;
72 AddExpressionSet(*initial_or->children[i], next_set);
73 expression_set_t intersect_result;
74 for (auto &expr : candidate_set) {
75 if (next_set.find(expr) != next_set.end()) {
76 intersect_result.insert(expr);
77 }
78 }
79 candidate_set = intersect_result;
80 }
81 if (candidate_set.size() == 0) {
82 // nothing found: abort
83 return nullptr;
84 }
85 // now for each of the remaining expressions in the candidate set we know that it is contained in all branches of
86 // the OR
87 auto new_root = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND);
88 for (auto &expr : candidate_set) {
89 assert(initial_or->children.size() > 0);
90
91 // extract the expression from the first child of the OR
92 auto result = ExtractExpression(*initial_or, 0, (Expression &)*expr);
93 // now for the subsequent expressions, simply remove the expression
94 for (idx_t i = 1; i < initial_or->children.size(); i++) {
95 ExtractExpression(*initial_or, i, *result);
96 }
97 // now we add the expression to the new root
98 new_root->children.push_back(move(result));
99 // remove any expressions that were set to nullptr
100 for (idx_t i = 0; i < initial_or->children.size(); i++) {
101 if (!initial_or->children[i]) {
102 initial_or->children.erase(initial_or->children.begin() + i);
103 i--;
104 }
105 }
106 }
107 // finally we need to add the remaining expressions in the OR to the new root
108 if (initial_or->children.size() == 1) {
109 // one child: skip the OR entirely and only add the single child
110 new_root->children.push_back(move(initial_or->children[0]));
111 } else if (initial_or->children.size() > 1) {
112 // multiple children still remain: push them into a new OR and add that to the new root
113 auto new_or = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR);
114 for (auto &child : initial_or->children) {
115 new_or->children.push_back(move(child));
116 }
117 new_root->children.push_back(move(new_or));
118 }
119 // finally return the new root
120 if (new_root->children.size() == 1) {
121 return move(new_root->children[0]);
122 }
123 return move(new_root);
124}
125