1#include "duckdb/planner/subquery/flatten_dependent_join.hpp"
2
3#include "duckdb/planner/binder.hpp"
4#include "duckdb/planner/expression/list.hpp"
5#include "duckdb/planner/logical_operator_visitor.hpp"
6#include "duckdb/planner/binder.hpp"
7#include "duckdb/planner/operator/list.hpp"
8#include "duckdb/planner/subquery/has_correlated_expressions.hpp"
9#include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp"
10#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
11#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
12#include "duckdb/function/aggregate/distributive_functions.hpp"
13
14using namespace duckdb;
15using namespace std;
16
17FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector<CorrelatedColumnInfo> &correlated)
18 : binder(binder), correlated_columns(correlated) {
19 for (idx_t i = 0; i < correlated_columns.size(); i++) {
20 auto &col = correlated_columns[i];
21 correlated_map[col.binding] = i;
22 delim_types.push_back(col.type);
23 }
24}
25
26bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op) {
27 assert(op);
28 // check if this entry has correlated expressions
29 HasCorrelatedExpressions visitor(correlated_columns);
30 visitor.VisitOperator(*op);
31 bool has_correlation = visitor.has_correlated_expressions;
32 // now visit the children of this entry and check if they have correlated expressions
33 for (auto &child : op->children) {
34 // we OR the property with its children such that has_correlation is true if either
35 // (1) this node has a correlated expression or
36 // (2) one of its children has a correlated expression
37 if (DetectCorrelatedExpressions(child.get())) {
38 has_correlation = true;
39 }
40 }
41 // set the entry in the map
42 has_correlated_expressions[op] = has_correlation;
43 return has_correlation;
44}
45
46unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoin(unique_ptr<LogicalOperator> plan) {
47 auto result = PushDownDependentJoinInternal(move(plan));
48 if (replacement_map.size() > 0) {
49 // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END"
50 RewriteCountAggregates aggr(replacement_map);
51 aggr.VisitOperator(*result);
52 }
53 return result;
54}
55
56unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr<LogicalOperator> plan) {
57 // first check if the logical operator has correlated expressions
58 auto entry = has_correlated_expressions.find(plan.get());
59 assert(entry != has_correlated_expressions.end());
60 if (!entry->second) {
61 // we reached a node without correlated expressions
62 // we can eliminate the dependent join now and create a simple cross product
63 auto cross_product = make_unique<LogicalCrossProduct>();
64 // now create the duplicate eliminated scan for this node
65 auto delim_index = binder.GenerateTableIndex();
66 this->base_binding = ColumnBinding(delim_index, 0);
67 auto delim_scan = make_unique<LogicalDelimGet>(delim_index, delim_types);
68 cross_product->children.push_back(move(delim_scan));
69 cross_product->children.push_back(move(plan));
70 return move(cross_product);
71 }
72 switch (plan->type) {
73 case LogicalOperatorType::FILTER: {
74 // filter
75 // first we flatten the dependent join in the child of the filter
76 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
77 // then we replace any correlated expressions with the corresponding entry in the correlated_map
78 RewriteCorrelatedExpressions rewriter(base_binding, correlated_map);
79 rewriter.VisitOperator(*plan);
80 return plan;
81 }
82 case LogicalOperatorType::PROJECTION: {
83 // projection
84 // first we flatten the dependent join in the child of the projection
85 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
86 // then we replace any correlated expressions with the corresponding entry in the correlated_map
87 RewriteCorrelatedExpressions rewriter(base_binding, correlated_map);
88 rewriter.VisitOperator(*plan);
89 // now we add all the columns of the delim_scan to the projection list
90 auto proj = (LogicalProjection *)plan.get();
91 for (idx_t i = 0; i < correlated_columns.size(); i++) {
92 auto colref = make_unique<BoundColumnRefExpression>(
93 correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i));
94 plan->expressions.push_back(move(colref));
95 }
96
97 base_binding.table_index = proj->table_index;
98 this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size();
99 this->data_offset = 0;
100 return plan;
101 }
102 case LogicalOperatorType::AGGREGATE_AND_GROUP_BY: {
103 auto &aggr = (LogicalAggregate &)*plan;
104 // aggregate and group by
105 // first we flatten the dependent join in the child of the projection
106 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
107 // then we replace any correlated expressions with the corresponding entry in the correlated_map
108 RewriteCorrelatedExpressions rewriter(base_binding, correlated_map);
109 rewriter.VisitOperator(*plan);
110 // now we add all the columns of the delim_scan to the grouping operators AND the projection list
111 for (idx_t i = 0; i < correlated_columns.size(); i++) {
112 auto colref = make_unique<BoundColumnRefExpression>(
113 correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i));
114 aggr.groups.push_back(move(colref));
115 }
116 if (aggr.groups.size() == correlated_columns.size()) {
117 // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan
118 auto left_outer_join = make_unique<LogicalComparisonJoin>(JoinType::LEFT);
119 auto left_index = binder.GenerateTableIndex();
120 auto delim_scan = make_unique<LogicalDelimGet>(left_index, delim_types);
121 left_outer_join->children.push_back(move(delim_scan));
122 left_outer_join->children.push_back(move(plan));
123 for (idx_t i = 0; i < correlated_columns.size(); i++) {
124 JoinCondition cond;
125 cond.left =
126 make_unique<BoundColumnRefExpression>(correlated_columns[i].type, ColumnBinding(left_index, i));
127 cond.right = make_unique<BoundColumnRefExpression>(
128 correlated_columns[i].type,
129 ColumnBinding(aggr.group_index, (aggr.groups.size() - correlated_columns.size()) + i));
130 cond.comparison = ExpressionType::COMPARE_EQUAL;
131 cond.null_values_are_equal = true;
132 left_outer_join->conditions.push_back(move(cond));
133 }
134 // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0
135 // ELSE COUNT(*) END
136 for (idx_t i = 0; i < aggr.expressions.size(); i++) {
137 assert(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE);
138 auto bound = (BoundAggregateExpression *)&*aggr.expressions[i];
139 vector<SQLType> arguments;
140 if (bound->function == CountFun::GetFunction() || bound->function == CountStarFun::GetFunction()) {
141 // have to replace this ColumnBinding with the CASE expression
142 replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i;
143 }
144 }
145 // now we update the delim_index
146
147 base_binding.table_index = left_index;
148 this->delim_offset = base_binding.column_index = 0;
149 this->data_offset = 0;
150 return move(left_outer_join);
151 } else {
152 // update the delim_index
153 base_binding.table_index = aggr.group_index;
154 this->delim_offset = base_binding.column_index = aggr.groups.size() - correlated_columns.size();
155 this->data_offset = aggr.groups.size();
156 return plan;
157 }
158 }
159 case LogicalOperatorType::CROSS_PRODUCT: {
160 // cross product
161 // push into both sides of the plan
162 bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second;
163 bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second;
164 if (!right_has_correlation) {
165 // only left has correlation: push into left
166 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
167 return plan;
168 }
169 if (!left_has_correlation) {
170 // only right has correlation: push into right
171 plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1]));
172 return plan;
173 }
174 // both sides have correlation
175 // turn into an inner join
176 auto join = make_unique<LogicalComparisonJoin>(JoinType::INNER);
177 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
178 auto left_binding = this->base_binding;
179 plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1]));
180 // add the correlated columns to the join conditions
181 for (idx_t i = 0; i < correlated_columns.size(); i++) {
182 JoinCondition cond;
183 cond.left = make_unique<BoundColumnRefExpression>(
184 correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i));
185 cond.right = make_unique<BoundColumnRefExpression>(
186 correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i));
187 cond.comparison = ExpressionType::COMPARE_EQUAL;
188 cond.null_values_are_equal = true;
189 join->conditions.push_back(move(cond));
190 }
191 join->children.push_back(move(plan->children[0]));
192 join->children.push_back(move(plan->children[1]));
193 return move(join);
194 }
195 case LogicalOperatorType::COMPARISON_JOIN: {
196 auto &join = (LogicalComparisonJoin &)*plan;
197 assert(plan->children.size() == 2);
198 // check the correlated expressions in the children of the join
199 bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second;
200 bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second;
201
202 if (join.join_type == JoinType::INNER) {
203 // inner join
204 if (!right_has_correlation) {
205 // only left has correlation: push into left
206 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
207 return plan;
208 }
209 if (!left_has_correlation) {
210 // only right has correlation: push into right
211 plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1]));
212 return plan;
213 }
214 } else if (join.join_type == JoinType::LEFT) {
215 // left outer join
216 if (!right_has_correlation) {
217 // only left has correlation: push into left
218 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
219 return plan;
220 }
221 } else if (join.join_type == JoinType::MARK) {
222 if (right_has_correlation) {
223 throw Exception("MARK join with correlation in RHS not supported");
224 }
225 // push the child into the LHS
226 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
227 // rewrite expressions in the join conditions
228 RewriteCorrelatedExpressions rewriter(base_binding, correlated_map);
229 rewriter.VisitOperator(*plan);
230 return plan;
231 } else {
232 throw Exception("Unsupported join type for flattening correlated subquery");
233 }
234 // both sides have correlation
235 // push into both sides
236 // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join
237 // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push
238 // because the RIGHT binding might contain NULL values
239 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
240 auto left_binding = this->base_binding;
241 plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1]));
242 auto right_binding = this->base_binding;
243 if (join.join_type == JoinType::LEFT) {
244 this->base_binding = left_binding;
245 }
246 // add the correlated columns to the join conditions
247 for (idx_t i = 0; i < correlated_columns.size(); i++) {
248 JoinCondition cond;
249
250 cond.left = make_unique<BoundColumnRefExpression>(
251 correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i));
252 cond.right = make_unique<BoundColumnRefExpression>(
253 correlated_columns[i].type, ColumnBinding(right_binding.table_index, right_binding.column_index + i));
254 cond.comparison = ExpressionType::COMPARE_EQUAL;
255 cond.null_values_are_equal = true;
256 join.conditions.push_back(move(cond));
257 }
258 // then we replace any correlated expressions with the corresponding entry in the correlated_map
259 RewriteCorrelatedExpressions rewriter(right_binding, correlated_map);
260 rewriter.VisitOperator(*plan);
261 return plan;
262 }
263 case LogicalOperatorType::LIMIT: {
264 auto &limit = (LogicalLimit &)*plan;
265 if (limit.offset > 0) {
266 throw ParserException("OFFSET not supported in correlated subquery");
267 }
268 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
269 if (limit.limit == 0) {
270 // limit = 0 means we return zero columns here
271 return plan;
272 } else {
273 // limit > 0 does nothing
274 return move(plan->children[0]);
275 }
276 }
277 case LogicalOperatorType::WINDOW: {
278 auto &window = (LogicalWindow &)*plan;
279 // push into children
280 plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0]));
281 // add the correlated columns to the PARTITION BY clauses in the Window
282 for (auto &expr : window.expressions) {
283 assert(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW);
284 auto &w = (BoundWindowExpression &)*expr;
285 for (idx_t i = 0; i < correlated_columns.size(); i++) {
286 w.partitions.push_back(make_unique<BoundColumnRefExpression>(
287 correlated_columns[i].type,
288 ColumnBinding(base_binding.table_index, base_binding.column_index + i)));
289 }
290 }
291 return plan;
292 }
293 case LogicalOperatorType::EXCEPT:
294 case LogicalOperatorType::INTERSECT:
295 case LogicalOperatorType::UNION: {
296 auto &setop = (LogicalSetOperation &)*plan;
297 // set operator, push into both children
298 plan->children[0] = PushDownDependentJoin(move(plan->children[0]));
299 plan->children[1] = PushDownDependentJoin(move(plan->children[1]));
300 // we have to refer to the setop index now
301 base_binding.table_index = setop.table_index;
302 base_binding.column_index = setop.column_count;
303 setop.column_count += correlated_columns.size();
304 return plan;
305 }
306 case LogicalOperatorType::DISTINCT:
307 plan->children[0] = PushDownDependentJoin(move(plan->children[0]));
308 return plan;
309 case LogicalOperatorType::ORDER_BY:
310 throw ParserException("ORDER BY not supported in correlated subquery");
311 default:
312 throw NotImplementedException("Logical operator type \"%s\" for dependent join",
313 LogicalOperatorToString(plan->type).c_str());
314 }
315}
316