1 | #include "duckdb/optimizer/unnest_rewriter.hpp" |
2 | |
3 | #include "duckdb/common/pair.hpp" |
4 | #include "duckdb/planner/operator/logical_delim_get.hpp" |
5 | #include "duckdb/planner/operator/logical_delim_join.hpp" |
6 | #include "duckdb/planner/operator/logical_unnest.hpp" |
7 | #include "duckdb/planner/operator/logical_projection.hpp" |
8 | #include "duckdb/planner/operator/logical_window.hpp" |
9 | #include "duckdb/planner/expression/bound_unnest_expression.hpp" |
10 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
11 | |
12 | namespace duckdb { |
13 | |
14 | void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { |
15 | VisitOperatorChildren(op); |
16 | VisitOperatorExpressions(op); |
17 | } |
18 | |
19 | void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) { |
20 | |
21 | auto &expr = *expression; |
22 | |
23 | if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { |
24 | |
25 | auto &bound_column_ref = expr->Cast<BoundColumnRefExpression>(); |
26 | for (idx_t i = 0; i < replace_bindings.size(); i++) { |
27 | if (bound_column_ref.binding == replace_bindings[i].old_binding) { |
28 | bound_column_ref.binding = replace_bindings[i].new_binding; |
29 | break; |
30 | } |
31 | } |
32 | } |
33 | |
34 | VisitExpressionChildren(expression&: **expression); |
35 | } |
36 | |
37 | unique_ptr<LogicalOperator> UnnestRewriter::Optimize(unique_ptr<LogicalOperator> op) { |
38 | |
39 | UnnestRewriterPlanUpdater updater; |
40 | vector<unique_ptr<LogicalOperator> *> candidates; |
41 | FindCandidates(op_ptr: &op, candidates); |
42 | |
43 | // rewrite the plan and update the bindings |
44 | for (auto &candidate : candidates) { |
45 | |
46 | // rearrange the logical operators |
47 | if (RewriteCandidate(candidate)) { |
48 | updater.overwritten_tbl_idx = overwritten_tbl_idx; |
49 | // update the bindings of the BOUND_UNNEST expression |
50 | UpdateBoundUnnestBindings(updater, candidate); |
51 | // update the sequence of LOGICAL_PROJECTION(s) |
52 | UpdateRHSBindings(plan_ptr: &op, candidate, updater); |
53 | // reset |
54 | delim_columns.clear(); |
55 | lhs_bindings.clear(); |
56 | } |
57 | } |
58 | |
59 | return op; |
60 | } |
61 | |
62 | void UnnestRewriter::FindCandidates(unique_ptr<LogicalOperator> *op_ptr, |
63 | vector<unique_ptr<LogicalOperator> *> &candidates) { |
64 | auto op = op_ptr->get(); |
65 | // search children before adding, so that we add candidates bottom-up |
66 | for (auto &child : op->children) { |
67 | FindCandidates(op_ptr: &child, candidates); |
68 | } |
69 | |
70 | // search for operator that has a LOGICAL_DELIM_JOIN as its child |
71 | if (op->children.size() != 1) { |
72 | return; |
73 | } |
74 | if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { |
75 | return; |
76 | } |
77 | |
78 | // found a delim join |
79 | auto &delim_join = op->children[0]->Cast<LogicalDelimJoin>(); |
80 | // only support INNER delim joins |
81 | if (delim_join.join_type != JoinType::INNER) { |
82 | return; |
83 | } |
84 | // INNER delim join must have exactly one condition |
85 | if (delim_join.conditions.size() != 1) { |
86 | return; |
87 | } |
88 | |
89 | // LHS child is a window |
90 | if (delim_join.children[0]->type != LogicalOperatorType::LOGICAL_WINDOW) { |
91 | return; |
92 | } |
93 | |
94 | // RHS child must be projection(s) followed by an UNNEST |
95 | auto curr_op = &delim_join.children[1]; |
96 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
97 | if (curr_op->get()->children.size() != 1) { |
98 | break; |
99 | } |
100 | curr_op = &curr_op->get()->children[0]; |
101 | } |
102 | |
103 | if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST) { |
104 | candidates.push_back(x: op_ptr); |
105 | } |
106 | } |
107 | |
108 | bool UnnestRewriter::RewriteCandidate(unique_ptr<LogicalOperator> *candidate) { |
109 | |
110 | auto &topmost_op = (LogicalOperator &)**candidate; |
111 | if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && |
112 | topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && |
113 | topmost_op.type != LogicalOperatorType::LOGICAL_FILTER && |
114 | topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY && |
115 | topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) { |
116 | return false; |
117 | } |
118 | |
119 | // get the LOGICAL_DELIM_JOIN, which is a child of the candidate |
120 | D_ASSERT(topmost_op.children.size() == 1); |
121 | auto &delim_join = *(topmost_op.children[0]); |
122 | D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
123 | GetDelimColumns(op&: delim_join); |
124 | |
125 | // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION |
126 | // this lhs_proj later becomes the child of the UNNEST |
127 | auto &window = *delim_join.children[0]; |
128 | auto &lhs_op = window.children[0]; |
129 | GetLHSExpressions(op&: *lhs_op); |
130 | |
131 | // find the LOGICAL_UNNEST |
132 | // and get the path down to the LOGICAL_UNNEST |
133 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
134 | auto curr_op = &(delim_join.children[1]); |
135 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
136 | path_to_unnest.push_back(x: curr_op); |
137 | curr_op = &curr_op->get()->children[0]; |
138 | } |
139 | |
140 | // store the table index of the child of the LOGICAL_UNNEST |
141 | // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST |
142 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
143 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
144 | D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET); |
145 | overwritten_tbl_idx = unnest.children[0]->Cast<LogicalDelimGet>().table_index; |
146 | |
147 | D_ASSERT(!unnest.children.empty()); |
148 | auto &delim_get = unnest.children[0]->Cast<LogicalDelimGet>(); |
149 | D_ASSERT(delim_get.chunk_types.size() > 1); |
150 | distinct_unnest_count = delim_get.chunk_types.size(); |
151 | unnest.children[0] = std::move(lhs_op); |
152 | |
153 | // replace the LOGICAL_DELIM_JOIN with its RHS child operator |
154 | topmost_op.children[0] = std::move(*path_to_unnest.front()); |
155 | return true; |
156 | } |
157 | |
158 | void UnnestRewriter::UpdateRHSBindings(unique_ptr<LogicalOperator> *plan_ptr, unique_ptr<LogicalOperator> *candidate, |
159 | UnnestRewriterPlanUpdater &updater) { |
160 | |
161 | auto &topmost_op = (LogicalOperator &)**candidate; |
162 | idx_t shift = lhs_bindings.size(); |
163 | |
164 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
165 | auto curr_op = &(topmost_op.children[0]); |
166 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
167 | |
168 | path_to_unnest.push_back(x: curr_op); |
169 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
170 | auto &proj = curr_op->get()->Cast<LogicalProjection>(); |
171 | |
172 | // pop the unnest columns and the delim index |
173 | D_ASSERT(proj.expressions.size() > distinct_unnest_count); |
174 | for (idx_t i = 0; i < distinct_unnest_count; i++) { |
175 | proj.expressions.pop_back(); |
176 | } |
177 | |
178 | // store all shifted current bindings |
179 | idx_t tbl_idx = proj.table_index; |
180 | for (idx_t i = 0; i < proj.expressions.size(); i++) { |
181 | ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift)); |
182 | updater.replace_bindings.push_back(x: replace_binding); |
183 | } |
184 | |
185 | curr_op = &curr_op->get()->children[0]; |
186 | } |
187 | |
188 | // update all bindings by shifting them |
189 | updater.VisitOperator(op&: *plan_ptr->get()); |
190 | updater.replace_bindings.clear(); |
191 | |
192 | // update all bindings coming from the LHS to RHS bindings |
193 | D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); |
194 | auto &top_proj = topmost_op.children[0]->Cast<LogicalProjection>(); |
195 | for (idx_t i = 0; i < lhs_bindings.size(); i++) { |
196 | ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i)); |
197 | updater.replace_bindings.push_back(x: replace_binding); |
198 | } |
199 | |
200 | // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan |
201 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
202 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
203 | vector<unique_ptr<Expression>> temp_bound_unnests; |
204 | for (auto &temp_bound_unnest : unnest.expressions) { |
205 | temp_bound_unnests.push_back(x: std::move(temp_bound_unnest)); |
206 | } |
207 | D_ASSERT(unnest.children.size() == 1); |
208 | auto temp_unnest_child = std::move(unnest.children[0]); |
209 | unnest.expressions.clear(); |
210 | unnest.children.clear(); |
211 | // update the bindings of the plan |
212 | updater.VisitOperator(op&: *plan_ptr->get()); |
213 | updater.replace_bindings.clear(); |
214 | // add the children again |
215 | for (auto &temp_bound_unnest : temp_bound_unnests) { |
216 | unnest.expressions.push_back(x: std::move(temp_bound_unnest)); |
217 | } |
218 | unnest.children.push_back(x: std::move(temp_unnest_child)); |
219 | |
220 | // add the LHS expressions to each LOGICAL_PROJECTION |
221 | for (idx_t i = path_to_unnest.size(); i > 0; i--) { |
222 | |
223 | D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
224 | auto &proj = path_to_unnest[i - 1]->get()->Cast<LogicalProjection>(); |
225 | |
226 | // temporarily store the existing expressions |
227 | vector<unique_ptr<Expression>> existing_expressions; |
228 | for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) { |
229 | existing_expressions.push_back(x: std::move(proj.expressions[expr_idx])); |
230 | } |
231 | |
232 | proj.expressions.clear(); |
233 | |
234 | // add the new expressions |
235 | for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { |
236 | auto new_expr = make_uniq<BoundColumnRefExpression>( |
237 | args&: lhs_bindings[expr_idx].alias, args&: lhs_bindings[expr_idx].type, args&: lhs_bindings[expr_idx].binding); |
238 | proj.expressions.push_back(x: std::move(new_expr)); |
239 | |
240 | // update the table index |
241 | lhs_bindings[expr_idx].binding.table_index = proj.table_index; |
242 | lhs_bindings[expr_idx].binding.column_index = expr_idx; |
243 | } |
244 | |
245 | // add the existing expressions again |
246 | for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) { |
247 | proj.expressions.push_back(x: std::move(existing_expressions[expr_idx])); |
248 | } |
249 | } |
250 | } |
251 | |
252 | void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, |
253 | unique_ptr<LogicalOperator> *candidate) { |
254 | |
255 | auto &topmost_op = (LogicalOperator &)**candidate; |
256 | |
257 | // traverse LOGICAL_PROJECTION(s) |
258 | auto curr_op = &(topmost_op.children[0]); |
259 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
260 | curr_op = &curr_op->get()->children[0]; |
261 | } |
262 | |
263 | // found the LOGICAL_UNNEST |
264 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
265 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
266 | |
267 | D_ASSERT(unnest.children.size() == 1); |
268 | auto unnest_cols = unnest.children[0]->GetColumnBindings(); |
269 | |
270 | for (idx_t i = 0; i < delim_columns.size(); i++) { |
271 | auto delim_binding = delim_columns[i]; |
272 | |
273 | auto unnest_it = unnest_cols.begin(); |
274 | while (unnest_it != unnest_cols.end()) { |
275 | auto unnest_binding = *unnest_it; |
276 | |
277 | if (delim_binding.table_index == unnest_binding.table_index) { |
278 | unnest_binding.table_index = overwritten_tbl_idx; |
279 | unnest_binding.column_index++; |
280 | updater.replace_bindings.emplace_back(args&: unnest_binding, args&: delim_binding); |
281 | unnest_cols.erase(position: unnest_it); |
282 | break; |
283 | } |
284 | unnest_it++; |
285 | } |
286 | } |
287 | |
288 | // update bindings |
289 | for (auto &unnest_expr : unnest.expressions) { |
290 | updater.VisitExpression(expression: &unnest_expr); |
291 | } |
292 | updater.replace_bindings.clear(); |
293 | } |
294 | |
295 | void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { |
296 | |
297 | D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
298 | auto &delim_join = op.Cast<LogicalDelimJoin>(); |
299 | for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { |
300 | auto &expr = *delim_join.duplicate_eliminated_columns[i]; |
301 | D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF); |
302 | auto &bound_colref_expr = expr.Cast<BoundColumnRefExpression>(); |
303 | delim_columns.push_back(x: bound_colref_expr.binding); |
304 | } |
305 | } |
306 | |
307 | void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { |
308 | |
309 | op.ResolveOperatorTypes(); |
310 | auto col_bindings = op.GetColumnBindings(); |
311 | D_ASSERT(op.types.size() == col_bindings.size()); |
312 | |
313 | bool set_alias = false; |
314 | // we can easily extract the alias for LOGICAL_PROJECTION(s) |
315 | if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { |
316 | auto &proj = op.Cast<LogicalProjection>(); |
317 | if (proj.expressions.size() == op.types.size()) { |
318 | set_alias = true; |
319 | } |
320 | } |
321 | |
322 | for (idx_t i = 0; i < op.types.size(); i++) { |
323 | lhs_bindings.emplace_back(args&: col_bindings[i], args&: op.types[i]); |
324 | if (set_alias) { |
325 | auto &proj = op.Cast<LogicalProjection>(); |
326 | lhs_bindings.back().alias = proj.expressions[i]->alias; |
327 | } |
328 | } |
329 | } |
330 | |
331 | } // namespace duckdb |
332 | |