1 | #include "duckdb/planner/operator/logical_join.hpp" |
2 | |
3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | #include "duckdb/planner/expression_iterator.hpp" |
5 | |
6 | using namespace duckdb; |
7 | using namespace std; |
8 | |
9 | LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type) |
10 | : LogicalOperator(logical_type), join_type(join_type) { |
11 | } |
12 | |
13 | vector<ColumnBinding> LogicalJoin::GetColumnBindings() { |
14 | auto left_bindings = MapBindings(children[0]->GetColumnBindings(), left_projection_map); |
15 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
16 | // for SEMI and ANTI join we only project the left hand side |
17 | return left_bindings; |
18 | } |
19 | if (join_type == JoinType::MARK) { |
20 | // for MARK join we project the left hand side plus the MARK column |
21 | left_bindings.push_back(ColumnBinding(mark_index, 0)); |
22 | return left_bindings; |
23 | } |
24 | // for other join types we project both the LHS and the RHS |
25 | auto right_bindings = MapBindings(children[1]->GetColumnBindings(), right_projection_map); |
26 | left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); |
27 | return left_bindings; |
28 | } |
29 | |
30 | void LogicalJoin::ResolveTypes() { |
31 | types = MapTypes(children[0]->types, left_projection_map); |
32 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
33 | // for SEMI and ANTI join we only project the left hand side |
34 | return; |
35 | } |
36 | if (join_type == JoinType::MARK) { |
37 | // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK |
38 | types.push_back(TypeId::BOOL); |
39 | return; |
40 | } |
41 | // for any other join we project both sides |
42 | auto right_types = MapTypes(children[1]->types, right_projection_map); |
43 | types.insert(types.end(), right_types.begin(), right_types.end()); |
44 | } |
45 | |
46 | void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set<idx_t> &bindings) { |
47 | auto column_bindings = op.GetColumnBindings(); |
48 | for (auto binding : column_bindings) { |
49 | bindings.insert(binding.table_index); |
50 | } |
51 | } |
52 | |
53 | void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set<idx_t> &bindings) { |
54 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
55 | auto &colref = (BoundColumnRefExpression &)expr; |
56 | assert(colref.depth == 0); |
57 | bindings.insert(colref.binding.table_index); |
58 | } |
59 | ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { GetExpressionBindings(child, bindings); }); |
60 | } |
61 | |