| 1 | #include "duckdb/planner/joinside.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 4 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
| 6 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
| 7 | #include "duckdb/planner/expression_iterator.hpp" |
| 8 | #include "duckdb/common/field_writer.hpp" |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | unique_ptr<Expression> JoinCondition::CreateExpression(JoinCondition cond) { |
| 13 | auto bound_comparison = |
| 14 | make_uniq<BoundComparisonExpression>(args&: cond.comparison, args: std::move(cond.left), args: std::move(cond.right)); |
| 15 | return std::move(bound_comparison); |
| 16 | } |
| 17 | |
| 18 | unique_ptr<Expression> JoinCondition::CreateExpression(vector<JoinCondition> conditions) { |
| 19 | unique_ptr<Expression> result; |
| 20 | for (auto &cond : conditions) { |
| 21 | auto expr = CreateExpression(cond: std::move(cond)); |
| 22 | if (!result) { |
| 23 | result = std::move(expr); |
| 24 | } else { |
| 25 | auto conj = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND, args: std::move(expr), |
| 26 | args: std::move(result)); |
| 27 | result = std::move(conj); |
| 28 | } |
| 29 | } |
| 30 | return result; |
| 31 | } |
| 32 | |
| 33 | //! Serializes a JoinCondition to a stand-alone binary blob |
| 34 | void JoinCondition::Serialize(Serializer &serializer) const { |
| 35 | FieldWriter writer(serializer); |
| 36 | writer.WriteOptional(element: left); |
| 37 | writer.WriteOptional(element: right); |
| 38 | writer.WriteField<ExpressionType>(element: comparison); |
| 39 | writer.Finalize(); |
| 40 | } |
| 41 | |
| 42 | //! Deserializes a blob back into a JoinCondition |
| 43 | JoinCondition JoinCondition::Deserialize(Deserializer &source, PlanDeserializationState &state) { |
| 44 | auto result = JoinCondition(); |
| 45 | |
| 46 | FieldReader reader(source); |
| 47 | auto left = reader.ReadOptional<Expression>(default_value: nullptr, args&: state); |
| 48 | auto right = reader.ReadOptional<Expression>(default_value: nullptr, args&: state); |
| 49 | result.left = std::move(left); |
| 50 | result.right = std::move(right); |
| 51 | result.comparison = reader.ReadRequired<ExpressionType>(); |
| 52 | reader.Finalize(); |
| 53 | return result; |
| 54 | } |
| 55 | |
| 56 | JoinSide JoinSide::CombineJoinSide(JoinSide left, JoinSide right) { |
| 57 | if (left == JoinSide::NONE) { |
| 58 | return right; |
| 59 | } |
| 60 | if (right == JoinSide::NONE) { |
| 61 | return left; |
| 62 | } |
| 63 | if (left != right) { |
| 64 | return JoinSide::BOTH; |
| 65 | } |
| 66 | return left; |
| 67 | } |
| 68 | |
| 69 | JoinSide JoinSide::GetJoinSide(idx_t table_binding, const unordered_set<idx_t> &left_bindings, |
| 70 | const unordered_set<idx_t> &right_bindings) { |
| 71 | if (left_bindings.find(x: table_binding) != left_bindings.end()) { |
| 72 | // column references table on left side |
| 73 | D_ASSERT(right_bindings.find(table_binding) == right_bindings.end()); |
| 74 | return JoinSide::LEFT; |
| 75 | } else { |
| 76 | // column references table on right side |
| 77 | D_ASSERT(right_bindings.find(table_binding) != right_bindings.end()); |
| 78 | return JoinSide::RIGHT; |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | JoinSide JoinSide::GetJoinSide(Expression &expression, const unordered_set<idx_t> &left_bindings, |
| 83 | const unordered_set<idx_t> &right_bindings) { |
| 84 | if (expression.type == ExpressionType::BOUND_COLUMN_REF) { |
| 85 | auto &colref = expression.Cast<BoundColumnRefExpression>(); |
| 86 | if (colref.depth > 0) { |
| 87 | throw Exception("Non-inner join on correlated columns not supported" ); |
| 88 | } |
| 89 | return GetJoinSide(table_binding: colref.binding.table_index, left_bindings, right_bindings); |
| 90 | } |
| 91 | D_ASSERT(expression.type != ExpressionType::BOUND_REF); |
| 92 | if (expression.type == ExpressionType::SUBQUERY) { |
| 93 | D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); |
| 94 | auto &subquery = expression.Cast<BoundSubqueryExpression>(); |
| 95 | JoinSide side = JoinSide::NONE; |
| 96 | if (subquery.child) { |
| 97 | side = GetJoinSide(expression&: *subquery.child, left_bindings, right_bindings); |
| 98 | } |
| 99 | // correlated subquery, check the side of each of correlated columns in the subquery |
| 100 | for (auto &corr : subquery.binder->correlated_columns) { |
| 101 | if (corr.depth > 1) { |
| 102 | // correlated column has depth > 1 |
| 103 | // it does not refer to any table in the current set of bindings |
| 104 | return JoinSide::BOTH; |
| 105 | } |
| 106 | auto correlated_side = GetJoinSide(table_binding: corr.binding.table_index, left_bindings, right_bindings); |
| 107 | side = CombineJoinSide(left: side, right: correlated_side); |
| 108 | } |
| 109 | return side; |
| 110 | } |
| 111 | JoinSide join_side = JoinSide::NONE; |
| 112 | ExpressionIterator::EnumerateChildren(expr&: expression, callback: [&](Expression &child) { |
| 113 | auto child_side = GetJoinSide(expression&: child, left_bindings, right_bindings); |
| 114 | join_side = CombineJoinSide(left: child_side, right: join_side); |
| 115 | }); |
| 116 | return join_side; |
| 117 | } |
| 118 | |
| 119 | JoinSide JoinSide::GetJoinSide(const unordered_set<idx_t> &bindings, const unordered_set<idx_t> &left_bindings, |
| 120 | const unordered_set<idx_t> &right_bindings) { |
| 121 | JoinSide side = JoinSide::NONE; |
| 122 | for (auto binding : bindings) { |
| 123 | side = CombineJoinSide(left: side, right: GetJoinSide(table_binding: binding, left_bindings, right_bindings)); |
| 124 | } |
| 125 | return side; |
| 126 | } |
| 127 | |
| 128 | } // namespace duckdb |
| 129 | |