1 | #include "duckdb/planner/subquery/flatten_dependent_join.hpp" |
2 | |
3 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
4 | #include "duckdb/common/operator/add.hpp" |
5 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
6 | #include "duckdb/planner/binder.hpp" |
7 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
8 | #include "duckdb/planner/expression/list.hpp" |
9 | #include "duckdb/planner/logical_operator_visitor.hpp" |
10 | #include "duckdb/planner/operator/list.hpp" |
11 | #include "duckdb/planner/subquery/has_correlated_expressions.hpp" |
12 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" |
13 | |
14 | namespace duckdb { |
15 | |
16 | FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector<CorrelatedColumnInfo> &correlated, |
17 | bool perform_delim, bool any_join) |
18 | : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), |
19 | perform_delim(perform_delim), any_join(any_join) { |
20 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
21 | auto &col = correlated_columns[i]; |
22 | correlated_map[col.binding] = i; |
23 | delim_types.push_back(x: col.type); |
24 | } |
25 | } |
26 | |
27 | bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op, bool lateral) { |
28 | D_ASSERT(op); |
29 | // check if this entry has correlated expressions |
30 | HasCorrelatedExpressions visitor(correlated_columns, lateral); |
31 | visitor.VisitOperator(op&: *op); |
32 | bool has_correlation = visitor.has_correlated_expressions; |
33 | // now visit the children of this entry and check if they have correlated expressions |
34 | for (auto &child : op->children) { |
35 | // we OR the property with its children such that has_correlation is true if either |
36 | // (1) this node has a correlated expression or |
37 | // (2) one of its children has a correlated expression |
38 | if (DetectCorrelatedExpressions(op: child.get(), lateral)) { |
39 | has_correlation = true; |
40 | } |
41 | } |
42 | // set the entry in the map |
43 | has_correlated_expressions[op] = has_correlation; |
44 | return has_correlation; |
45 | } |
46 | |
47 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoin(unique_ptr<LogicalOperator> plan) { |
48 | bool propagate_null_values = true; |
49 | auto result = PushDownDependentJoinInternal(plan: std::move(plan), parent_propagate_null_values&: propagate_null_values); |
50 | if (!replacement_map.empty()) { |
51 | // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" |
52 | RewriteCountAggregates aggr(replacement_map); |
53 | aggr.VisitOperator(op&: *result); |
54 | } |
55 | return result; |
56 | } |
57 | |
58 | bool SubqueryDependentFilter(Expression *expr) { |
59 | if (expr->expression_class == ExpressionClass::BOUND_CONJUNCTION && |
60 | expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { |
61 | auto &bound_conjuction = expr->Cast<BoundConjunctionExpression>(); |
62 | for (auto &child : bound_conjuction.children) { |
63 | if (SubqueryDependentFilter(expr: child.get())) { |
64 | return true; |
65 | } |
66 | } |
67 | } |
68 | if (expr->expression_class == ExpressionClass::BOUND_SUBQUERY) { |
69 | return true; |
70 | } |
71 | return false; |
72 | } |
73 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr<LogicalOperator> plan, |
74 | bool &parent_propagate_null_values) { |
75 | // first check if the logical operator has correlated expressions |
76 | auto entry = has_correlated_expressions.find(x: plan.get()); |
77 | D_ASSERT(entry != has_correlated_expressions.end()); |
78 | if (!entry->second) { |
79 | // we reached a node without correlated expressions |
80 | // we can eliminate the dependent join now and create a simple cross product |
81 | // now create the duplicate eliminated scan for this node |
82 | auto left_columns = plan->GetColumnBindings().size(); |
83 | auto delim_index = binder.GenerateTableIndex(); |
84 | this->base_binding = ColumnBinding(delim_index, 0); |
85 | this->delim_offset = left_columns; |
86 | this->data_offset = 0; |
87 | auto delim_scan = make_uniq<LogicalDelimGet>(args&: delim_index, args&: delim_types); |
88 | return LogicalCrossProduct::Create(left: std::move(plan), right: std::move(delim_scan)); |
89 | } |
90 | switch (plan->type) { |
91 | case LogicalOperatorType::LOGICAL_UNNEST: |
92 | case LogicalOperatorType::LOGICAL_FILTER: { |
93 | // filter |
94 | // first we flatten the dependent join in the child of the filter |
95 | for (auto &expr : plan->expressions) { |
96 | any_join |= SubqueryDependentFilter(expr: expr.get()); |
97 | } |
98 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
99 | |
100 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
101 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
102 | rewriter.VisitOperator(op&: *plan); |
103 | return plan; |
104 | } |
105 | case LogicalOperatorType::LOGICAL_PROJECTION: { |
106 | // projection |
107 | // first we flatten the dependent join in the child of the projection |
108 | for (auto &expr : plan->expressions) { |
109 | parent_propagate_null_values &= expr->PropagatesNullValues(); |
110 | } |
111 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
112 | |
113 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
114 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
115 | rewriter.VisitOperator(op&: *plan); |
116 | // now we add all the columns of the delim_scan to the projection list |
117 | auto &proj = plan->Cast<LogicalProjection>(); |
118 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
119 | auto &col = correlated_columns[i]; |
120 | auto colref = make_uniq<BoundColumnRefExpression>( |
121 | args: col.name, args: col.type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
122 | plan->expressions.push_back(x: std::move(colref)); |
123 | } |
124 | |
125 | base_binding.table_index = proj.table_index; |
126 | this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size(); |
127 | this->data_offset = 0; |
128 | return plan; |
129 | } |
130 | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { |
131 | auto &aggr = plan->Cast<LogicalAggregate>(); |
132 | // aggregate and group by |
133 | // first we flatten the dependent join in the child of the projection |
134 | for (auto &expr : plan->expressions) { |
135 | parent_propagate_null_values &= expr->PropagatesNullValues(); |
136 | } |
137 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
138 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
139 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
140 | rewriter.VisitOperator(op&: *plan); |
141 | // now we add all the columns of the delim_scan to the grouping operators AND the projection list |
142 | idx_t delim_table_index; |
143 | idx_t delim_column_offset; |
144 | idx_t delim_data_offset; |
145 | auto new_group_count = perform_delim ? correlated_columns.size() : 1; |
146 | for (idx_t i = 0; i < new_group_count; i++) { |
147 | auto &col = correlated_columns[i]; |
148 | auto colref = make_uniq<BoundColumnRefExpression>( |
149 | args: col.name, args: col.type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
150 | for (auto &set : aggr.grouping_sets) { |
151 | set.insert(x: aggr.groups.size()); |
152 | } |
153 | aggr.groups.push_back(x: std::move(colref)); |
154 | } |
155 | if (!perform_delim) { |
156 | // if we are not performing the duplicate elimination, we have only added the row_id column to the grouping |
157 | // operators in this case, we push a FIRST aggregate for each of the remaining expressions |
158 | delim_table_index = aggr.aggregate_index; |
159 | delim_column_offset = aggr.expressions.size(); |
160 | delim_data_offset = aggr.groups.size(); |
161 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
162 | auto &col = correlated_columns[i]; |
163 | auto first_aggregate = FirstFun::GetFunction(type: col.type); |
164 | auto colref = make_uniq<BoundColumnRefExpression>( |
165 | args: col.name, args: col.type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
166 | vector<unique_ptr<Expression>> aggr_children; |
167 | aggr_children.push_back(x: std::move(colref)); |
168 | auto first_fun = |
169 | make_uniq<BoundAggregateExpression>(args: std::move(first_aggregate), args: std::move(aggr_children), args: nullptr, |
170 | args: nullptr, args: AggregateType::NON_DISTINCT); |
171 | aggr.expressions.push_back(x: std::move(first_fun)); |
172 | } |
173 | } else { |
174 | delim_table_index = aggr.group_index; |
175 | delim_column_offset = aggr.groups.size() - correlated_columns.size(); |
176 | delim_data_offset = aggr.groups.size(); |
177 | } |
178 | if (aggr.groups.size() == new_group_count) { |
179 | // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan |
180 | // FIXME: this does not always have to be a LEFT OUTER JOIN, depending on whether aggr.expressions return |
181 | // NULL or a value |
182 | unique_ptr<LogicalComparisonJoin> join = make_uniq<LogicalComparisonJoin>(args: JoinType::INNER); |
183 | for (auto &aggr_exp : aggr.expressions) { |
184 | auto &b_aggr_exp = aggr_exp->Cast<BoundAggregateExpression>(); |
185 | if (!b_aggr_exp.PropagatesNullValues() || any_join || !parent_propagate_null_values) { |
186 | join = make_uniq<LogicalComparisonJoin>(args: JoinType::LEFT); |
187 | break; |
188 | } |
189 | } |
190 | auto left_index = binder.GenerateTableIndex(); |
191 | auto delim_scan = make_uniq<LogicalDelimGet>(args&: left_index, args&: delim_types); |
192 | join->children.push_back(x: std::move(delim_scan)); |
193 | join->children.push_back(x: std::move(plan)); |
194 | for (idx_t i = 0; i < new_group_count; i++) { |
195 | auto &col = correlated_columns[i]; |
196 | JoinCondition cond; |
197 | cond.left = make_uniq<BoundColumnRefExpression>(args: col.name, args: col.type, args: ColumnBinding(left_index, i)); |
198 | cond.right = make_uniq<BoundColumnRefExpression>( |
199 | args: correlated_columns[i].type, args: ColumnBinding(delim_table_index, delim_column_offset + i)); |
200 | cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; |
201 | join->conditions.push_back(x: std::move(cond)); |
202 | } |
203 | // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0 |
204 | // ELSE COUNT(*) END |
205 | for (idx_t i = 0; i < aggr.expressions.size(); i++) { |
206 | D_ASSERT(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); |
207 | auto &bound = aggr.expressions[i]->Cast<BoundAggregateExpression>(); |
208 | vector<LogicalType> arguments; |
209 | if (bound.function == CountFun::GetFunction() || bound.function == CountStarFun::GetFunction()) { |
210 | // have to replace this ColumnBinding with the CASE expression |
211 | replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i; |
212 | } |
213 | } |
214 | // now we update the delim_index |
215 | base_binding.table_index = left_index; |
216 | this->delim_offset = base_binding.column_index = 0; |
217 | this->data_offset = 0; |
218 | return std::move(join); |
219 | } else { |
220 | // update the delim_index |
221 | base_binding.table_index = delim_table_index; |
222 | this->delim_offset = base_binding.column_index = delim_column_offset; |
223 | this->data_offset = delim_data_offset; |
224 | return plan; |
225 | } |
226 | } |
227 | case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { |
228 | // cross product |
229 | // push into both sides of the plan |
230 | bool left_has_correlation = has_correlated_expressions.find(x: plan->children[0].get())->second; |
231 | bool right_has_correlation = has_correlated_expressions.find(x: plan->children[1].get())->second; |
232 | if (!right_has_correlation) { |
233 | // only left has correlation: push into left |
234 | plan->children[0] = |
235 | PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
236 | return plan; |
237 | } |
238 | if (!left_has_correlation) { |
239 | // only right has correlation: push into right |
240 | plan->children[1] = |
241 | PushDownDependentJoinInternal(plan: std::move(plan->children[1]), parent_propagate_null_values); |
242 | return plan; |
243 | } |
244 | // both sides have correlation |
245 | // turn into an inner join |
246 | auto join = make_uniq<LogicalComparisonJoin>(args: JoinType::INNER); |
247 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
248 | auto left_binding = this->base_binding; |
249 | plan->children[1] = PushDownDependentJoinInternal(plan: std::move(plan->children[1]), parent_propagate_null_values); |
250 | // add the correlated columns to the join conditions |
251 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
252 | JoinCondition cond; |
253 | cond.left = make_uniq<BoundColumnRefExpression>( |
254 | args: correlated_columns[i].type, args: ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
255 | cond.right = make_uniq<BoundColumnRefExpression>( |
256 | args: correlated_columns[i].type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
257 | cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; |
258 | join->conditions.push_back(x: std::move(cond)); |
259 | } |
260 | join->children.push_back(x: std::move(plan->children[0])); |
261 | join->children.push_back(x: std::move(plan->children[1])); |
262 | return std::move(join); |
263 | } |
264 | case LogicalOperatorType::LOGICAL_ANY_JOIN: |
265 | case LogicalOperatorType::LOGICAL_ASOF_JOIN: |
266 | case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { |
267 | auto &join = plan->Cast<LogicalJoin>(); |
268 | D_ASSERT(plan->children.size() == 2); |
269 | // check the correlated expressions in the children of the join |
270 | bool left_has_correlation = has_correlated_expressions.find(x: plan->children[0].get())->second; |
271 | bool right_has_correlation = has_correlated_expressions.find(x: plan->children[1].get())->second; |
272 | |
273 | if (join.join_type == JoinType::INNER) { |
274 | // inner join |
275 | if (!right_has_correlation) { |
276 | // only left has correlation: push into left |
277 | plan->children[0] = |
278 | PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
279 | return plan; |
280 | } |
281 | if (!left_has_correlation) { |
282 | // only right has correlation: push into right |
283 | plan->children[1] = |
284 | PushDownDependentJoinInternal(plan: std::move(plan->children[1]), parent_propagate_null_values); |
285 | return plan; |
286 | } |
287 | } else if (join.join_type == JoinType::LEFT) { |
288 | // left outer join |
289 | if (!right_has_correlation) { |
290 | // only left has correlation: push into left |
291 | plan->children[0] = |
292 | PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
293 | return plan; |
294 | } |
295 | } else if (join.join_type == JoinType::RIGHT) { |
296 | // left outer join |
297 | if (!left_has_correlation) { |
298 | // only right has correlation: push into right |
299 | plan->children[1] = |
300 | PushDownDependentJoinInternal(plan: std::move(plan->children[1]), parent_propagate_null_values); |
301 | return plan; |
302 | } |
303 | } else if (join.join_type == JoinType::MARK) { |
304 | if (right_has_correlation) { |
305 | throw Exception("MARK join with correlation in RHS not supported" ); |
306 | } |
307 | // push the child into the LHS |
308 | plan->children[0] = |
309 | PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
310 | // rewrite expressions in the join conditions |
311 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
312 | rewriter.VisitOperator(op&: *plan); |
313 | return plan; |
314 | } else { |
315 | throw Exception("Unsupported join type for flattening correlated subquery" ); |
316 | } |
317 | // both sides have correlation |
318 | // push into both sides |
319 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
320 | auto left_binding = this->base_binding; |
321 | plan->children[1] = PushDownDependentJoinInternal(plan: std::move(plan->children[1]), parent_propagate_null_values); |
322 | auto right_binding = this->base_binding; |
323 | // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join |
324 | // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push |
325 | // because the RIGHT binding might contain NULL values |
326 | if (join.join_type == JoinType::LEFT) { |
327 | this->base_binding = left_binding; |
328 | } else if (join.join_type == JoinType::RIGHT) { |
329 | this->base_binding = right_binding; |
330 | } |
331 | // add the correlated columns to the join conditions |
332 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
333 | auto left = make_uniq<BoundColumnRefExpression>( |
334 | args: correlated_columns[i].type, args: ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
335 | auto right = make_uniq<BoundColumnRefExpression>( |
336 | args: correlated_columns[i].type, args: ColumnBinding(right_binding.table_index, right_binding.column_index + i)); |
337 | |
338 | if (join.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || |
339 | join.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { |
340 | JoinCondition cond; |
341 | cond.left = std::move(left); |
342 | cond.right = std::move(right); |
343 | cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; |
344 | |
345 | auto &comparison_join = join.Cast<LogicalComparisonJoin>(); |
346 | comparison_join.conditions.push_back(x: std::move(cond)); |
347 | } else { |
348 | auto &any_join = join.Cast<LogicalAnyJoin>(); |
349 | auto comparison = make_uniq<BoundComparisonExpression>(args: ExpressionType::COMPARE_NOT_DISTINCT_FROM, |
350 | args: std::move(left), args: std::move(right)); |
351 | auto conjunction = make_uniq<BoundConjunctionExpression>( |
352 | args: ExpressionType::CONJUNCTION_AND, args: std::move(comparison), args: std::move(any_join.condition)); |
353 | any_join.condition = std::move(conjunction); |
354 | } |
355 | } |
356 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
357 | RewriteCorrelatedExpressions rewriter(right_binding, correlated_map); |
358 | rewriter.VisitOperator(op&: *plan); |
359 | return plan; |
360 | } |
361 | case LogicalOperatorType::LOGICAL_LIMIT: { |
362 | auto &limit = plan->Cast<LogicalLimit>(); |
363 | if (limit.limit || limit.offset) { |
364 | throw ParserException("Non-constant limit or offset not supported in correlated subquery" ); |
365 | } |
366 | auto rownum_alias = "limit_rownum" ; |
367 | unique_ptr<LogicalOperator> child; |
368 | unique_ptr<LogicalOrder> order_by; |
369 | |
370 | // check if the direct child of this LIMIT node is an ORDER BY node, if so, keep it separate |
371 | // this is done for an optimization to avoid having to compute the total order |
372 | if (plan->children[0]->type == LogicalOperatorType::LOGICAL_ORDER_BY) { |
373 | order_by = unique_ptr_cast<LogicalOperator, LogicalOrder>(src: std::move(plan->children[0])); |
374 | child = PushDownDependentJoinInternal(plan: std::move(order_by->children[0]), parent_propagate_null_values); |
375 | } else { |
376 | child = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
377 | } |
378 | auto child_column_count = child->GetColumnBindings().size(); |
379 | // we push a row_number() OVER (PARTITION BY [correlated columns]) |
380 | auto window_index = binder.GenerateTableIndex(); |
381 | auto window = make_uniq<LogicalWindow>(args&: window_index); |
382 | auto row_number = |
383 | make_uniq<BoundWindowExpression>(args: ExpressionType::WINDOW_ROW_NUMBER, args: LogicalType::BIGINT, args: nullptr, args: nullptr); |
384 | auto partition_count = perform_delim ? correlated_columns.size() : 1; |
385 | for (idx_t i = 0; i < partition_count; i++) { |
386 | auto &col = correlated_columns[i]; |
387 | auto colref = make_uniq<BoundColumnRefExpression>( |
388 | args: col.name, args: col.type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
389 | row_number->partitions.push_back(x: std::move(colref)); |
390 | } |
391 | if (order_by) { |
392 | // optimization: if there is an ORDER BY node followed by a LIMIT |
393 | // rather than computing the entire order, we push the ORDER BY expressions into the row_num computation |
394 | // this way, the order only needs to be computed per partition |
395 | row_number->orders = std::move(order_by->orders); |
396 | } |
397 | row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; |
398 | row_number->end = WindowBoundary::CURRENT_ROW_ROWS; |
399 | window->expressions.push_back(x: std::move(row_number)); |
400 | window->children.push_back(x: std::move(child)); |
401 | |
402 | // add a filter based on the row_number |
403 | // the filter we add is "row_number > offset AND row_number <= offset + limit" |
404 | auto filter = make_uniq<LogicalFilter>(); |
405 | unique_ptr<Expression> condition; |
406 | auto row_num_ref = |
407 | make_uniq<BoundColumnRefExpression>(args&: rownum_alias, args: LogicalType::BIGINT, args: ColumnBinding(window_index, 0)); |
408 | |
409 | int64_t upper_bound_limit = NumericLimits<int64_t>::Maximum(); |
410 | TryAddOperator::Operation(left: limit.offset_val, right: limit.limit_val, result&: upper_bound_limit); |
411 | auto upper_bound = make_uniq<BoundConstantExpression>(args: Value::BIGINT(value: upper_bound_limit)); |
412 | condition = make_uniq<BoundComparisonExpression>(args: ExpressionType::COMPARE_LESSTHANOREQUALTO, args: row_num_ref->Copy(), |
413 | args: std::move(upper_bound)); |
414 | // we only need to add "row_number >= offset + 1" if offset is bigger than 0 |
415 | if (limit.offset_val > 0) { |
416 | auto lower_bound = make_uniq<BoundConstantExpression>(args: Value::BIGINT(value: limit.offset_val)); |
417 | auto lower_comp = make_uniq<BoundComparisonExpression>(args: ExpressionType::COMPARE_GREATERTHAN, |
418 | args: row_num_ref->Copy(), args: std::move(lower_bound)); |
419 | auto conj = make_uniq<BoundConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND, args: std::move(lower_comp), |
420 | args: std::move(condition)); |
421 | condition = std::move(conj); |
422 | } |
423 | filter->expressions.push_back(x: std::move(condition)); |
424 | filter->children.push_back(x: std::move(window)); |
425 | // we prune away the row_number after the filter clause using the projection map |
426 | for (idx_t i = 0; i < child_column_count; i++) { |
427 | filter->projection_map.push_back(x: i); |
428 | } |
429 | return std::move(filter); |
430 | } |
431 | case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { |
432 | // NOTE: limit percent could be supported in a manner similar to the LIMIT above |
433 | // but instead of filtering by an exact number of rows, the limit should be expressed as |
434 | // COUNT computed over the partition multiplied by the percentage |
435 | throw ParserException("Limit percent operator not supported in correlated subquery" ); |
436 | } |
437 | case LogicalOperatorType::LOGICAL_WINDOW: { |
438 | auto &window = plan->Cast<LogicalWindow>(); |
439 | // push into children |
440 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
441 | // add the correlated columns to the PARTITION BY clauses in the Window |
442 | for (auto &expr : window.expressions) { |
443 | D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); |
444 | auto &w = expr->Cast<BoundWindowExpression>(); |
445 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
446 | w.partitions.push_back(x: make_uniq<BoundColumnRefExpression>( |
447 | args: correlated_columns[i].type, |
448 | args: ColumnBinding(base_binding.table_index, base_binding.column_index + i))); |
449 | } |
450 | } |
451 | return plan; |
452 | } |
453 | case LogicalOperatorType::LOGICAL_EXCEPT: |
454 | case LogicalOperatorType::LOGICAL_INTERSECT: |
455 | case LogicalOperatorType::LOGICAL_UNION: { |
456 | auto &setop = plan->Cast<LogicalSetOperation>(); |
457 | // set operator, push into both children |
458 | #ifdef DEBUG |
459 | plan->children[0]->ResolveOperatorTypes(); |
460 | plan->children[1]->ResolveOperatorTypes(); |
461 | D_ASSERT(plan->children[0]->types == plan->children[1]->types); |
462 | #endif |
463 | plan->children[0] = PushDownDependentJoin(plan: std::move(plan->children[0])); |
464 | plan->children[1] = PushDownDependentJoin(plan: std::move(plan->children[1])); |
465 | #ifdef DEBUG |
466 | D_ASSERT(plan->children[0]->GetColumnBindings().size() == plan->children[1]->GetColumnBindings().size()); |
467 | plan->children[0]->ResolveOperatorTypes(); |
468 | plan->children[1]->ResolveOperatorTypes(); |
469 | D_ASSERT(plan->children[0]->types == plan->children[1]->types); |
470 | #endif |
471 | // we have to refer to the setop index now |
472 | base_binding.table_index = setop.table_index; |
473 | base_binding.column_index = setop.column_count; |
474 | setop.column_count += correlated_columns.size(); |
475 | return plan; |
476 | } |
477 | case LogicalOperatorType::LOGICAL_DISTINCT: { |
478 | auto &distinct = plan->Cast<LogicalDistinct>(); |
479 | // push down into child |
480 | distinct.children[0] = PushDownDependentJoin(plan: std::move(distinct.children[0])); |
481 | // add all correlated columns to the distinct targets |
482 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
483 | distinct.distinct_targets.push_back(x: make_uniq<BoundColumnRefExpression>( |
484 | args: correlated_columns[i].type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i))); |
485 | } |
486 | return plan; |
487 | } |
488 | case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { |
489 | // expression get |
490 | // first we flatten the dependent join in the child |
491 | plan->children[0] = PushDownDependentJoinInternal(plan: std::move(plan->children[0]), parent_propagate_null_values); |
492 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
493 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
494 | rewriter.VisitOperator(op&: *plan); |
495 | // now we add all the correlated columns to each of the expressions of the expression scan |
496 | auto &expr_get = plan->Cast<LogicalExpressionGet>(); |
497 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
498 | for (auto &expr_list : expr_get.expressions) { |
499 | auto colref = make_uniq<BoundColumnRefExpression>( |
500 | args: correlated_columns[i].type, args: ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
501 | expr_list.push_back(x: std::move(colref)); |
502 | } |
503 | expr_get.expr_types.push_back(x: correlated_columns[i].type); |
504 | } |
505 | |
506 | base_binding.table_index = expr_get.table_index; |
507 | this->delim_offset = base_binding.column_index = expr_get.expr_types.size() - correlated_columns.size(); |
508 | this->data_offset = 0; |
509 | return plan; |
510 | } |
511 | case LogicalOperatorType::LOGICAL_PIVOT: |
512 | throw BinderException("PIVOT is not supported in correlated subqueries yet" ); |
513 | case LogicalOperatorType::LOGICAL_ORDER_BY: |
514 | plan->children[0] = PushDownDependentJoin(plan: std::move(plan->children[0])); |
515 | return plan; |
516 | case LogicalOperatorType::LOGICAL_GET: { |
517 | auto &get = plan->Cast<LogicalGet>(); |
518 | if (get.children.size() != 1) { |
519 | throw InternalException("Flatten dependent joins - logical get encountered without children" ); |
520 | } |
521 | plan->children[0] = PushDownDependentJoin(plan: std::move(plan->children[0])); |
522 | for (idx_t i = 0; i < (perform_delim ? correlated_columns.size() : 1); i++) { |
523 | get.projected_input.push_back(x: this->delim_offset + i); |
524 | } |
525 | this->delim_offset = get.returned_types.size(); |
526 | this->data_offset = 0; |
527 | return plan; |
528 | } |
529 | case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { |
530 | throw BinderException("Recursive CTEs not supported in correlated subquery" ); |
531 | } |
532 | case LogicalOperatorType::LOGICAL_DELIM_JOIN: { |
533 | throw BinderException("Nested lateral joins or lateral joins in correlated subqueries are not (yet) supported" ); |
534 | } |
535 | case LogicalOperatorType::LOGICAL_SAMPLE: |
536 | throw BinderException("Sampling in correlated subqueries is not (yet) supported" ); |
537 | default: |
538 | throw InternalException("Logical operator type \"%s\" for dependent join" , LogicalOperatorToString(type: plan->type)); |
539 | } |
540 | } |
541 | |
542 | } // namespace duckdb |
543 | |