| 1 | #include "duckdb/optimizer/unnest_rewriter.hpp" |
| 2 | |
| 3 | #include "duckdb/common/pair.hpp" |
| 4 | #include "duckdb/planner/operator/logical_delim_get.hpp" |
| 5 | #include "duckdb/planner/operator/logical_delim_join.hpp" |
| 6 | #include "duckdb/planner/operator/logical_unnest.hpp" |
| 7 | #include "duckdb/planner/operator/logical_projection.hpp" |
| 8 | #include "duckdb/planner/operator/logical_window.hpp" |
| 9 | #include "duckdb/planner/expression/bound_unnest_expression.hpp" |
| 10 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 11 | |
| 12 | namespace duckdb { |
| 13 | |
| 14 | void UnnestRewriterPlanUpdater::VisitOperator(LogicalOperator &op) { |
| 15 | VisitOperatorChildren(op); |
| 16 | VisitOperatorExpressions(op); |
| 17 | } |
| 18 | |
| 19 | void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) { |
| 20 | |
| 21 | auto &expr = *expression; |
| 22 | |
| 23 | if (expr->expression_class == ExpressionClass::BOUND_COLUMN_REF) { |
| 24 | |
| 25 | auto &bound_column_ref = expr->Cast<BoundColumnRefExpression>(); |
| 26 | for (idx_t i = 0; i < replace_bindings.size(); i++) { |
| 27 | if (bound_column_ref.binding == replace_bindings[i].old_binding) { |
| 28 | bound_column_ref.binding = replace_bindings[i].new_binding; |
| 29 | break; |
| 30 | } |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | VisitExpressionChildren(expression&: **expression); |
| 35 | } |
| 36 | |
| 37 | unique_ptr<LogicalOperator> UnnestRewriter::Optimize(unique_ptr<LogicalOperator> op) { |
| 38 | |
| 39 | UnnestRewriterPlanUpdater updater; |
| 40 | vector<unique_ptr<LogicalOperator> *> candidates; |
| 41 | FindCandidates(op_ptr: &op, candidates); |
| 42 | |
| 43 | // rewrite the plan and update the bindings |
| 44 | for (auto &candidate : candidates) { |
| 45 | |
| 46 | // rearrange the logical operators |
| 47 | if (RewriteCandidate(candidate)) { |
| 48 | updater.overwritten_tbl_idx = overwritten_tbl_idx; |
| 49 | // update the bindings of the BOUND_UNNEST expression |
| 50 | UpdateBoundUnnestBindings(updater, candidate); |
| 51 | // update the sequence of LOGICAL_PROJECTION(s) |
| 52 | UpdateRHSBindings(plan_ptr: &op, candidate, updater); |
| 53 | // reset |
| 54 | delim_columns.clear(); |
| 55 | lhs_bindings.clear(); |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | return op; |
| 60 | } |
| 61 | |
| 62 | void UnnestRewriter::FindCandidates(unique_ptr<LogicalOperator> *op_ptr, |
| 63 | vector<unique_ptr<LogicalOperator> *> &candidates) { |
| 64 | auto op = op_ptr->get(); |
| 65 | // search children before adding, so that we add candidates bottom-up |
| 66 | for (auto &child : op->children) { |
| 67 | FindCandidates(op_ptr: &child, candidates); |
| 68 | } |
| 69 | |
| 70 | // search for operator that has a LOGICAL_DELIM_JOIN as its child |
| 71 | if (op->children.size() != 1) { |
| 72 | return; |
| 73 | } |
| 74 | if (op->children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { |
| 75 | return; |
| 76 | } |
| 77 | |
| 78 | // found a delim join |
| 79 | auto &delim_join = op->children[0]->Cast<LogicalDelimJoin>(); |
| 80 | // only support INNER delim joins |
| 81 | if (delim_join.join_type != JoinType::INNER) { |
| 82 | return; |
| 83 | } |
| 84 | // INNER delim join must have exactly one condition |
| 85 | if (delim_join.conditions.size() != 1) { |
| 86 | return; |
| 87 | } |
| 88 | |
| 89 | // LHS child is a window |
| 90 | if (delim_join.children[0]->type != LogicalOperatorType::LOGICAL_WINDOW) { |
| 91 | return; |
| 92 | } |
| 93 | |
| 94 | // RHS child must be projection(s) followed by an UNNEST |
| 95 | auto curr_op = &delim_join.children[1]; |
| 96 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
| 97 | if (curr_op->get()->children.size() != 1) { |
| 98 | break; |
| 99 | } |
| 100 | curr_op = &curr_op->get()->children[0]; |
| 101 | } |
| 102 | |
| 103 | if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST) { |
| 104 | candidates.push_back(x: op_ptr); |
| 105 | } |
| 106 | } |
| 107 | |
| 108 | bool UnnestRewriter::RewriteCandidate(unique_ptr<LogicalOperator> *candidate) { |
| 109 | |
| 110 | auto &topmost_op = (LogicalOperator &)**candidate; |
| 111 | if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && |
| 112 | topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && |
| 113 | topmost_op.type != LogicalOperatorType::LOGICAL_FILTER && |
| 114 | topmost_op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY && |
| 115 | topmost_op.type != LogicalOperatorType::LOGICAL_UNNEST) { |
| 116 | return false; |
| 117 | } |
| 118 | |
| 119 | // get the LOGICAL_DELIM_JOIN, which is a child of the candidate |
| 120 | D_ASSERT(topmost_op.children.size() == 1); |
| 121 | auto &delim_join = *(topmost_op.children[0]); |
| 122 | D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
| 123 | GetDelimColumns(op&: delim_join); |
| 124 | |
| 125 | // LHS of the LOGICAL_DELIM_JOIN is a LOGICAL_WINDOW that contains a LOGICAL_PROJECTION |
| 126 | // this lhs_proj later becomes the child of the UNNEST |
| 127 | auto &window = *delim_join.children[0]; |
| 128 | auto &lhs_op = window.children[0]; |
| 129 | GetLHSExpressions(op&: *lhs_op); |
| 130 | |
| 131 | // find the LOGICAL_UNNEST |
| 132 | // and get the path down to the LOGICAL_UNNEST |
| 133 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
| 134 | auto curr_op = &(delim_join.children[1]); |
| 135 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
| 136 | path_to_unnest.push_back(x: curr_op); |
| 137 | curr_op = &curr_op->get()->children[0]; |
| 138 | } |
| 139 | |
| 140 | // store the table index of the child of the LOGICAL_UNNEST |
| 141 | // then update the plan by making the lhs_proj the child of the LOGICAL_UNNEST |
| 142 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
| 143 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
| 144 | D_ASSERT(unnest.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET); |
| 145 | overwritten_tbl_idx = unnest.children[0]->Cast<LogicalDelimGet>().table_index; |
| 146 | |
| 147 | D_ASSERT(!unnest.children.empty()); |
| 148 | auto &delim_get = unnest.children[0]->Cast<LogicalDelimGet>(); |
| 149 | D_ASSERT(delim_get.chunk_types.size() > 1); |
| 150 | distinct_unnest_count = delim_get.chunk_types.size(); |
| 151 | unnest.children[0] = std::move(lhs_op); |
| 152 | |
| 153 | // replace the LOGICAL_DELIM_JOIN with its RHS child operator |
| 154 | topmost_op.children[0] = std::move(*path_to_unnest.front()); |
| 155 | return true; |
| 156 | } |
| 157 | |
| 158 | void UnnestRewriter::UpdateRHSBindings(unique_ptr<LogicalOperator> *plan_ptr, unique_ptr<LogicalOperator> *candidate, |
| 159 | UnnestRewriterPlanUpdater &updater) { |
| 160 | |
| 161 | auto &topmost_op = (LogicalOperator &)**candidate; |
| 162 | idx_t shift = lhs_bindings.size(); |
| 163 | |
| 164 | vector<unique_ptr<LogicalOperator> *> path_to_unnest; |
| 165 | auto curr_op = &(topmost_op.children[0]); |
| 166 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
| 167 | |
| 168 | path_to_unnest.push_back(x: curr_op); |
| 169 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
| 170 | auto &proj = curr_op->get()->Cast<LogicalProjection>(); |
| 171 | |
| 172 | // pop the unnest columns and the delim index |
| 173 | D_ASSERT(proj.expressions.size() > distinct_unnest_count); |
| 174 | for (idx_t i = 0; i < distinct_unnest_count; i++) { |
| 175 | proj.expressions.pop_back(); |
| 176 | } |
| 177 | |
| 178 | // store all shifted current bindings |
| 179 | idx_t tbl_idx = proj.table_index; |
| 180 | for (idx_t i = 0; i < proj.expressions.size(); i++) { |
| 181 | ReplaceBinding replace_binding(ColumnBinding(tbl_idx, i), ColumnBinding(tbl_idx, i + shift)); |
| 182 | updater.replace_bindings.push_back(x: replace_binding); |
| 183 | } |
| 184 | |
| 185 | curr_op = &curr_op->get()->children[0]; |
| 186 | } |
| 187 | |
| 188 | // update all bindings by shifting them |
| 189 | updater.VisitOperator(op&: *plan_ptr->get()); |
| 190 | updater.replace_bindings.clear(); |
| 191 | |
| 192 | // update all bindings coming from the LHS to RHS bindings |
| 193 | D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); |
| 194 | auto &top_proj = topmost_op.children[0]->Cast<LogicalProjection>(); |
| 195 | for (idx_t i = 0; i < lhs_bindings.size(); i++) { |
| 196 | ReplaceBinding replace_binding(lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, i)); |
| 197 | updater.replace_bindings.push_back(x: replace_binding); |
| 198 | } |
| 199 | |
| 200 | // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan |
| 201 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
| 202 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
| 203 | vector<unique_ptr<Expression>> temp_bound_unnests; |
| 204 | for (auto &temp_bound_unnest : unnest.expressions) { |
| 205 | temp_bound_unnests.push_back(x: std::move(temp_bound_unnest)); |
| 206 | } |
| 207 | D_ASSERT(unnest.children.size() == 1); |
| 208 | auto temp_unnest_child = std::move(unnest.children[0]); |
| 209 | unnest.expressions.clear(); |
| 210 | unnest.children.clear(); |
| 211 | // update the bindings of the plan |
| 212 | updater.VisitOperator(op&: *plan_ptr->get()); |
| 213 | updater.replace_bindings.clear(); |
| 214 | // add the children again |
| 215 | for (auto &temp_bound_unnest : temp_bound_unnests) { |
| 216 | unnest.expressions.push_back(x: std::move(temp_bound_unnest)); |
| 217 | } |
| 218 | unnest.children.push_back(x: std::move(temp_unnest_child)); |
| 219 | |
| 220 | // add the LHS expressions to each LOGICAL_PROJECTION |
| 221 | for (idx_t i = path_to_unnest.size(); i > 0; i--) { |
| 222 | |
| 223 | D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); |
| 224 | auto &proj = path_to_unnest[i - 1]->get()->Cast<LogicalProjection>(); |
| 225 | |
| 226 | // temporarily store the existing expressions |
| 227 | vector<unique_ptr<Expression>> existing_expressions; |
| 228 | for (idx_t expr_idx = 0; expr_idx < proj.expressions.size(); expr_idx++) { |
| 229 | existing_expressions.push_back(x: std::move(proj.expressions[expr_idx])); |
| 230 | } |
| 231 | |
| 232 | proj.expressions.clear(); |
| 233 | |
| 234 | // add the new expressions |
| 235 | for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { |
| 236 | auto new_expr = make_uniq<BoundColumnRefExpression>( |
| 237 | args&: lhs_bindings[expr_idx].alias, args&: lhs_bindings[expr_idx].type, args&: lhs_bindings[expr_idx].binding); |
| 238 | proj.expressions.push_back(x: std::move(new_expr)); |
| 239 | |
| 240 | // update the table index |
| 241 | lhs_bindings[expr_idx].binding.table_index = proj.table_index; |
| 242 | lhs_bindings[expr_idx].binding.column_index = expr_idx; |
| 243 | } |
| 244 | |
| 245 | // add the existing expressions again |
| 246 | for (idx_t expr_idx = 0; expr_idx < existing_expressions.size(); expr_idx++) { |
| 247 | proj.expressions.push_back(x: std::move(existing_expressions[expr_idx])); |
| 248 | } |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, |
| 253 | unique_ptr<LogicalOperator> *candidate) { |
| 254 | |
| 255 | auto &topmost_op = (LogicalOperator &)**candidate; |
| 256 | |
| 257 | // traverse LOGICAL_PROJECTION(s) |
| 258 | auto curr_op = &(topmost_op.children[0]); |
| 259 | while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { |
| 260 | curr_op = &curr_op->get()->children[0]; |
| 261 | } |
| 262 | |
| 263 | // found the LOGICAL_UNNEST |
| 264 | D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST); |
| 265 | auto &unnest = curr_op->get()->Cast<LogicalUnnest>(); |
| 266 | |
| 267 | D_ASSERT(unnest.children.size() == 1); |
| 268 | auto unnest_cols = unnest.children[0]->GetColumnBindings(); |
| 269 | |
| 270 | for (idx_t i = 0; i < delim_columns.size(); i++) { |
| 271 | auto delim_binding = delim_columns[i]; |
| 272 | |
| 273 | auto unnest_it = unnest_cols.begin(); |
| 274 | while (unnest_it != unnest_cols.end()) { |
| 275 | auto unnest_binding = *unnest_it; |
| 276 | |
| 277 | if (delim_binding.table_index == unnest_binding.table_index) { |
| 278 | unnest_binding.table_index = overwritten_tbl_idx; |
| 279 | unnest_binding.column_index++; |
| 280 | updater.replace_bindings.emplace_back(args&: unnest_binding, args&: delim_binding); |
| 281 | unnest_cols.erase(position: unnest_it); |
| 282 | break; |
| 283 | } |
| 284 | unnest_it++; |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | // update bindings |
| 289 | for (auto &unnest_expr : unnest.expressions) { |
| 290 | updater.VisitExpression(expression: &unnest_expr); |
| 291 | } |
| 292 | updater.replace_bindings.clear(); |
| 293 | } |
| 294 | |
| 295 | void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { |
| 296 | |
| 297 | D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
| 298 | auto &delim_join = op.Cast<LogicalDelimJoin>(); |
| 299 | for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { |
| 300 | auto &expr = *delim_join.duplicate_eliminated_columns[i]; |
| 301 | D_ASSERT(expr.type == ExpressionType::BOUND_COLUMN_REF); |
| 302 | auto &bound_colref_expr = expr.Cast<BoundColumnRefExpression>(); |
| 303 | delim_columns.push_back(x: bound_colref_expr.binding); |
| 304 | } |
| 305 | } |
| 306 | |
| 307 | void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { |
| 308 | |
| 309 | op.ResolveOperatorTypes(); |
| 310 | auto col_bindings = op.GetColumnBindings(); |
| 311 | D_ASSERT(op.types.size() == col_bindings.size()); |
| 312 | |
| 313 | bool set_alias = false; |
| 314 | // we can easily extract the alias for LOGICAL_PROJECTION(s) |
| 315 | if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { |
| 316 | auto &proj = op.Cast<LogicalProjection>(); |
| 317 | if (proj.expressions.size() == op.types.size()) { |
| 318 | set_alias = true; |
| 319 | } |
| 320 | } |
| 321 | |
| 322 | for (idx_t i = 0; i < op.types.size(); i++) { |
| 323 | lhs_bindings.emplace_back(args&: col_bindings[i], args&: op.types[i]); |
| 324 | if (set_alias) { |
| 325 | auto &proj = op.Cast<LogicalProjection>(); |
| 326 | lhs_bindings.back().alias = proj.expressions[i]->alias; |
| 327 | } |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | } // namespace duckdb |
| 332 | |