1#include "duckdb/planner/operator/logical_join.hpp"
2
3#include "duckdb/planner/expression/bound_columnref_expression.hpp"
4#include "duckdb/planner/expression_iterator.hpp"
5#include "duckdb/common/field_writer.hpp"
6
7namespace duckdb {
8
9LogicalJoin::LogicalJoin(JoinType join_type, LogicalOperatorType logical_type)
10 : LogicalOperator(logical_type), join_type(join_type) {
11}
12
13vector<ColumnBinding> LogicalJoin::GetColumnBindings() {
14 auto left_bindings = MapBindings(types: children[0]->GetColumnBindings(), projection_map: left_projection_map);
15 if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) {
16 // for SEMI and ANTI join we only project the left hand side
17 return left_bindings;
18 }
19 if (join_type == JoinType::MARK) {
20 // for MARK join we project the left hand side plus the MARK column
21 left_bindings.emplace_back(args&: mark_index, args: 0);
22 return left_bindings;
23 }
24 // for other join types we project both the LHS and the RHS
25 auto right_bindings = MapBindings(types: children[1]->GetColumnBindings(), projection_map: right_projection_map);
26 left_bindings.insert(position: left_bindings.end(), first: right_bindings.begin(), last: right_bindings.end());
27 return left_bindings;
28}
29
30void LogicalJoin::ResolveTypes() {
31 types = MapTypes(types: children[0]->types, projection_map: left_projection_map);
32 if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) {
33 // for SEMI and ANTI join we only project the left hand side
34 return;
35 }
36 if (join_type == JoinType::MARK) {
37 // for MARK join we project the left hand side, plus a BOOLEAN column indicating the MARK
38 types.emplace_back(args: LogicalType::BOOLEAN);
39 return;
40 }
41 // for any other join we project both sides
42 auto right_types = MapTypes(types: children[1]->types, projection_map: right_projection_map);
43 types.insert(position: types.end(), first: right_types.begin(), last: right_types.end());
44}
45
46void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set<idx_t> &bindings) {
47 auto column_bindings = op.GetColumnBindings();
48 for (auto binding : column_bindings) {
49 bindings.insert(x: binding.table_index);
50 }
51}
52
53void LogicalJoin::GetExpressionBindings(Expression &expr, unordered_set<idx_t> &bindings) {
54 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
55 auto &colref = expr.Cast<BoundColumnRefExpression>();
56 D_ASSERT(colref.depth == 0);
57 bindings.insert(x: colref.binding.table_index);
58 }
59 ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { GetExpressionBindings(expr&: child, bindings); });
60}
61
62void LogicalJoin::Serialize(FieldWriter &writer) const {
63 writer.WriteField<JoinType>(element: join_type);
64 writer.WriteField<idx_t>(element: mark_index);
65 writer.WriteList<idx_t>(elements: left_projection_map);
66 writer.WriteList<idx_t>(elements: right_projection_map);
67 // writer.WriteSerializableList(join_stats);
68}
69
70void LogicalJoin::Deserialize(LogicalJoin &join, LogicalDeserializationState &state, FieldReader &reader) {
71 join.join_type = reader.ReadRequired<JoinType>();
72 join.mark_index = reader.ReadRequired<idx_t>();
73 join.left_projection_map = reader.ReadRequiredList<idx_t>();
74 join.right_projection_map = reader.ReadRequiredList<idx_t>();
75 // join.join_stats = reader.ReadRequiredSerializableList<BaseStatistics>(reader.GetSource());
76}
77
78} // namespace duckdb
79