| 1 | #include "duckdb/optimizer/rule/distributivity.hpp" | 
| 2 |  | 
| 3 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" | 
| 4 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" | 
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
| 6 | #include "duckdb/planner/expression_iterator.hpp" | 
| 7 | #include "duckdb/planner/operator/logical_filter.hpp" | 
| 8 |  | 
| 9 | using namespace duckdb; | 
| 10 | using namespace std; | 
| 11 |  | 
| 12 | DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewriter) { | 
| 13 | 	// we match on an OR expression within a LogicalFilter node | 
| 14 | 	root = make_unique<ExpressionMatcher>(); | 
| 15 | 	root->expr_type = make_unique<SpecificExpressionTypeMatcher>(ExpressionType::CONJUNCTION_OR); | 
| 16 | } | 
| 17 |  | 
| 18 | void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { | 
| 19 | 	if (expr.type == ExpressionType::CONJUNCTION_AND) { | 
| 20 | 		auto &and_expr = (BoundConjunctionExpression &)expr; | 
| 21 | 		for (auto &child : and_expr.children) { | 
| 22 | 			set.insert(child.get()); | 
| 23 | 		} | 
| 24 | 	} else { | 
| 25 | 		set.insert(&expr); | 
| 26 | 	} | 
| 27 | } | 
| 28 |  | 
| 29 | unique_ptr<Expression> DistributivityRule::(BoundConjunctionExpression &conj, idx_t idx, | 
| 30 |                                                              Expression &expr) { | 
| 31 | 	auto &child = conj.children[idx]; | 
| 32 | 	unique_ptr<Expression> result; | 
| 33 | 	if (child->type == ExpressionType::CONJUNCTION_AND) { | 
| 34 | 		// AND, remove expression from the list | 
| 35 | 		auto &and_expr = (BoundConjunctionExpression &)*child; | 
| 36 | 		for (idx_t i = 0; i < and_expr.children.size(); i++) { | 
| 37 | 			if (Expression::Equals(and_expr.children[i].get(), &expr)) { | 
| 38 | 				result = move(and_expr.children[i]); | 
| 39 | 				and_expr.children.erase(and_expr.children.begin() + i); | 
| 40 | 				break; | 
| 41 | 			} | 
| 42 | 		} | 
| 43 | 		if (and_expr.children.size() == 1) { | 
| 44 | 			conj.children[idx] = move(and_expr.children[0]); | 
| 45 | 		} | 
| 46 | 	} else { | 
| 47 | 		// not an AND node! remove the entire expression | 
| 48 | 		// this happens in the case of e.g. (X AND B) OR X | 
| 49 | 		assert(Expression::Equals(child.get(), &expr)); | 
| 50 | 		result = move(child); | 
| 51 | 		conj.children[idx] = nullptr; | 
| 52 | 	} | 
| 53 | 	assert(result); | 
| 54 | 	return result; | 
| 55 | } | 
| 56 |  | 
| 57 | unique_ptr<Expression> DistributivityRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, | 
| 58 |                                                  bool &changes_made) { | 
| 59 | 	auto initial_or = (BoundConjunctionExpression *)bindings[0]; | 
| 60 |  | 
| 61 | 	// we want to find expressions that occur in each of the children of the OR | 
| 62 | 	// i.e. (X AND A) OR (X AND B) => X occurs in all branches | 
| 63 | 	// first, for the initial child, we create an expression set of which expressions occur | 
| 64 | 	// this is our initial candidate set (in the example: [X, A]) | 
| 65 | 	expression_set_t candidate_set; | 
| 66 | 	AddExpressionSet(*initial_or->children[0], candidate_set); | 
| 67 | 	// now for each of the remaining children, we create a set again and intersect them | 
| 68 | 	// in our example: the second set would be [X, B] | 
| 69 | 	// the intersection would leave [X] | 
| 70 | 	for (idx_t i = 1; i < initial_or->children.size(); i++) { | 
| 71 | 		expression_set_t next_set; | 
| 72 | 		AddExpressionSet(*initial_or->children[i], next_set); | 
| 73 | 		expression_set_t intersect_result; | 
| 74 | 		for (auto &expr : candidate_set) { | 
| 75 | 			if (next_set.find(expr) != next_set.end()) { | 
| 76 | 				intersect_result.insert(expr); | 
| 77 | 			} | 
| 78 | 		} | 
| 79 | 		candidate_set = intersect_result; | 
| 80 | 	} | 
| 81 | 	if (candidate_set.size() == 0) { | 
| 82 | 		// nothing found: abort | 
| 83 | 		return nullptr; | 
| 84 | 	} | 
| 85 | 	// now for each of the remaining expressions in the candidate set we know that it is contained in all branches of | 
| 86 | 	// the OR | 
| 87 | 	auto new_root = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND); | 
| 88 | 	for (auto &expr : candidate_set) { | 
| 89 | 		assert(initial_or->children.size() > 0); | 
| 90 |  | 
| 91 | 		// extract the expression from the first child of the OR | 
| 92 | 		auto result = ExtractExpression(*initial_or, 0, (Expression &)*expr); | 
| 93 | 		// now for the subsequent expressions, simply remove the expression | 
| 94 | 		for (idx_t i = 1; i < initial_or->children.size(); i++) { | 
| 95 | 			ExtractExpression(*initial_or, i, *result); | 
| 96 | 		} | 
| 97 | 		// now we add the expression to the new root | 
| 98 | 		new_root->children.push_back(move(result)); | 
| 99 | 		// remove any expressions that were set to nullptr | 
| 100 | 		for (idx_t i = 0; i < initial_or->children.size(); i++) { | 
| 101 | 			if (!initial_or->children[i]) { | 
| 102 | 				initial_or->children.erase(initial_or->children.begin() + i); | 
| 103 | 				i--; | 
| 104 | 			} | 
| 105 | 		} | 
| 106 | 	} | 
| 107 | 	// finally we need to add the remaining expressions in the OR to the new root | 
| 108 | 	if (initial_or->children.size() == 1) { | 
| 109 | 		// one child: skip the OR entirely and only add the single child | 
| 110 | 		new_root->children.push_back(move(initial_or->children[0])); | 
| 111 | 	} else if (initial_or->children.size() > 1) { | 
| 112 | 		// multiple children still remain: push them into a new OR and add that to the new root | 
| 113 | 		auto new_or = make_unique<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_OR); | 
| 114 | 		for (auto &child : initial_or->children) { | 
| 115 | 			new_or->children.push_back(move(child)); | 
| 116 | 		} | 
| 117 | 		new_root->children.push_back(move(new_or)); | 
| 118 | 	} | 
| 119 | 	// finally return the new root | 
| 120 | 	if (new_root->children.size() == 1) { | 
| 121 | 		return move(new_root->children[0]); | 
| 122 | 	} | 
| 123 | 	return move(new_root); | 
| 124 | } | 
| 125 |  |