| 1 | #include "duckdb/planner/subquery/flatten_dependent_join.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/binder.hpp" |
| 4 | #include "duckdb/planner/expression/list.hpp" |
| 5 | #include "duckdb/planner/logical_operator_visitor.hpp" |
| 6 | #include "duckdb/planner/binder.hpp" |
| 7 | #include "duckdb/planner/operator/list.hpp" |
| 8 | #include "duckdb/planner/subquery/has_correlated_expressions.hpp" |
| 9 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" |
| 10 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 11 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
| 12 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
| 13 | |
| 14 | using namespace duckdb; |
| 15 | using namespace std; |
| 16 | |
| 17 | FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector<CorrelatedColumnInfo> &correlated) |
| 18 | : binder(binder), correlated_columns(correlated) { |
| 19 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 20 | auto &col = correlated_columns[i]; |
| 21 | correlated_map[col.binding] = i; |
| 22 | delim_types.push_back(col.type); |
| 23 | } |
| 24 | } |
| 25 | |
| 26 | bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator *op) { |
| 27 | assert(op); |
| 28 | // check if this entry has correlated expressions |
| 29 | HasCorrelatedExpressions visitor(correlated_columns); |
| 30 | visitor.VisitOperator(*op); |
| 31 | bool has_correlation = visitor.has_correlated_expressions; |
| 32 | // now visit the children of this entry and check if they have correlated expressions |
| 33 | for (auto &child : op->children) { |
| 34 | // we OR the property with its children such that has_correlation is true if either |
| 35 | // (1) this node has a correlated expression or |
| 36 | // (2) one of its children has a correlated expression |
| 37 | if (DetectCorrelatedExpressions(child.get())) { |
| 38 | has_correlation = true; |
| 39 | } |
| 40 | } |
| 41 | // set the entry in the map |
| 42 | has_correlated_expressions[op] = has_correlation; |
| 43 | return has_correlation; |
| 44 | } |
| 45 | |
| 46 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoin(unique_ptr<LogicalOperator> plan) { |
| 47 | auto result = PushDownDependentJoinInternal(move(plan)); |
| 48 | if (replacement_map.size() > 0) { |
| 49 | // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" |
| 50 | RewriteCountAggregates aggr(replacement_map); |
| 51 | aggr.VisitOperator(*result); |
| 52 | } |
| 53 | return result; |
| 54 | } |
| 55 | |
| 56 | unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal(unique_ptr<LogicalOperator> plan) { |
| 57 | // first check if the logical operator has correlated expressions |
| 58 | auto entry = has_correlated_expressions.find(plan.get()); |
| 59 | assert(entry != has_correlated_expressions.end()); |
| 60 | if (!entry->second) { |
| 61 | // we reached a node without correlated expressions |
| 62 | // we can eliminate the dependent join now and create a simple cross product |
| 63 | auto cross_product = make_unique<LogicalCrossProduct>(); |
| 64 | // now create the duplicate eliminated scan for this node |
| 65 | auto delim_index = binder.GenerateTableIndex(); |
| 66 | this->base_binding = ColumnBinding(delim_index, 0); |
| 67 | auto delim_scan = make_unique<LogicalDelimGet>(delim_index, delim_types); |
| 68 | cross_product->children.push_back(move(delim_scan)); |
| 69 | cross_product->children.push_back(move(plan)); |
| 70 | return move(cross_product); |
| 71 | } |
| 72 | switch (plan->type) { |
| 73 | case LogicalOperatorType::FILTER: { |
| 74 | // filter |
| 75 | // first we flatten the dependent join in the child of the filter |
| 76 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 77 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
| 78 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
| 79 | rewriter.VisitOperator(*plan); |
| 80 | return plan; |
| 81 | } |
| 82 | case LogicalOperatorType::PROJECTION: { |
| 83 | // projection |
| 84 | // first we flatten the dependent join in the child of the projection |
| 85 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 86 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
| 87 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
| 88 | rewriter.VisitOperator(*plan); |
| 89 | // now we add all the columns of the delim_scan to the projection list |
| 90 | auto proj = (LogicalProjection *)plan.get(); |
| 91 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 92 | auto colref = make_unique<BoundColumnRefExpression>( |
| 93 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
| 94 | plan->expressions.push_back(move(colref)); |
| 95 | } |
| 96 | |
| 97 | base_binding.table_index = proj->table_index; |
| 98 | this->delim_offset = base_binding.column_index = plan->expressions.size() - correlated_columns.size(); |
| 99 | this->data_offset = 0; |
| 100 | return plan; |
| 101 | } |
| 102 | case LogicalOperatorType::AGGREGATE_AND_GROUP_BY: { |
| 103 | auto &aggr = (LogicalAggregate &)*plan; |
| 104 | // aggregate and group by |
| 105 | // first we flatten the dependent join in the child of the projection |
| 106 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 107 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
| 108 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
| 109 | rewriter.VisitOperator(*plan); |
| 110 | // now we add all the columns of the delim_scan to the grouping operators AND the projection list |
| 111 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 112 | auto colref = make_unique<BoundColumnRefExpression>( |
| 113 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
| 114 | aggr.groups.push_back(move(colref)); |
| 115 | } |
| 116 | if (aggr.groups.size() == correlated_columns.size()) { |
| 117 | // we have to perform a LEFT OUTER JOIN between the result of this aggregate and the delim scan |
| 118 | auto left_outer_join = make_unique<LogicalComparisonJoin>(JoinType::LEFT); |
| 119 | auto left_index = binder.GenerateTableIndex(); |
| 120 | auto delim_scan = make_unique<LogicalDelimGet>(left_index, delim_types); |
| 121 | left_outer_join->children.push_back(move(delim_scan)); |
| 122 | left_outer_join->children.push_back(move(plan)); |
| 123 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 124 | JoinCondition cond; |
| 125 | cond.left = |
| 126 | make_unique<BoundColumnRefExpression>(correlated_columns[i].type, ColumnBinding(left_index, i)); |
| 127 | cond.right = make_unique<BoundColumnRefExpression>( |
| 128 | correlated_columns[i].type, |
| 129 | ColumnBinding(aggr.group_index, (aggr.groups.size() - correlated_columns.size()) + i)); |
| 130 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
| 131 | cond.null_values_are_equal = true; |
| 132 | left_outer_join->conditions.push_back(move(cond)); |
| 133 | } |
| 134 | // for any COUNT aggregate we replace references to the column with: CASE WHEN COUNT(*) IS NULL THEN 0 |
| 135 | // ELSE COUNT(*) END |
| 136 | for (idx_t i = 0; i < aggr.expressions.size(); i++) { |
| 137 | assert(aggr.expressions[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); |
| 138 | auto bound = (BoundAggregateExpression *)&*aggr.expressions[i]; |
| 139 | vector<SQLType> arguments; |
| 140 | if (bound->function == CountFun::GetFunction() || bound->function == CountStarFun::GetFunction()) { |
| 141 | // have to replace this ColumnBinding with the CASE expression |
| 142 | replacement_map[ColumnBinding(aggr.aggregate_index, i)] = i; |
| 143 | } |
| 144 | } |
| 145 | // now we update the delim_index |
| 146 | |
| 147 | base_binding.table_index = left_index; |
| 148 | this->delim_offset = base_binding.column_index = 0; |
| 149 | this->data_offset = 0; |
| 150 | return move(left_outer_join); |
| 151 | } else { |
| 152 | // update the delim_index |
| 153 | base_binding.table_index = aggr.group_index; |
| 154 | this->delim_offset = base_binding.column_index = aggr.groups.size() - correlated_columns.size(); |
| 155 | this->data_offset = aggr.groups.size(); |
| 156 | return plan; |
| 157 | } |
| 158 | } |
| 159 | case LogicalOperatorType::CROSS_PRODUCT: { |
| 160 | // cross product |
| 161 | // push into both sides of the plan |
| 162 | bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; |
| 163 | bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; |
| 164 | if (!right_has_correlation) { |
| 165 | // only left has correlation: push into left |
| 166 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 167 | return plan; |
| 168 | } |
| 169 | if (!left_has_correlation) { |
| 170 | // only right has correlation: push into right |
| 171 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
| 172 | return plan; |
| 173 | } |
| 174 | // both sides have correlation |
| 175 | // turn into an inner join |
| 176 | auto join = make_unique<LogicalComparisonJoin>(JoinType::INNER); |
| 177 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 178 | auto left_binding = this->base_binding; |
| 179 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
| 180 | // add the correlated columns to the join conditions |
| 181 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 182 | JoinCondition cond; |
| 183 | cond.left = make_unique<BoundColumnRefExpression>( |
| 184 | correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
| 185 | cond.right = make_unique<BoundColumnRefExpression>( |
| 186 | correlated_columns[i].type, ColumnBinding(base_binding.table_index, base_binding.column_index + i)); |
| 187 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
| 188 | cond.null_values_are_equal = true; |
| 189 | join->conditions.push_back(move(cond)); |
| 190 | } |
| 191 | join->children.push_back(move(plan->children[0])); |
| 192 | join->children.push_back(move(plan->children[1])); |
| 193 | return move(join); |
| 194 | } |
| 195 | case LogicalOperatorType::COMPARISON_JOIN: { |
| 196 | auto &join = (LogicalComparisonJoin &)*plan; |
| 197 | assert(plan->children.size() == 2); |
| 198 | // check the correlated expressions in the children of the join |
| 199 | bool left_has_correlation = has_correlated_expressions.find(plan->children[0].get())->second; |
| 200 | bool right_has_correlation = has_correlated_expressions.find(plan->children[1].get())->second; |
| 201 | |
| 202 | if (join.join_type == JoinType::INNER) { |
| 203 | // inner join |
| 204 | if (!right_has_correlation) { |
| 205 | // only left has correlation: push into left |
| 206 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 207 | return plan; |
| 208 | } |
| 209 | if (!left_has_correlation) { |
| 210 | // only right has correlation: push into right |
| 211 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
| 212 | return plan; |
| 213 | } |
| 214 | } else if (join.join_type == JoinType::LEFT) { |
| 215 | // left outer join |
| 216 | if (!right_has_correlation) { |
| 217 | // only left has correlation: push into left |
| 218 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 219 | return plan; |
| 220 | } |
| 221 | } else if (join.join_type == JoinType::MARK) { |
| 222 | if (right_has_correlation) { |
| 223 | throw Exception("MARK join with correlation in RHS not supported" ); |
| 224 | } |
| 225 | // push the child into the LHS |
| 226 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 227 | // rewrite expressions in the join conditions |
| 228 | RewriteCorrelatedExpressions rewriter(base_binding, correlated_map); |
| 229 | rewriter.VisitOperator(*plan); |
| 230 | return plan; |
| 231 | } else { |
| 232 | throw Exception("Unsupported join type for flattening correlated subquery" ); |
| 233 | } |
| 234 | // both sides have correlation |
| 235 | // push into both sides |
| 236 | // NOTE: for OUTER JOINS it matters what the BASE BINDING is after the join |
| 237 | // for the LEFT OUTER JOIN, we want the LEFT side to be the base binding after we push |
| 238 | // because the RIGHT binding might contain NULL values |
| 239 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 240 | auto left_binding = this->base_binding; |
| 241 | plan->children[1] = PushDownDependentJoinInternal(move(plan->children[1])); |
| 242 | auto right_binding = this->base_binding; |
| 243 | if (join.join_type == JoinType::LEFT) { |
| 244 | this->base_binding = left_binding; |
| 245 | } |
| 246 | // add the correlated columns to the join conditions |
| 247 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 248 | JoinCondition cond; |
| 249 | |
| 250 | cond.left = make_unique<BoundColumnRefExpression>( |
| 251 | correlated_columns[i].type, ColumnBinding(left_binding.table_index, left_binding.column_index + i)); |
| 252 | cond.right = make_unique<BoundColumnRefExpression>( |
| 253 | correlated_columns[i].type, ColumnBinding(right_binding.table_index, right_binding.column_index + i)); |
| 254 | cond.comparison = ExpressionType::COMPARE_EQUAL; |
| 255 | cond.null_values_are_equal = true; |
| 256 | join.conditions.push_back(move(cond)); |
| 257 | } |
| 258 | // then we replace any correlated expressions with the corresponding entry in the correlated_map |
| 259 | RewriteCorrelatedExpressions rewriter(right_binding, correlated_map); |
| 260 | rewriter.VisitOperator(*plan); |
| 261 | return plan; |
| 262 | } |
| 263 | case LogicalOperatorType::LIMIT: { |
| 264 | auto &limit = (LogicalLimit &)*plan; |
| 265 | if (limit.offset > 0) { |
| 266 | throw ParserException("OFFSET not supported in correlated subquery" ); |
| 267 | } |
| 268 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 269 | if (limit.limit == 0) { |
| 270 | // limit = 0 means we return zero columns here |
| 271 | return plan; |
| 272 | } else { |
| 273 | // limit > 0 does nothing |
| 274 | return move(plan->children[0]); |
| 275 | } |
| 276 | } |
| 277 | case LogicalOperatorType::WINDOW: { |
| 278 | auto &window = (LogicalWindow &)*plan; |
| 279 | // push into children |
| 280 | plan->children[0] = PushDownDependentJoinInternal(move(plan->children[0])); |
| 281 | // add the correlated columns to the PARTITION BY clauses in the Window |
| 282 | for (auto &expr : window.expressions) { |
| 283 | assert(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); |
| 284 | auto &w = (BoundWindowExpression &)*expr; |
| 285 | for (idx_t i = 0; i < correlated_columns.size(); i++) { |
| 286 | w.partitions.push_back(make_unique<BoundColumnRefExpression>( |
| 287 | correlated_columns[i].type, |
| 288 | ColumnBinding(base_binding.table_index, base_binding.column_index + i))); |
| 289 | } |
| 290 | } |
| 291 | return plan; |
| 292 | } |
| 293 | case LogicalOperatorType::EXCEPT: |
| 294 | case LogicalOperatorType::INTERSECT: |
| 295 | case LogicalOperatorType::UNION: { |
| 296 | auto &setop = (LogicalSetOperation &)*plan; |
| 297 | // set operator, push into both children |
| 298 | plan->children[0] = PushDownDependentJoin(move(plan->children[0])); |
| 299 | plan->children[1] = PushDownDependentJoin(move(plan->children[1])); |
| 300 | // we have to refer to the setop index now |
| 301 | base_binding.table_index = setop.table_index; |
| 302 | base_binding.column_index = setop.column_count; |
| 303 | setop.column_count += correlated_columns.size(); |
| 304 | return plan; |
| 305 | } |
| 306 | case LogicalOperatorType::DISTINCT: |
| 307 | plan->children[0] = PushDownDependentJoin(move(plan->children[0])); |
| 308 | return plan; |
| 309 | case LogicalOperatorType::ORDER_BY: |
| 310 | throw ParserException("ORDER BY not supported in correlated subquery" ); |
| 311 | default: |
| 312 | throw NotImplementedException("Logical operator type \"%s\" for dependent join" , |
| 313 | LogicalOperatorToString(plan->type).c_str()); |
| 314 | } |
| 315 | } |
| 316 | |