1#include "duckdb/planner/joinside.hpp"
2
3#include "duckdb/planner/expression/bound_columnref_expression.hpp"
4#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
6#include "duckdb/planner/expression/bound_subquery_expression.hpp"
7#include "duckdb/planner/expression_iterator.hpp"
8#include "duckdb/common/field_writer.hpp"
9
10namespace duckdb {
11
12unique_ptr<Expression> JoinCondition::CreateExpression(JoinCondition cond) {
13 auto bound_comparison =
14 make_uniq<BoundComparisonExpression>(args&: cond.comparison, args: std::move(cond.left), args: std::move(cond.right));
15 return std::move(bound_comparison);
16}
17
18unique_ptr<Expression> JoinCondition::CreateExpression(vector<JoinCondition> conditions) {
19 unique_ptr<Expression> result;
20 for (auto &cond : conditions) {
21 auto expr = CreateExpression(cond: std::move(cond));
22 if (!result) {
23 result = std::move(expr);
24 } else {
25 auto conj = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND, args: std::move(expr),
26 args: std::move(result));
27 result = std::move(conj);
28 }
29 }
30 return result;
31}
32
33//! Serializes a JoinCondition to a stand-alone binary blob
34void JoinCondition::Serialize(Serializer &serializer) const {
35 FieldWriter writer(serializer);
36 writer.WriteOptional(element: left);
37 writer.WriteOptional(element: right);
38 writer.WriteField<ExpressionType>(element: comparison);
39 writer.Finalize();
40}
41
42//! Deserializes a blob back into a JoinCondition
43JoinCondition JoinCondition::Deserialize(Deserializer &source, PlanDeserializationState &state) {
44 auto result = JoinCondition();
45
46 FieldReader reader(source);
47 auto left = reader.ReadOptional<Expression>(default_value: nullptr, args&: state);
48 auto right = reader.ReadOptional<Expression>(default_value: nullptr, args&: state);
49 result.left = std::move(left);
50 result.right = std::move(right);
51 result.comparison = reader.ReadRequired<ExpressionType>();
52 reader.Finalize();
53 return result;
54}
55
56JoinSide JoinSide::CombineJoinSide(JoinSide left, JoinSide right) {
57 if (left == JoinSide::NONE) {
58 return right;
59 }
60 if (right == JoinSide::NONE) {
61 return left;
62 }
63 if (left != right) {
64 return JoinSide::BOTH;
65 }
66 return left;
67}
68
69JoinSide JoinSide::GetJoinSide(idx_t table_binding, const unordered_set<idx_t> &left_bindings,
70 const unordered_set<idx_t> &right_bindings) {
71 if (left_bindings.find(x: table_binding) != left_bindings.end()) {
72 // column references table on left side
73 D_ASSERT(right_bindings.find(table_binding) == right_bindings.end());
74 return JoinSide::LEFT;
75 } else {
76 // column references table on right side
77 D_ASSERT(right_bindings.find(table_binding) != right_bindings.end());
78 return JoinSide::RIGHT;
79 }
80}
81
82JoinSide JoinSide::GetJoinSide(Expression &expression, const unordered_set<idx_t> &left_bindings,
83 const unordered_set<idx_t> &right_bindings) {
84 if (expression.type == ExpressionType::BOUND_COLUMN_REF) {
85 auto &colref = expression.Cast<BoundColumnRefExpression>();
86 if (colref.depth > 0) {
87 throw Exception("Non-inner join on correlated columns not supported");
88 }
89 return GetJoinSide(table_binding: colref.binding.table_index, left_bindings, right_bindings);
90 }
91 D_ASSERT(expression.type != ExpressionType::BOUND_REF);
92 if (expression.type == ExpressionType::SUBQUERY) {
93 D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY);
94 auto &subquery = expression.Cast<BoundSubqueryExpression>();
95 JoinSide side = JoinSide::NONE;
96 if (subquery.child) {
97 side = GetJoinSide(expression&: *subquery.child, left_bindings, right_bindings);
98 }
99 // correlated subquery, check the side of each of correlated columns in the subquery
100 for (auto &corr : subquery.binder->correlated_columns) {
101 if (corr.depth > 1) {
102 // correlated column has depth > 1
103 // it does not refer to any table in the current set of bindings
104 return JoinSide::BOTH;
105 }
106 auto correlated_side = GetJoinSide(table_binding: corr.binding.table_index, left_bindings, right_bindings);
107 side = CombineJoinSide(left: side, right: correlated_side);
108 }
109 return side;
110 }
111 JoinSide join_side = JoinSide::NONE;
112 ExpressionIterator::EnumerateChildren(expr&: expression, callback: [&](Expression &child) {
113 auto child_side = GetJoinSide(expression&: child, left_bindings, right_bindings);
114 join_side = CombineJoinSide(left: child_side, right: join_side);
115 });
116 return join_side;
117}
118
119JoinSide JoinSide::GetJoinSide(const unordered_set<idx_t> &bindings, const unordered_set<idx_t> &left_bindings,
120 const unordered_set<idx_t> &right_bindings) {
121 JoinSide side = JoinSide::NONE;
122 for (auto binding : bindings) {
123 side = CombineJoinSide(left: side, right: GetJoinSide(table_binding: binding, left_bindings, right_bindings));
124 }
125 return side;
126}
127
128} // namespace duckdb
129