1 | #include "duckdb/planner/operator/logical_join.hpp" |
2 | |
3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | #include "duckdb/planner/expression_iterator.hpp" |
5 | #include "duckdb/common/field_writer.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type) |
10 | : LogicalOperator(logical_type), join_type(join_type) { |
11 | } |
12 | |
13 | vector<ColumnBinding> LogicalJoin::GetColumnBindings() { |
14 | auto left_bindings = MapBindings(types: children[0]->GetColumnBindings(), projection_map: left_projection_map); |
15 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
16 | // for SEMI and ANTI join we only project the left hand side |
17 | return left_bindings; |
18 | } |
19 | if (join_type == JoinType::MARK) { |
20 | // for MARK join we project the left hand side plus the MARK column |
21 | left_bindings.emplace_back(args&: mark_index, args: 0); |
22 | return left_bindings; |
23 | } |
24 | // for other join types we project both the LHS and the RHS |
25 | auto right_bindings = MapBindings(types: children[1]->GetColumnBindings(), projection_map: right_projection_map); |
26 | left_bindings.insert(position: left_bindings.end(), first: right_bindings.begin(), last: right_bindings.end()); |
27 | return left_bindings; |
28 | } |
29 | |
30 | void LogicalJoin::ResolveTypes() { |
31 | types = MapTypes(types: children[0]->types, projection_map: left_projection_map); |
32 | if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { |
33 | // for SEMI and ANTI join we only project the left hand side |
34 | return; |
35 | } |
36 | if (join_type == JoinType::MARK) { |
37 | // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK |
38 | types.emplace_back(args: LogicalType::BOOLEAN); |
39 | return; |
40 | } |
41 | // for any other join we project both sides |
42 | auto right_types = MapTypes(types: children[1]->types, projection_map: right_projection_map); |
43 | types.insert(position: types.end(), first: right_types.begin(), last: right_types.end()); |
44 | } |
45 | |
46 | void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set<idx_t> &bindings) { |
47 | auto column_bindings = op.GetColumnBindings(); |
48 | for (auto binding : column_bindings) { |
49 | bindings.insert(x: binding.table_index); |
50 | } |
51 | } |
52 | |
53 | void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set<idx_t> &bindings) { |
54 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
55 | auto &colref = expr.Cast<BoundColumnRefExpression>(); |
56 | D_ASSERT(colref.depth == 0); |
57 | bindings.insert(x: colref.binding.table_index); |
58 | } |
59 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { GetExpressionBindings(expr&: child, bindings); }); |
60 | } |
61 | |
62 | void LogicalJoin::Serialize(FieldWriter &writer) const { |
63 | writer.WriteField<JoinType>(element: join_type); |
64 | writer.WriteField<idx_t>(element: mark_index); |
65 | writer.WriteList<idx_t>(elements: left_projection_map); |
66 | writer.WriteList<idx_t>(elements: right_projection_map); |
67 | // writer.WriteSerializableList(join_stats); |
68 | } |
69 | |
70 | void LogicalJoin::Deserialize(LogicalJoin &join, LogicalDeserializationState &state, FieldReader &reader) { |
71 | join.join_type = reader.ReadRequired<JoinType>(); |
72 | join.mark_index = reader.ReadRequired<idx_t>(); |
73 | join.left_projection_map = reader.ReadRequiredList<idx_t>(); |
74 | join.right_projection_map = reader.ReadRequiredList<idx_t>(); |
75 | // join.join_stats = reader.ReadRequiredSerializableList<BaseStatistics>(reader.GetSource()); |
76 | } |
77 | |
78 | } // namespace duckdb |
79 | |