1#include "duckdb/optimizer/filter_pushdown.hpp"
2#include "duckdb/planner/operator/logical_any_join.hpp"
3#include "duckdb/planner/operator/logical_comparison_join.hpp"
4#include "duckdb/planner/operator/logical_cross_product.hpp"
5#include "duckdb/planner/operator/logical_empty_result.hpp"
6
7namespace duckdb {
8
9using Filter = FilterPushdown::Filter;
10
11unique_ptr<LogicalOperator> FilterPushdown::PushdownInnerJoin(unique_ptr<LogicalOperator> op,
12 unordered_set<idx_t> &left_bindings,
13 unordered_set<idx_t> &right_bindings) {
14 auto &join = op->Cast<LogicalJoin>();
15 D_ASSERT(join.join_type == JoinType::INNER);
16 if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) {
17 return FinishPushdown(op: std::move(op));
18 }
19 // inner join: gather all the conditions of the inner join and add to the filter list
20 if (op->type == LogicalOperatorType::LOGICAL_ANY_JOIN) {
21 auto &any_join = join.Cast<LogicalAnyJoin>();
22 // any join: only one filter to add
23 if (AddFilter(expr: std::move(any_join.condition)) == FilterResult::UNSATISFIABLE) {
24 // filter statically evaluates to false, strip tree
25 return make_uniq<LogicalEmptyResult>(args: std::move(op));
26 }
27 } else if (op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN) {
28 // Don't mess with non-standard condition interpretations
29 return FinishPushdown(op: std::move(op));
30 } else {
31 // comparison join
32 D_ASSERT(op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN);
33 auto &comp_join = join.Cast<LogicalComparisonJoin>();
34 // turn the conditions into filters
35 for (auto &i : comp_join.conditions) {
36 auto condition = JoinCondition::CreateExpression(cond: std::move(i));
37 if (AddFilter(expr: std::move(condition)) == FilterResult::UNSATISFIABLE) {
38 // filter statically evaluates to false, strip tree
39 return make_uniq<LogicalEmptyResult>(args: std::move(op));
40 }
41 }
42 }
43 GenerateFilters();
44
45 // turn the inner join into a cross product
46 auto cross_product = make_uniq<LogicalCrossProduct>(args: std::move(op->children[0]), args: std::move(op->children[1]));
47 // then push down cross product
48 return PushdownCrossProduct(op: std::move(cross_product));
49}
50
51} // namespace duckdb
52