1#include "duckdb/optimizer/unnest_rewriter.hpp"
2
3#include "duckdb/common/pair.hpp"
4#include "duckdb/planner/operator/logical_delim_get.hpp"
5#include "duckdb/planner/operator/logical_delim_join.hpp"
6#include "duckdb/planner/operator/logical_unnest.hpp"
7#include "duckdb/planner/operator/logical_projection.hpp"
8#include "duckdb/planner/operator/logical_window.hpp"
9#include "duckdb/planner/expression/bound_unnest_expression.hpp"
10#include "duckdb/planner/expression/bound_columnref_expression.hpp"
11
12namespace duckdb {
13
14void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) {
15 VisitOperatorChildren(op);
16 VisitOperatorExpressions(op);
17}
18
19void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) {
20
21 auto &expr = *expression;
22
23 if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) {
24
25 auto &bound_column_ref = expr->Cast<BoundColumnRefExpression>();
26 for (idx_t i = 0; i < replace_bindings.size(); i++) {
27 if (bound_column_ref.binding == replace_bindings[i].old_binding) {
28 bound_column_ref.binding = replace_bindings[i].new_binding;
29 break;
30 }
31 }
32 }
33
34 VisitExpressionChildren(expression&: **expression);
35}
36
37unique_ptr<LogicalOperator> UnnestRewriter::Optimize(unique_ptr<LogicalOperator> op) {
38
39 UnnestRewriterPlanUpdater updater;
40 vector<unique_ptr<LogicalOperator> *> candidates;
41 FindCandidates(op_ptr: &op, candidates);
42
43 // rewrite the plan and update the bindings
44 for (auto &candidate : candidates) {
45
46 // rearrange the logical operators
47 if (RewriteCandidate(candidate)) {
48 updater.overwritten_tbl_idx = overwritten_tbl_idx;
49 // update the bindings of the BOUND_UNNEST expression
50 UpdateBoundUnnestBindings(updater, candidate);
51 // update the sequence of LOGICAL_PROJECTION(s)
52 UpdateRHSBindings(plan_ptr: &op, candidate, updater);
53 // reset
54 delim_columns.clear();
55 lhs_bindings.clear();
56 }
57 }
58
59 return op;
60}
61
62void UnnestRewriter::FindCandidates(unique_ptr<LogicalOperator> *op_ptr,
63 vector<unique_ptr<LogicalOperator> *> &candidates) {
64 auto op = op_ptr->get();
65 // search children before adding, so that we add candidates bottom-up
66 for (auto &child : op->children) {
67 FindCandidates(op_ptr: &child, candidates);
68 }
69
70 // search for operator that has a LOGICAL_DELIM_JOIN as its child
71 if (op->children.size() != 1) {
72 return;
73 }
74 if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) {
75 return;
76 }
77
78 // found a delim join
79 auto &delim_join = op->children[0]->Cast<LogicalDelimJoin>();
80 // only support INNER delim joins
81 if (delim_join.join_type != JoinType::INNER) {
82 return;
83 }
84 // INNER delim join must have exactly one condition
85 if (delim_join.conditions.size() != 1) {
86 return;
87 }
88
89 // LHS child is a window
90 if (delim_join.children[0]->type != LogicalOperatorType::LOGICAL_WINDOW) {
91 return;
92 }
93
94 // RHS child must be projection(s) followed by an UNNEST
95 auto curr_op = &delim_join.children[1];
96 while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) {
97 if (curr_op->get()->children.size() != 1) {
98 break;
99 }
100 curr_op = &curr_op->get()->children[0];
101 }
102
103 if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST) {
104 candidates.push_back(x: op_ptr);
105 }
106}
107
108bool UnnestRewriter::RewriteCandidate(unique_ptr<LogicalOperator> *candidate) {
109
110 auto &topmost_op = (LogicalOperator &)**candidate;
111 if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION &&
112 topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW &&
113 topmost_op.type != LogicalOperatorType::LOGICAL_FILTER &&
114 topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY &&
115 topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) {
116 return false;
117 }
118
119 // get the LOGICAL_DELIM_JOIN, which is a child of the candidate
120 D_ASSERT(topmost_op.children.size() == 1);
121 auto &delim_join = *(topmost_op.children[0]);
122 D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN);
123 GetDelimColumns(op&: delim_join);
124
125 // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION
126 // this lhs_proj later becomes the child of the UNNEST
127 auto &window = *delim_join.children[0];
128 auto &lhs_op = window.children[0];
129 GetLHSExpressions(op&: *lhs_op);
130
131 // find the LOGICAL_UNNEST
132 // and get the path down to the LOGICAL_UNNEST
133 vector<unique_ptr<LogicalOperator> *> path_to_unnest;
134 auto curr_op = &(delim_join.children[1]);
135 while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) {
136 path_to_unnest.push_back(x: curr_op);
137 curr_op = &curr_op->get()->children[0];
138 }
139
140 // store the table index of the child of the LOGICAL_UNNEST
141 // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST
142 D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST);
143 auto &unnest = curr_op->get()->Cast<LogicalUnnest>();
144 D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET);
145 overwritten_tbl_idx = unnest.children[0]->Cast<LogicalDelimGet>().table_index;
146
147 D_ASSERT(!unnest.children.empty());
148 auto &delim_get = unnest.children[0]->Cast<LogicalDelimGet>();
149 D_ASSERT(delim_get.chunk_types.size() > 1);
150 distinct_unnest_count = delim_get.chunk_types.size();
151 unnest.children[0] = std::move(lhs_op);
152
153 // replace the LOGICAL_DELIM_JOIN with its RHS child operator
154 topmost_op.children[0] = std::move(*path_to_unnest.front());
155 return true;
156}
157
158void UnnestRewriter::UpdateRHSBindings(unique_ptr<LogicalOperator> *plan_ptr, unique_ptr<LogicalOperator> *candidate,
159 UnnestRewriterPlanUpdater &updater) {
160
161 auto &topmost_op = (LogicalOperator &)**candidate;
162 idx_t shift = lhs_bindings.size();
163
164 vector<unique_ptr<LogicalOperator> *> path_to_unnest;
165 auto curr_op = &(topmost_op.children[0]);
166 while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) {
167
168 path_to_unnest.push_back(x: curr_op);
169 D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION);
170 auto &proj = curr_op->get()->Cast<LogicalProjection>();
171
172 // pop the unnest columns and the delim index
173 D_ASSERT(proj.expressions.size() > distinct_unnest_count);
174 for (idx_t i = 0; i < distinct_unnest_count; i++) {
175 proj.expressions.pop_back();
176 }
177
178 // store all shifted current bindings
179 idx_t tbl_idx = proj.table_index;
180 for (idx_t i = 0; i < proj.expressions.size(); i++) {
181 ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift));
182 updater.replace_bindings.push_back(x: replace_binding);
183 }
184
185 curr_op = &curr_op->get()->children[0];
186 }
187
188 // update all bindings by shifting them
189 updater.VisitOperator(op&: *plan_ptr->get());
190 updater.replace_bindings.clear();
191
192 // update all bindings coming from the LHS to RHS bindings
193 D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION);
194 auto &top_proj = topmost_op.children[0]->Cast<LogicalProjection>();
195 for (idx_t i = 0; i < lhs_bindings.size(); i++) {
196 ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i));
197 updater.replace_bindings.push_back(x: replace_binding);
198 }
199
200 // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan
201 D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST);
202 auto &unnest = curr_op->get()->Cast<LogicalUnnest>();
203 vector<unique_ptr<Expression>> temp_bound_unnests;
204 for (auto &temp_bound_unnest : unnest.expressions) {
205 temp_bound_unnests.push_back(x: std::move(temp_bound_unnest));
206 }
207 D_ASSERT(unnest.children.size() == 1);
208 auto temp_unnest_child = std::move(unnest.children[0]);
209 unnest.expressions.clear();
210 unnest.children.clear();
211 // update the bindings of the plan
212 updater.VisitOperator(op&: *plan_ptr->get());
213 updater.replace_bindings.clear();
214 // add the children again
215 for (auto &temp_bound_unnest : temp_bound_unnests) {
216 unnest.expressions.push_back(x: std::move(temp_bound_unnest));
217 }
218 unnest.children.push_back(x: std::move(temp_unnest_child));
219
220 // add the LHS expressions to each LOGICAL_PROJECTION
221 for (idx_t i = path_to_unnest.size(); i > 0; i--) {
222
223 D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION);
224 auto &proj = path_to_unnest[i - 1]->get()->Cast<LogicalProjection>();
225
226 // temporarily store the existing expressions
227 vector<unique_ptr<Expression>> existing_expressions;
228 for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) {
229 existing_expressions.push_back(x: std::move(proj.expressions[expr_idx]));
230 }
231
232 proj.expressions.clear();
233
234 // add the new expressions
235 for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) {
236 auto new_expr = make_uniq<BoundColumnRefExpression>(
237 args&: lhs_bindings[expr_idx].alias, args&: lhs_bindings[expr_idx].type, args&: lhs_bindings[expr_idx].binding);
238 proj.expressions.push_back(x: std::move(new_expr));
239
240 // update the table index
241 lhs_bindings[expr_idx].binding.table_index = proj.table_index;
242 lhs_bindings[expr_idx].binding.column_index = expr_idx;
243 }
244
245 // add the existing expressions again
246 for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) {
247 proj.expressions.push_back(x: std::move(existing_expressions[expr_idx]));
248 }
249 }
250}
251
252void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater,
253 unique_ptr<LogicalOperator> *candidate) {
254
255 auto &topmost_op = (LogicalOperator &)**candidate;
256
257 // traverse LOGICAL_PROJECTION(s)
258 auto curr_op = &(topmost_op.children[0]);
259 while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) {
260 curr_op = &curr_op->get()->children[0];
261 }
262
263 // found the LOGICAL_UNNEST
264 D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST);
265 auto &unnest = curr_op->get()->Cast<LogicalUnnest>();
266
267 D_ASSERT(unnest.children.size() == 1);
268 auto unnest_cols = unnest.children[0]->GetColumnBindings();
269
270 for (idx_t i = 0; i < delim_columns.size(); i++) {
271 auto delim_binding = delim_columns[i];
272
273 auto unnest_it = unnest_cols.begin();
274 while (unnest_it != unnest_cols.end()) {
275 auto unnest_binding = *unnest_it;
276
277 if (delim_binding.table_index == unnest_binding.table_index) {
278 unnest_binding.table_index = overwritten_tbl_idx;
279 unnest_binding.column_index++;
280 updater.replace_bindings.emplace_back(args&: unnest_binding, args&: delim_binding);
281 unnest_cols.erase(position: unnest_it);
282 break;
283 }
284 unnest_it++;
285 }
286 }
287
288 // update bindings
289 for (auto &unnest_expr : unnest.expressions) {
290 updater.VisitExpression(expression: &unnest_expr);
291 }
292 updater.replace_bindings.clear();
293}
294
295void UnnestRewriter::GetDelimColumns(LogicalOperator &op) {
296
297 D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN);
298 auto &delim_join = op.Cast<LogicalDelimJoin>();
299 for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) {
300 auto &expr = *delim_join.duplicate_eliminated_columns[i];
301 D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF);
302 auto &bound_colref_expr = expr.Cast<BoundColumnRefExpression>();
303 delim_columns.push_back(x: bound_colref_expr.binding);
304 }
305}
306
307void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) {
308
309 op.ResolveOperatorTypes();
310 auto col_bindings = op.GetColumnBindings();
311 D_ASSERT(op.types.size() == col_bindings.size());
312
313 bool set_alias = false;
314 // we can easily extract the alias for LOGICAL_PROJECTION(s)
315 if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) {
316 auto &proj = op.Cast<LogicalProjection>();
317 if (proj.expressions.size() == op.types.size()) {
318 set_alias = true;
319 }
320 }
321
322 for (idx_t i = 0; i < op.types.size(); i++) {
323 lhs_bindings.emplace_back(args&: col_bindings[i], args&: op.types[i]);
324 if (set_alias) {
325 auto &proj = op.Cast<LogicalProjection>();
326 lhs_bindings.back().alias = proj.expressions[i]->alias;
327 }
328 }
329}
330
331} // namespace duckdb
332