1#include "duckdb/planner/expression/bound_columnref_expression.hpp"
2#include "duckdb/planner/expression/bound_comparison_expression.hpp"
3#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
4#include "duckdb/planner/expression/bound_constant_expression.hpp"
5#include "duckdb/planner/expression/bound_operator_expression.hpp"
6#include "duckdb/planner/expression/bound_subquery_expression.hpp"
7#include "duckdb/planner/expression_iterator.hpp"
8#include "duckdb/planner/binder.hpp"
9#include "duckdb/planner/operator/logical_any_join.hpp"
10#include "duckdb/planner/operator/logical_asof_join.hpp"
11#include "duckdb/planner/operator/logical_comparison_join.hpp"
12#include "duckdb/planner/operator/logical_cross_product.hpp"
13#include "duckdb/planner/operator/logical_filter.hpp"
14#include "duckdb/planner/operator/logical_positional_join.hpp"
15#include "duckdb/planner/tableref/bound_joinref.hpp"
16#include "duckdb/main/client_context.hpp"
17#include "duckdb/planner/expression_binder/lateral_binder.hpp"
18
19namespace duckdb {
20
21//! Create a JoinCondition from a comparison
22static bool CreateJoinCondition(Expression &expr, const unordered_set<idx_t> &left_bindings,
23 const unordered_set<idx_t> &right_bindings, vector<JoinCondition> &conditions) {
24 // comparison
25 auto &comparison = expr.Cast<BoundComparisonExpression>();
26 auto left_side = JoinSide::GetJoinSide(expression&: *comparison.left, left_bindings, right_bindings);
27 auto right_side = JoinSide::GetJoinSide(expression&: *comparison.right, left_bindings, right_bindings);
28 if (left_side != JoinSide::BOTH && right_side != JoinSide::BOTH) {
29 // join condition can be divided in a left/right side
30 JoinCondition condition;
31 condition.comparison = expr.type;
32 auto left = std::move(comparison.left);
33 auto right = std::move(comparison.right);
34 if (left_side == JoinSide::RIGHT) {
35 // left = right, right = left, flip the comparison symbol and reverse sides
36 swap(a&: left, b&: right);
37 condition.comparison = FlipComparisonExpression(type: expr.type);
38 }
39 condition.left = std::move(left);
40 condition.right = std::move(right);
41 conditions.push_back(x: std::move(condition));
42 return true;
43 }
44 return false;
45}
46
47void LogicalComparisonJoin::ExtractJoinConditions(JoinType type, unique_ptr<LogicalOperator> &left_child,
48 unique_ptr<LogicalOperator> &right_child,
49 const unordered_set<idx_t> &left_bindings,
50 const unordered_set<idx_t> &right_bindings,
51 vector<unique_ptr<Expression>> &expressions,
52 vector<JoinCondition> &conditions,
53 vector<unique_ptr<Expression>> &arbitrary_expressions) {
54 for (auto &expr : expressions) {
55 auto total_side = JoinSide::GetJoinSide(expression&: *expr, left_bindings, right_bindings);
56 if (total_side != JoinSide::BOTH) {
57 // join condition does not reference both sides, add it as filter under the join
58 if (type == JoinType::LEFT && total_side == JoinSide::RIGHT) {
59 // filter is on RHS and the join is a LEFT OUTER join, we can push it in the right child
60 if (right_child->type != LogicalOperatorType::LOGICAL_FILTER) {
61 // not a filter yet, push a new empty filter
62 auto filter = make_uniq<LogicalFilter>();
63 filter->AddChild(child: std::move(right_child));
64 right_child = std::move(filter);
65 }
66 // push the expression into the filter
67 auto &filter = right_child->Cast<LogicalFilter>();
68 filter.expressions.push_back(x: std::move(expr));
69 continue;
70 }
71 } else if ((expr->type >= ExpressionType::COMPARE_EQUAL &&
72 expr->type <= ExpressionType::COMPARE_GREATERTHANOREQUALTO) ||
73 expr->type == ExpressionType::COMPARE_DISTINCT_FROM ||
74 expr->type == ExpressionType::COMPARE_NOT_DISTINCT_FROM) {
75 // comparison, check if we can create a comparison JoinCondition
76 if (CreateJoinCondition(expr&: *expr, left_bindings, right_bindings, conditions)) {
77 // successfully created the join condition
78 continue;
79 }
80 }
81 arbitrary_expressions.push_back(x: std::move(expr));
82 }
83}
84
85void LogicalComparisonJoin::ExtractJoinConditions(JoinType type, unique_ptr<LogicalOperator> &left_child,
86 unique_ptr<LogicalOperator> &right_child,
87 vector<unique_ptr<Expression>> &expressions,
88 vector<JoinCondition> &conditions,
89 vector<unique_ptr<Expression>> &arbitrary_expressions) {
90 unordered_set<idx_t> left_bindings, right_bindings;
91 LogicalJoin::GetTableReferences(op&: *left_child, bindings&: left_bindings);
92 LogicalJoin::GetTableReferences(op&: *right_child, bindings&: right_bindings);
93 return ExtractJoinConditions(type, left_child, right_child, left_bindings, right_bindings, expressions, conditions,
94 arbitrary_expressions);
95}
96
97void LogicalComparisonJoin::ExtractJoinConditions(JoinType type, unique_ptr<LogicalOperator> &left_child,
98 unique_ptr<LogicalOperator> &right_child,
99 unique_ptr<Expression> condition, vector<JoinCondition> &conditions,
100 vector<unique_ptr<Expression>> &arbitrary_expressions) {
101 // split the expressions by the AND clause
102 vector<unique_ptr<Expression>> expressions;
103 expressions.push_back(x: std::move(condition));
104 LogicalFilter::SplitPredicates(expressions);
105 return ExtractJoinConditions(type, left_child, right_child, expressions, conditions, arbitrary_expressions);
106}
107
108unique_ptr<LogicalOperator> LogicalComparisonJoin::CreateJoin(JoinType type, JoinRefType reftype,
109 unique_ptr<LogicalOperator> left_child,
110 unique_ptr<LogicalOperator> right_child,
111 vector<JoinCondition> conditions,
112 vector<unique_ptr<Expression>> arbitrary_expressions) {
113 // Validate the conditions
114 bool need_to_consider_arbitrary_expressions = true;
115 switch (reftype) {
116 case JoinRefType::ASOF: {
117 need_to_consider_arbitrary_expressions = false;
118 auto asof_idx = conditions.size();
119 for (size_t c = 0; c < conditions.size(); ++c) {
120 auto &cond = conditions[c];
121 switch (cond.comparison) {
122 case ExpressionType::COMPARE_EQUAL:
123 case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
124 break;
125 case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
126 if (asof_idx < conditions.size()) {
127 throw BinderException("Multiple ASOF JOIN inequalities");
128 }
129 asof_idx = c;
130 break;
131 default:
132 throw BinderException("Invalid ASOF JOIN comparison");
133 }
134 }
135 if (asof_idx == conditions.size()) {
136 throw BinderException("Missing ASOF JOIN inequality");
137 }
138 break;
139 }
140 default:
141 break;
142 }
143
144 if (type == JoinType::INNER && reftype == JoinRefType::REGULAR) {
145 // for inner joins we can push arbitrary expressions as a filter
146 // here we prefer to create a comparison join if possible
147 // that way we can use the much faster hash join to process the main join
148 // rather than doing a nested loop join to handle arbitrary expressions
149
150 // for left and full outer joins we HAVE to process all join conditions
151 // because pushing a filter will lead to an incorrect result, as non-matching tuples cannot be filtered out
152 need_to_consider_arbitrary_expressions = false;
153 }
154 if ((need_to_consider_arbitrary_expressions && !arbitrary_expressions.empty()) || conditions.empty()) {
155 if (arbitrary_expressions.empty()) {
156 // all conditions were pushed down, add TRUE predicate
157 arbitrary_expressions.push_back(x: make_uniq<BoundConstantExpression>(args: Value::BOOLEAN(value: true)));
158 }
159 for (auto &condition : conditions) {
160 arbitrary_expressions.push_back(x: JoinCondition::CreateExpression(cond: std::move(condition)));
161 }
162 // if we get here we could not create any JoinConditions
163 // turn this into an arbitrary expression join
164 auto any_join = make_uniq<LogicalAnyJoin>(args&: type);
165 // create the condition
166 any_join->children.push_back(x: std::move(left_child));
167 any_join->children.push_back(x: std::move(right_child));
168 // AND all the arbitrary expressions together
169 // do the same with any remaining conditions
170 any_join->condition = std::move(arbitrary_expressions[0]);
171 for (idx_t i = 1; i < arbitrary_expressions.size(); i++) {
172 any_join->condition = make_uniq<BoundConjunctionExpression>(
173 args: ExpressionType::CONJUNCTION_AND, args: std::move(any_join->condition), args: std::move(arbitrary_expressions[i]));
174 }
175 return std::move(any_join);
176 } else {
177 // we successfully converted expressions into JoinConditions
178 // create a LogicalComparisonJoin
179 unique_ptr<LogicalComparisonJoin> comp_join;
180 if (reftype == JoinRefType::ASOF) {
181 comp_join = make_uniq<LogicalAsOfJoin>(args&: type);
182 } else {
183 comp_join = make_uniq<LogicalComparisonJoin>(args&: type);
184 }
185 comp_join->conditions = std::move(conditions);
186 comp_join->children.push_back(x: std::move(left_child));
187 comp_join->children.push_back(x: std::move(right_child));
188 if (!arbitrary_expressions.empty()) {
189 // we have some arbitrary expressions as well
190 // add them to a filter
191 auto filter = make_uniq<LogicalFilter>();
192 for (auto &expr : arbitrary_expressions) {
193 filter->expressions.push_back(x: std::move(expr));
194 }
195 LogicalFilter::SplitPredicates(expressions&: filter->expressions);
196 filter->children.push_back(x: std::move(comp_join));
197 return std::move(filter);
198 }
199 return std::move(comp_join);
200 }
201}
202
203static bool HasCorrelatedColumns(Expression &expression) {
204 if (expression.type == ExpressionType::BOUND_COLUMN_REF) {
205 auto &colref = expression.Cast<BoundColumnRefExpression>();
206 if (colref.depth > 0) {
207 return true;
208 }
209 }
210 bool has_correlated_columns = false;
211 ExpressionIterator::EnumerateChildren(expression, callback: [&](Expression &child) {
212 if (HasCorrelatedColumns(expression&: child)) {
213 has_correlated_columns = true;
214 }
215 });
216 return has_correlated_columns;
217}
218
219unique_ptr<LogicalOperator> LogicalComparisonJoin::CreateJoin(JoinType type, JoinRefType reftype,
220 unique_ptr<LogicalOperator> left_child,
221 unique_ptr<LogicalOperator> right_child,
222 unique_ptr<Expression> condition) {
223 vector<JoinCondition> conditions;
224 vector<unique_ptr<Expression>> arbitrary_expressions;
225 LogicalComparisonJoin::ExtractJoinConditions(type, left_child, right_child, condition: std::move(condition), conditions,
226 arbitrary_expressions);
227 return LogicalComparisonJoin::CreateJoin(type, reftype, left_child: std::move(left_child), right_child: std::move(right_child),
228 conditions: std::move(conditions), arbitrary_expressions: std::move(arbitrary_expressions));
229}
230
231unique_ptr<LogicalOperator> Binder::CreatePlan(BoundJoinRef &ref) {
232 auto left = CreatePlan(ref&: *ref.left);
233 auto right = CreatePlan(ref&: *ref.right);
234 if (!ref.lateral && !ref.correlated_columns.empty()) {
235 // non-lateral join with correlated columns
236 // this happens if there is a join (or cross product) in a correlated subquery
237 // due to the lateral binder the expression depth of all correlated columns in the "ref.correlated_columns" set
238 // is 1 too high
239 // we reduce expression depth of all columns in the "ref.correlated_columns" set by 1
240 LateralBinder::ReduceExpressionDepth(op&: *right, info: ref.correlated_columns);
241 }
242 if (ref.type == JoinType::RIGHT && ref.ref_type != JoinRefType::ASOF &&
243 ClientConfig::GetConfig(context).enable_optimizer) {
244 // we turn any right outer joins into left outer joins for optimization purposes
245 // they are the same but with sides flipped, so treating them the same simplifies life
246 ref.type = JoinType::LEFT;
247 std::swap(a&: left, b&: right);
248 }
249 if (ref.lateral) {
250 // lateral join
251 return PlanLateralJoin(left: std::move(left), right: std::move(right), correlated_columns&: ref.correlated_columns, join_type: ref.type,
252 condition: std::move(ref.condition));
253 }
254 switch (ref.ref_type) {
255 case JoinRefType::CROSS:
256 return LogicalCrossProduct::Create(left: std::move(left), right: std::move(right));
257 case JoinRefType::POSITIONAL:
258 return LogicalPositionalJoin::Create(left: std::move(left), right: std::move(right));
259 default:
260 break;
261 }
262 if (ref.type == JoinType::INNER && (ref.condition->HasSubquery() || HasCorrelatedColumns(expression&: *ref.condition)) &&
263 ref.ref_type == JoinRefType::REGULAR) {
264 // inner join, generate a cross product + filter
265 // this will be later turned into a proper join by the join order optimizer
266 auto root = LogicalCrossProduct::Create(left: std::move(left), right: std::move(right));
267
268 auto filter = make_uniq<LogicalFilter>(args: std::move(ref.condition));
269 // visit the expressions in the filter
270 for (auto &expression : filter->expressions) {
271 PlanSubqueries(expr&: expression, root);
272 }
273 filter->AddChild(child: std::move(root));
274 return std::move(filter);
275 }
276
277 // now create the join operator from the join condition
278 auto result = LogicalComparisonJoin::CreateJoin(type: ref.type, reftype: ref.ref_type, left_child: std::move(left), right_child: std::move(right),
279 condition: std::move(ref.condition));
280
281 optional_ptr<LogicalOperator> join;
282 if (result->type == LogicalOperatorType::LOGICAL_FILTER) {
283 join = result->children[0].get();
284 } else {
285 join = result.get();
286 }
287 for (auto &child : join->children) {
288 if (child->type == LogicalOperatorType::LOGICAL_FILTER) {
289 auto &filter = child->Cast<LogicalFilter>();
290 for (auto &expr : filter.expressions) {
291 PlanSubqueries(expr, root&: filter.children[0]);
292 }
293 }
294 }
295
296 // we visit the expressions depending on the type of join
297 switch (join->type) {
298 case LogicalOperatorType::LOGICAL_ASOF_JOIN:
299 case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: {
300 // comparison join
301 // in this join we visit the expressions on the LHS with the LHS as root node
302 // and the expressions on the RHS with the RHS as root node
303 auto &comp_join = join->Cast<LogicalComparisonJoin>();
304 for (idx_t i = 0; i < comp_join.conditions.size(); i++) {
305 PlanSubqueries(expr&: comp_join.conditions[i].left, root&: comp_join.children[0]);
306 PlanSubqueries(expr&: comp_join.conditions[i].right, root&: comp_join.children[1]);
307 }
308 break;
309 }
310 case LogicalOperatorType::LOGICAL_ANY_JOIN: {
311 auto &any_join = join->Cast<LogicalAnyJoin>();
312 // for the any join we just visit the condition
313 if (any_join.condition->HasSubquery()) {
314 throw NotImplementedException("Cannot perform non-inner join on subquery!");
315 }
316 break;
317 }
318 default:
319 break;
320 }
321 return result;
322}
323
324} // namespace duckdb
325