| 1 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" | 
|---|
| 2 |  | 
|---|
| 3 | #include "duckdb/planner/expression/bound_case_expression.hpp" | 
|---|
| 4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" | 
|---|
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
|---|
| 6 | #include "duckdb/planner/expression/bound_operator_expression.hpp" | 
|---|
| 7 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" | 
|---|
| 8 | #include "duckdb/planner/expression_iterator.hpp" | 
|---|
| 9 |  | 
|---|
| 10 | using namespace duckdb; | 
|---|
| 11 | using namespace std; | 
|---|
| 12 |  | 
|---|
| 13 | RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding, | 
|---|
| 14 | column_binding_map_t<idx_t> &correlated_map) | 
|---|
| 15 | : base_binding(base_binding), correlated_map(correlated_map) { | 
|---|
| 16 | } | 
|---|
| 17 |  | 
|---|
| 18 | void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { | 
|---|
| 19 | VisitOperatorExpressions(op); | 
|---|
| 20 | } | 
|---|
| 21 |  | 
|---|
| 22 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, | 
|---|
| 23 | unique_ptr<Expression> *expr_ptr) { | 
|---|
| 24 | if (expr.depth == 0) { | 
|---|
| 25 | return nullptr; | 
|---|
| 26 | } | 
|---|
| 27 | // correlated column reference | 
|---|
| 28 | // replace with the entry referring to the duplicate eliminated scan | 
|---|
| 29 | assert(expr.depth == 1); | 
|---|
| 30 | auto entry = correlated_map.find(expr.binding); | 
|---|
| 31 | assert(entry != correlated_map.end()); | 
|---|
| 32 |  | 
|---|
| 33 | expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); | 
|---|
| 34 | expr.depth = 0; | 
|---|
| 35 | return nullptr; | 
|---|
| 36 | } | 
|---|
| 37 |  | 
|---|
| 38 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, | 
|---|
| 39 | unique_ptr<Expression> *expr_ptr) { | 
|---|
| 40 | if (!expr.IsCorrelated()) { | 
|---|
| 41 | return nullptr; | 
|---|
| 42 | } | 
|---|
| 43 | // subquery detected within this subquery | 
|---|
| 44 | // recursively rewrite it using the RewriteCorrelatedRecursive class | 
|---|
| 45 | RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map); | 
|---|
| 46 | rewrite.RewriteCorrelatedSubquery(expr); | 
|---|
| 47 | return nullptr; | 
|---|
| 48 | } | 
|---|
| 49 |  | 
|---|
| 50 | RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive( | 
|---|
| 51 | BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t<idx_t> &correlated_map) | 
|---|
| 52 | : parent(parent), base_binding(base_binding), correlated_map(correlated_map) { | 
|---|
| 53 | } | 
|---|
| 54 |  | 
|---|
| 55 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery( | 
|---|
| 56 | BoundSubqueryExpression &expr) { | 
|---|
| 57 | // rewrite the binding in the correlated list of the subquery) | 
|---|
| 58 | for (auto &corr : expr.binder->correlated_columns) { | 
|---|
| 59 | auto entry = correlated_map.find(corr.binding); | 
|---|
| 60 | if (entry != correlated_map.end()) { | 
|---|
| 61 | corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); | 
|---|
| 62 | } | 
|---|
| 63 | } | 
|---|
| 64 | // now rewrite any correlated BoundColumnRef expressions inside the subquery | 
|---|
| 65 | ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery, | 
|---|
| 66 | [&](Expression &child) { RewriteCorrelatedExpressions(child); }); | 
|---|
| 67 | } | 
|---|
| 68 |  | 
|---|
| 69 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) { | 
|---|
| 70 | if (child.type == ExpressionType::BOUND_COLUMN_REF) { | 
|---|
| 71 | // bound column reference | 
|---|
| 72 | auto &bound_colref = (BoundColumnRefExpression &)child; | 
|---|
| 73 | if (bound_colref.depth == 0) { | 
|---|
| 74 | // not a correlated column, ignore | 
|---|
| 75 | return; | 
|---|
| 76 | } | 
|---|
| 77 | // correlated column | 
|---|
| 78 | // check the correlated map | 
|---|
| 79 | auto entry = correlated_map.find(bound_colref.binding); | 
|---|
| 80 | if (entry != correlated_map.end()) { | 
|---|
| 81 | // we found the column in the correlated map! | 
|---|
| 82 | // update the binding and reduce the depth by 1 | 
|---|
| 83 |  | 
|---|
| 84 | bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); | 
|---|
| 85 | bound_colref.depth--; | 
|---|
| 86 | } | 
|---|
| 87 | } else if (child.type == ExpressionType::SUBQUERY) { | 
|---|
| 88 | // we encountered another subquery: rewrite recursively | 
|---|
| 89 | assert(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); | 
|---|
| 90 | auto &bound_subquery = (BoundSubqueryExpression &)child; | 
|---|
| 91 | RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map); | 
|---|
| 92 | rewrite.RewriteCorrelatedSubquery(bound_subquery); | 
|---|
| 93 | } | 
|---|
| 94 | } | 
|---|
| 95 |  | 
|---|
| 96 | RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t<idx_t> &replacement_map) | 
|---|
| 97 | : replacement_map(replacement_map) { | 
|---|
| 98 | } | 
|---|
| 99 |  | 
|---|
| 100 | unique_ptr<Expression> RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, | 
|---|
| 101 | unique_ptr<Expression> *expr_ptr) { | 
|---|
| 102 | auto entry = replacement_map.find(expr.binding); | 
|---|
| 103 | if (entry != replacement_map.end()) { | 
|---|
| 104 | // reference to a COUNT(*) aggregate | 
|---|
| 105 | // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END | 
|---|
| 106 | auto is_null = make_unique<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, TypeId::BOOL); | 
|---|
| 107 | is_null->children.push_back(expr.Copy()); | 
|---|
| 108 | auto check = move(is_null); | 
|---|
| 109 | auto result_if_true = make_unique<BoundConstantExpression>(Value::Numeric(expr.return_type, 0)); | 
|---|
| 110 | auto result_if_false = move(*expr_ptr); | 
|---|
| 111 | return make_unique<BoundCaseExpression>(move(check), move(result_if_true), move(result_if_false)); | 
|---|
| 112 | } | 
|---|
| 113 | return nullptr; | 
|---|
| 114 | } | 
|---|
| 115 |  | 
|---|