| 1 | #include "duckdb/optimizer/filter_pushdown.hpp" |
| 2 | #include "duckdb/planner/operator/logical_comparison_join.hpp" |
| 3 | #include "duckdb/planner/operator/logical_cross_product.hpp" |
| 4 | |
| 5 | namespace duckdb { |
| 6 | |
| 7 | using Filter = FilterPushdown::Filter; |
| 8 | |
| 9 | unique_ptr<LogicalOperator> FilterPushdown::PushdownCrossProduct(unique_ptr<LogicalOperator> op) { |
| 10 | D_ASSERT(op->type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); |
| 11 | FilterPushdown left_pushdown(optimizer), right_pushdown(optimizer); |
| 12 | vector<unique_ptr<Expression>> join_expressions; |
| 13 | unordered_set<idx_t> left_bindings, right_bindings; |
| 14 | if (!filters.empty()) { |
| 15 | // check to see into which side we should push the filters |
| 16 | // first get the LHS and RHS bindings |
| 17 | LogicalJoin::GetTableReferences(op&: *op->children[0], bindings&: left_bindings); |
| 18 | LogicalJoin::GetTableReferences(op&: *op->children[1], bindings&: right_bindings); |
| 19 | // now check the set of filters |
| 20 | for (auto &f : filters) { |
| 21 | auto side = JoinSide::GetJoinSide(bindings: f->bindings, left_bindings, right_bindings); |
| 22 | if (side == JoinSide::LEFT) { |
| 23 | // bindings match left side: push into left |
| 24 | left_pushdown.filters.push_back(x: std::move(f)); |
| 25 | } else if (side == JoinSide::RIGHT) { |
| 26 | // bindings match right side: push into right |
| 27 | right_pushdown.filters.push_back(x: std::move(f)); |
| 28 | } else { |
| 29 | D_ASSERT(side == JoinSide::BOTH || side == JoinSide::NONE); |
| 30 | // bindings match both: turn into join condition |
| 31 | join_expressions.push_back(x: std::move(f->filter)); |
| 32 | } |
| 33 | } |
| 34 | } |
| 35 | |
| 36 | op->children[0] = left_pushdown.Rewrite(op: std::move(op->children[0])); |
| 37 | op->children[1] = right_pushdown.Rewrite(op: std::move(op->children[1])); |
| 38 | |
| 39 | if (!join_expressions.empty()) { |
| 40 | // join conditions found: turn into inner join |
| 41 | // extract join conditions |
| 42 | vector<JoinCondition> conditions; |
| 43 | vector<unique_ptr<Expression>> arbitrary_expressions; |
| 44 | auto join_type = JoinType::INNER; |
| 45 | LogicalComparisonJoin::ExtractJoinConditions(type: join_type, left_child&: op->children[0], right_child&: op->children[1], left_bindings, |
| 46 | right_bindings, expressions&: join_expressions, conditions, |
| 47 | arbitrary_expressions); |
| 48 | // create the join from the join conditions |
| 49 | return LogicalComparisonJoin::CreateJoin(type: JoinType::INNER, ref_type: JoinRefType::REGULAR, left_child: std::move(op->children[0]), |
| 50 | right_child: std::move(op->children[1]), conditions: std::move(conditions), |
| 51 | arbitrary_expressions: std::move(arbitrary_expressions)); |
| 52 | } else { |
| 53 | // no join conditions found: keep as cross product |
| 54 | return op; |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | } // namespace duckdb |
| 59 | |