| 1 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/expression/bound_case_expression.hpp" |
| 4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 6 | #include "duckdb/planner/expression/bound_operator_expression.hpp" |
| 7 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
| 8 | #include "duckdb/planner/expression_iterator.hpp" |
| 9 | |
| 10 | using namespace duckdb; |
| 11 | using namespace std; |
| 12 | |
| 13 | RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding, |
| 14 | column_binding_map_t<idx_t> &correlated_map) |
| 15 | : base_binding(base_binding), correlated_map(correlated_map) { |
| 16 | } |
| 17 | |
| 18 | void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { |
| 19 | VisitOperatorExpressions(op); |
| 20 | } |
| 21 | |
| 22 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, |
| 23 | unique_ptr<Expression> *expr_ptr) { |
| 24 | if (expr.depth == 0) { |
| 25 | return nullptr; |
| 26 | } |
| 27 | // correlated column reference |
| 28 | // replace with the entry referring to the duplicate eliminated scan |
| 29 | assert(expr.depth == 1); |
| 30 | auto entry = correlated_map.find(expr.binding); |
| 31 | assert(entry != correlated_map.end()); |
| 32 | |
| 33 | expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
| 34 | expr.depth = 0; |
| 35 | return nullptr; |
| 36 | } |
| 37 | |
| 38 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, |
| 39 | unique_ptr<Expression> *expr_ptr) { |
| 40 | if (!expr.IsCorrelated()) { |
| 41 | return nullptr; |
| 42 | } |
| 43 | // subquery detected within this subquery |
| 44 | // recursively rewrite it using the RewriteCorrelatedRecursive class |
| 45 | RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map); |
| 46 | rewrite.RewriteCorrelatedSubquery(expr); |
| 47 | return nullptr; |
| 48 | } |
| 49 | |
| 50 | RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive( |
| 51 | BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t<idx_t> &correlated_map) |
| 52 | : parent(parent), base_binding(base_binding), correlated_map(correlated_map) { |
| 53 | } |
| 54 | |
| 55 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery( |
| 56 | BoundSubqueryExpression &expr) { |
| 57 | // rewrite the binding in the correlated list of the subquery) |
| 58 | for (auto &corr : expr.binder->correlated_columns) { |
| 59 | auto entry = correlated_map.find(corr.binding); |
| 60 | if (entry != correlated_map.end()) { |
| 61 | corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
| 62 | } |
| 63 | } |
| 64 | // now rewrite any correlated BoundColumnRef expressions inside the subquery |
| 65 | ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery, |
| 66 | [&](Expression &child) { RewriteCorrelatedExpressions(child); }); |
| 67 | } |
| 68 | |
| 69 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) { |
| 70 | if (child.type == ExpressionType::BOUND_COLUMN_REF) { |
| 71 | // bound column reference |
| 72 | auto &bound_colref = (BoundColumnRefExpression &)child; |
| 73 | if (bound_colref.depth == 0) { |
| 74 | // not a correlated column, ignore |
| 75 | return; |
| 76 | } |
| 77 | // correlated column |
| 78 | // check the correlated map |
| 79 | auto entry = correlated_map.find(bound_colref.binding); |
| 80 | if (entry != correlated_map.end()) { |
| 81 | // we found the column in the correlated map! |
| 82 | // update the binding and reduce the depth by 1 |
| 83 | |
| 84 | bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
| 85 | bound_colref.depth--; |
| 86 | } |
| 87 | } else if (child.type == ExpressionType::SUBQUERY) { |
| 88 | // we encountered another subquery: rewrite recursively |
| 89 | assert(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); |
| 90 | auto &bound_subquery = (BoundSubqueryExpression &)child; |
| 91 | RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map); |
| 92 | rewrite.RewriteCorrelatedSubquery(bound_subquery); |
| 93 | } |
| 94 | } |
| 95 | |
| 96 | RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t<idx_t> &replacement_map) |
| 97 | : replacement_map(replacement_map) { |
| 98 | } |
| 99 | |
| 100 | unique_ptr<Expression> RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, |
| 101 | unique_ptr<Expression> *expr_ptr) { |
| 102 | auto entry = replacement_map.find(expr.binding); |
| 103 | if (entry != replacement_map.end()) { |
| 104 | // reference to a COUNT(*) aggregate |
| 105 | // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END |
| 106 | auto is_null = make_unique<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, TypeId::BOOL); |
| 107 | is_null->children.push_back(expr.Copy()); |
| 108 | auto check = move(is_null); |
| 109 | auto result_if_true = make_unique<BoundConstantExpression>(Value::Numeric(expr.return_type, 0)); |
| 110 | auto result_if_false = move(*expr_ptr); |
| 111 | return make_unique<BoundCaseExpression>(move(check), move(result_if_true), move(result_if_false)); |
| 112 | } |
| 113 | return nullptr; |
| 114 | } |
| 115 | |