| 1 | #include "duckdb/planner/operator/logical_join.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 4 | #include "duckdb/planner/expression_iterator.hpp" |
| 5 | |
| 6 | using namespace duckdb; |
| 7 | using namespace std; |
| 8 | |
| 9 | LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type) |
| 10 | : LogicalOperator(logical_type), join_type(join_type) { |
| 11 | } |
| 12 | |
| 13 | vector<ColumnBinding> LogicalJoin::GetColumnBindings() { |
| 14 | auto left_bindings = MapBindings(children[0]->GetColumnBindings(), left_projection_map); |
| 15 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
| 16 | // for SEMI and ANTI join we only project the left hand side |
| 17 | return left_bindings; |
| 18 | } |
| 19 | if (join_type == JoinType::MARK) { |
| 20 | // for MARK join we project the left hand side plus the MARK column |
| 21 | left_bindings.push_back(ColumnBinding(mark_index, 0)); |
| 22 | return left_bindings; |
| 23 | } |
| 24 | // for other join types we project both the LHS and the RHS |
| 25 | auto right_bindings = MapBindings(children[1]->GetColumnBindings(), right_projection_map); |
| 26 | left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); |
| 27 | return left_bindings; |
| 28 | } |
| 29 | |
| 30 | void LogicalJoin::ResolveTypes() { |
| 31 | types = MapTypes(children[0]->types, left_projection_map); |
| 32 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
| 33 | // for SEMI and ANTI join we only project the left hand side |
| 34 | return; |
| 35 | } |
| 36 | if (join_type == JoinType::MARK) { |
| 37 | // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK |
| 38 | types.push_back(TypeId::BOOL); |
| 39 | return; |
| 40 | } |
| 41 | // for any other join we project both sides |
| 42 | auto right_types = MapTypes(children[1]->types, right_projection_map); |
| 43 | types.insert(types.end(), right_types.begin(), right_types.end()); |
| 44 | } |
| 45 | |
| 46 | void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set<idx_t> &bindings) { |
| 47 | auto column_bindings = op.GetColumnBindings(); |
| 48 | for (auto binding : column_bindings) { |
| 49 | bindings.insert(binding.table_index); |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set<idx_t> &bindings) { |
| 54 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
| 55 | auto &colref = (BoundColumnRefExpression &)expr; |
| 56 | assert(colref.depth == 0); |
| 57 | bindings.insert(colref.binding.table_index); |
| 58 | } |
| 59 | ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { GetExpressionBindings(child, bindings); }); |
| 60 | } |
| 61 | |