1 | #include "duckdb/planner/joinside.hpp" |
2 | |
3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
7 | #include "duckdb/planner/expression_iterator.hpp" |
8 | #include "duckdb/common/field_writer.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | unique_ptr<Expression> JoinCondition::CreateExpression(JoinCondition cond) { |
13 | auto bound_comparison = |
14 | make_uniq<BoundComparisonExpression>(args&: cond.comparison, args: std::move(cond.left), args: std::move(cond.right)); |
15 | return std::move(bound_comparison); |
16 | } |
17 | |
18 | unique_ptr<Expression> JoinCondition::CreateExpression(vector<JoinCondition> conditions) { |
19 | unique_ptr<Expression> result; |
20 | for (auto &cond : conditions) { |
21 | auto expr = CreateExpression(cond: std::move(cond)); |
22 | if (!result) { |
23 | result = std::move(expr); |
24 | } else { |
25 | auto conj = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND, args: std::move(expr), |
26 | args: std::move(result)); |
27 | result = std::move(conj); |
28 | } |
29 | } |
30 | return result; |
31 | } |
32 | |
33 | //! Serializes a JoinCondition to a stand-alone binary blob |
34 | void JoinCondition::Serialize(Serializer &serializer) const { |
35 | FieldWriter writer(serializer); |
36 | writer.WriteOptional(element: left); |
37 | writer.WriteOptional(element: right); |
38 | writer.WriteField<ExpressionType>(element: comparison); |
39 | writer.Finalize(); |
40 | } |
41 | |
42 | //! Deserializes a blob back into a JoinCondition |
43 | JoinCondition JoinCondition::Deserialize(Deserializer &source, PlanDeserializationState &state) { |
44 | auto result = JoinCondition(); |
45 | |
46 | FieldReader reader(source); |
47 | auto left = reader.ReadOptional<Expression>(default_value: nullptr, args&: state); |
48 | auto right = reader.ReadOptional<Expression>(default_value: nullptr, args&: state); |
49 | result.left = std::move(left); |
50 | result.right = std::move(right); |
51 | result.comparison = reader.ReadRequired<ExpressionType>(); |
52 | reader.Finalize(); |
53 | return result; |
54 | } |
55 | |
56 | JoinSide JoinSide::CombineJoinSide(JoinSide left, JoinSide right) { |
57 | if (left == JoinSide::NONE) { |
58 | return right; |
59 | } |
60 | if (right == JoinSide::NONE) { |
61 | return left; |
62 | } |
63 | if (left != right) { |
64 | return JoinSide::BOTH; |
65 | } |
66 | return left; |
67 | } |
68 | |
69 | JoinSide JoinSide::GetJoinSide(idx_t table_binding, const unordered_set<idx_t> &left_bindings, |
70 | const unordered_set<idx_t> &right_bindings) { |
71 | if (left_bindings.find(x: table_binding) != left_bindings.end()) { |
72 | // column references table on left side |
73 | D_ASSERT(right_bindings.find(table_binding) == right_bindings.end()); |
74 | return JoinSide::LEFT; |
75 | } else { |
76 | // column references table on right side |
77 | D_ASSERT(right_bindings.find(table_binding) != right_bindings.end()); |
78 | return JoinSide::RIGHT; |
79 | } |
80 | } |
81 | |
82 | JoinSide JoinSide::GetJoinSide(Expression &expression, const unordered_set<idx_t> &left_bindings, |
83 | const unordered_set<idx_t> &right_bindings) { |
84 | if (expression.type == ExpressionType::BOUND_COLUMN_REF) { |
85 | auto &colref = expression.Cast<BoundColumnRefExpression>(); |
86 | if (colref.depth > 0) { |
87 | throw Exception("Non-inner join on correlated columns not supported" ); |
88 | } |
89 | return GetJoinSide(table_binding: colref.binding.table_index, left_bindings, right_bindings); |
90 | } |
91 | D_ASSERT(expression.type != ExpressionType::BOUND_REF); |
92 | if (expression.type == ExpressionType::SUBQUERY) { |
93 | D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); |
94 | auto &subquery = expression.Cast<BoundSubqueryExpression>(); |
95 | JoinSide side = JoinSide::NONE; |
96 | if (subquery.child) { |
97 | side = GetJoinSide(expression&: *subquery.child, left_bindings, right_bindings); |
98 | } |
99 | // correlated subquery, check the side of each of correlated columns in the subquery |
100 | for (auto &corr : subquery.binder->correlated_columns) { |
101 | if (corr.depth > 1) { |
102 | // correlated column has depth > 1 |
103 | // it does not refer to any table in the current set of bindings |
104 | return JoinSide::BOTH; |
105 | } |
106 | auto correlated_side = GetJoinSide(table_binding: corr.binding.table_index, left_bindings, right_bindings); |
107 | side = CombineJoinSide(left: side, right: correlated_side); |
108 | } |
109 | return side; |
110 | } |
111 | JoinSide join_side = JoinSide::NONE; |
112 | ExpressionIterator::EnumerateChildren(expr&: expression, callback: [&](Expression &child) { |
113 | auto child_side = GetJoinSide(expression&: child, left_bindings, right_bindings); |
114 | join_side = CombineJoinSide(left: child_side, right: join_side); |
115 | }); |
116 | return join_side; |
117 | } |
118 | |
119 | JoinSide JoinSide::GetJoinSide(const unordered_set<idx_t> &bindings, const unordered_set<idx_t> &left_bindings, |
120 | const unordered_set<idx_t> &right_bindings) { |
121 | JoinSide side = JoinSide::NONE; |
122 | for (auto binding : bindings) { |
123 | side = CombineJoinSide(left: side, right: GetJoinSide(table_binding: binding, left_bindings, right_bindings)); |
124 | } |
125 | return side; |
126 | } |
127 | |
128 | } // namespace duckdb |
129 | |