1 | #include "duckdb/planner/subquery/flatten_dependent_join.hpp" |
2 | |
3 | #include "duckdb/planner/binder.hpp" |
4 | #include "duckdb/planner/expression/list.hpp" |
5 | #include "duckdb/planner/logical_operator_visitor.hpp" |
6 | #include "duckdb/planner/binder.hpp" |
7 | #include "duckdb/planner/operator/list.hpp" |
8 | #include "duckdb/planner/subquery/has_correlated_expressions.hpp" |
9 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" |
10 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
11 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
12 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
13 | |
14 | using namespace duckdb; |
15 | using namespace std; |
16 | |
17 | FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector<CorrelatedColumnInfo> &correlated) |
18 | : binder(binder), correlated_columns(correlated) { |
19 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
20 | auto &col = correlated_columns[i]; |
21 | correlated_map[col.binding] = i; |
22 | delim_types.push_back(col.type); |
23 | } |
24 | } |
25 | |
26 | bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op) { |
27 | assert(op); |
28 | // check if this entry has correlated expressions |
29 | HasCorrelatedExpressions visitor(correlated_columns); |
30 | visitor.VisitOperator(*op); |
31 | bool has_correlation = visitor.has_correlated_expressions; |
32 | // now visit the children of this entry and check if they have correlated expressions |
33 | for (auto &child : op->children) { |
34 | // we OR the property with its children such that has_correlation is true if either |
35 | // (1) this node has a correlated expression or |
36 | // (2) one of its children has a correlated expression |
37 | if (DetectCorrelatedExpressions(child.get())) { |
38 | has_correlation = true; |
39 | } |
40 | } |
41 | // set the entry in the map |
42 | has_correlated_expressions[op] = has_correlation; |
43 | return has_correlation; |
44 | } |
45 | |
46 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoin(unique_ptr<LogicalOperator> plan) { |
47 | auto result = PushDownDependentJoinInternal(move(plan)); |
48 | if (replacement_map.size() > 0) { |
49 | // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" |
50 | RewriteCountAggregates aggr(replacement_map); |
51 | aggr.VisitOperator(*result); |
52 | } |
53 | return result; |
54 | } |
55 | |
56 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr<LogicalOperator> plan) { |
57 | // first check if the logical operator has correlated expressions |
58 | auto entry = has_correlated_expressions.find(plan.get()); |
59 | assert(entry != has_correlated_expressions.end()); |
60 | if (!entry->second) { |
61 | // we reached a node without correlated expressions |
62 | // we can eliminate the dependent join now and create a simple cross product |
63 | auto cross_product = make_unique<LogicalCrossProduct>(); |
64 | // now create the duplicate eliminated scan for this node |
65 | auto delim_index = binder.GenerateTableIndex(); |
66 | this->base_binding = ColumnBinding(delim_index, 0); |
67 | auto delim_scan = make_unique<LogicalDelimGet>(delim_index, delim_types); |
68 | cross_product->children.push_back(move(delim_scan)); |
69 | cross_product->children.push_back(move(plan)); |
70 | return move(cross_product); |
71 | } |
72 | switch (plan->type) { |
73 | case LogicalOperatorType::FILTER: { |
74 | // filter |
75 | // first we flatten the dependent join in the child of the filter |
76 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
77 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
78 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
79 | rewriter.VisitOperator(*plan); |
80 | return plan; |
81 | } |
82 | case LogicalOperatorType::PROJECTION: { |
83 | // projection |
84 | // first we flatten the dependent join in the child of the projection |
85 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
86 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
87 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
88 | rewriter.VisitOperator(*plan); |
89 | // now we add all the columns of the delim_scan to the projection list |
90 | auto proj = (LogicalProjection *)plan.get(); |
91 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
92 | auto colref = make_unique<BoundColumnRefExpression>( |
93 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
94 | plan->expressions.push_back(move(colref)); |
95 | } |
96 | |
97 | base_binding.table_index = proj->table_index; |
98 | this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size(); |
99 | this->data_offset = 0; |
100 | return plan; |
101 | } |
102 | case LogicalOperatorType::AGGREGATE_AND_GROUP_BY: { |
103 | auto &aggr = (LogicalAggregate &)*plan; |
104 | // aggregate and group by |
105 | // first we flatten the dependent join in the child of the projection |
106 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
107 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
108 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
109 | rewriter.VisitOperator(*plan); |
110 | // now we add all the columns of the delim_scan to the grouping operators AND the projection list |
111 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
112 | auto colref = make_unique<BoundColumnRefExpression>( |
113 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
114 | aggr.groups.push_back(move(colref)); |
115 | } |
116 | if (aggr.groups.size() == correlated_columns.size()) { |
117 | // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan |
118 | auto left_outer_join = make_unique<LogicalComparisonJoin>(JoinType::LEFT); |
119 | auto left_index = binder.GenerateTableIndex(); |
120 | auto delim_scan = make_unique<LogicalDelimGet>(left_index, delim_types); |
121 | left_outer_join->children.push_back(move(delim_scan)); |
122 | left_outer_join->children.push_back(move(plan)); |
123 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
124 | JoinCondition cond; |
125 | cond.left = |
126 | make_unique<BoundColumnRefExpression>(correlated_columns[i].type, ColumnBinding(left_index, i)); |
127 | cond.right = make_unique<BoundColumnRefExpression>( |
128 | correlated_columns[i].type, |
129 | ColumnBinding(aggr.group_index, (aggr.groups.size() - correlated_columns.size()) + i)); |
130 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
131 | cond.null_values_are_equal = true; |
132 | left_outer_join->conditions.push_back(move(cond)); |
133 | } |
134 | // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0 |
135 | // ELSE COUNT(*) END |
136 | for (idx_t i = 0; i < aggr.expressions.size(); i++) { |
137 | assert(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); |
138 | auto bound = (BoundAggregateExpression *)&*aggr.expressions[i]; |
139 | vector<SQLType> arguments; |
140 | if (bound->function == CountFun::GetFunction() || bound->function == CountStarFun::GetFunction()) { |
141 | // have to replace this ColumnBinding with the CASE expression |
142 | replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i; |
143 | } |
144 | } |
145 | // now we update the delim_index |
146 | |
147 | base_binding.table_index = left_index; |
148 | this->delim_offset = base_binding.column_index = 0; |
149 | this->data_offset = 0; |
150 | return move(left_outer_join); |
151 | } else { |
152 | // update the delim_index |
153 | base_binding.table_index = aggr.group_index; |
154 | this->delim_offset = base_binding.column_index = aggr.groups.size() - correlated_columns.size(); |
155 | this->data_offset = aggr.groups.size(); |
156 | return plan; |
157 | } |
158 | } |
159 | case LogicalOperatorType::CROSS_PRODUCT: { |
160 | // cross product |
161 | // push into both sides of the plan |
162 | bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; |
163 | bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; |
164 | if (!right_has_correlation) { |
165 | // only left has correlation: push into left |
166 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
167 | return plan; |
168 | } |
169 | if (!left_has_correlation) { |
170 | // only right has correlation: push into right |
171 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
172 | return plan; |
173 | } |
174 | // both sides have correlation |
175 | // turn into an inner join |
176 | auto join = make_unique<LogicalComparisonJoin>(JoinType::INNER); |
177 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
178 | auto left_binding = this->base_binding; |
179 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
180 | // add the correlated columns to the join conditions |
181 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
182 | JoinCondition cond; |
183 | cond.left = make_unique<BoundColumnRefExpression>( |
184 | correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
185 | cond.right = make_unique<BoundColumnRefExpression>( |
186 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
187 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
188 | cond.null_values_are_equal = true; |
189 | join->conditions.push_back(move(cond)); |
190 | } |
191 | join->children.push_back(move(plan->children[0])); |
192 | join->children.push_back(move(plan->children[1])); |
193 | return move(join); |
194 | } |
195 | case LogicalOperatorType::COMPARISON_JOIN: { |
196 | auto &join = (LogicalComparisonJoin &)*plan; |
197 | assert(plan->children.size() == 2); |
198 | // check the correlated expressions in the children of the join |
199 | bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; |
200 | bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; |
201 | |
202 | if (join.join_type == JoinType::INNER) { |
203 | // inner join |
204 | if (!right_has_correlation) { |
205 | // only left has correlation: push into left |
206 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
207 | return plan; |
208 | } |
209 | if (!left_has_correlation) { |
210 | // only right has correlation: push into right |
211 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
212 | return plan; |
213 | } |
214 | } else if (join.join_type == JoinType::LEFT) { |
215 | // left outer join |
216 | if (!right_has_correlation) { |
217 | // only left has correlation: push into left |
218 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
219 | return plan; |
220 | } |
221 | } else if (join.join_type == JoinType::MARK) { |
222 | if (right_has_correlation) { |
223 | throw Exception("MARK join with correlation in RHS not supported" ); |
224 | } |
225 | // push the child into the LHS |
226 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
227 | // rewrite expressions in the join conditions |
228 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
229 | rewriter.VisitOperator(*plan); |
230 | return plan; |
231 | } else { |
232 | throw Exception("Unsupported join type for flattening correlated subquery" ); |
233 | } |
234 | // both sides have correlation |
235 | // push into both sides |
236 | // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join |
237 | // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push |
238 | // because the RIGHT binding might contain NULL values |
239 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
240 | auto left_binding = this->base_binding; |
241 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
242 | auto right_binding = this->base_binding; |
243 | if (join.join_type == JoinType::LEFT) { |
244 | this->base_binding = left_binding; |
245 | } |
246 | // add the correlated columns to the join conditions |
247 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
248 | JoinCondition cond; |
249 | |
250 | cond.left = make_unique<BoundColumnRefExpression>( |
251 | correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
252 | cond.right = make_unique<BoundColumnRefExpression>( |
253 | correlated_columns[i].type, ColumnBinding(right_binding.table_index, right_binding.column_index + i)); |
254 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
255 | cond.null_values_are_equal = true; |
256 | join.conditions.push_back(move(cond)); |
257 | } |
258 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
259 | RewriteCorrelatedExpressions rewriter(right_binding, correlated_map); |
260 | rewriter.VisitOperator(*plan); |
261 | return plan; |
262 | } |
263 | case LogicalOperatorType::LIMIT: { |
264 | auto &limit = (LogicalLimit &)*plan; |
265 | if (limit.offset > 0) { |
266 | throw ParserException("OFFSET not supported in correlated subquery" ); |
267 | } |
268 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
269 | if (limit.limit == 0) { |
270 | // limit = 0 means we return zero columns here |
271 | return plan; |
272 | } else { |
273 | // limit > 0 does nothing |
274 | return move(plan->children[0]); |
275 | } |
276 | } |
277 | case LogicalOperatorType::WINDOW: { |
278 | auto &window = (LogicalWindow &)*plan; |
279 | // push into children |
280 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
281 | // add the correlated columns to the PARTITION BY clauses in the Window |
282 | for (auto &expr : window.expressions) { |
283 | assert(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); |
284 | auto &w = (BoundWindowExpression &)*expr; |
285 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
286 | w.partitions.push_back(make_unique<BoundColumnRefExpression>( |
287 | correlated_columns[i].type, |
288 | ColumnBinding(base_binding.table_index, base_binding.column_index + i))); |
289 | } |
290 | } |
291 | return plan; |
292 | } |
293 | case LogicalOperatorType::EXCEPT: |
294 | case LogicalOperatorType::INTERSECT: |
295 | case LogicalOperatorType::UNION: { |
296 | auto &setop = (LogicalSetOperation &)*plan; |
297 | // set operator, push into both children |
298 | plan->children[0] = PushDownDependentJoin(move(plan->children[0])); |
299 | plan->children[1] = PushDownDependentJoin(move(plan->children[1])); |
300 | // we have to refer to the setop index now |
301 | base_binding.table_index = setop.table_index; |
302 | base_binding.column_index = setop.column_count; |
303 | setop.column_count += correlated_columns.size(); |
304 | return plan; |
305 | } |
306 | case LogicalOperatorType::DISTINCT: |
307 | plan->children[0] = PushDownDependentJoin(move(plan->children[0])); |
308 | return plan; |
309 | case LogicalOperatorType::ORDER_BY: |
310 | throw ParserException("ORDER BY not supported in correlated subquery" ); |
311 | default: |
312 | throw NotImplementedException("Logical operator type \"%s\" for dependent join" , |
313 | LogicalOperatorToString(plan->type).c_str()); |
314 | } |
315 | } |
316 | |