1 | #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" |
2 | |
3 | #include "duckdb/planner/expression/bound_case_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_operator_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
8 | #include "duckdb/planner/expression_iterator.hpp" |
9 | |
10 | using namespace duckdb; |
11 | using namespace std; |
12 | |
13 | RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding, |
14 | column_binding_map_t<idx_t> &correlated_map) |
15 | : base_binding(base_binding), correlated_map(correlated_map) { |
16 | } |
17 | |
18 | void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { |
19 | VisitOperatorExpressions(op); |
20 | } |
21 | |
22 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, |
23 | unique_ptr<Expression> *expr_ptr) { |
24 | if (expr.depth == 0) { |
25 | return nullptr; |
26 | } |
27 | // correlated column reference |
28 | // replace with the entry referring to the duplicate eliminated scan |
29 | assert(expr.depth == 1); |
30 | auto entry = correlated_map.find(expr.binding); |
31 | assert(entry != correlated_map.end()); |
32 | |
33 | expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
34 | expr.depth = 0; |
35 | return nullptr; |
36 | } |
37 | |
38 | unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr, |
39 | unique_ptr<Expression> *expr_ptr) { |
40 | if (!expr.IsCorrelated()) { |
41 | return nullptr; |
42 | } |
43 | // subquery detected within this subquery |
44 | // recursively rewrite it using the RewriteCorrelatedRecursive class |
45 | RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map); |
46 | rewrite.RewriteCorrelatedSubquery(expr); |
47 | return nullptr; |
48 | } |
49 | |
50 | RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive( |
51 | BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t<idx_t> &correlated_map) |
52 | : parent(parent), base_binding(base_binding), correlated_map(correlated_map) { |
53 | } |
54 | |
55 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery( |
56 | BoundSubqueryExpression &expr) { |
57 | // rewrite the binding in the correlated list of the subquery) |
58 | for (auto &corr : expr.binder->correlated_columns) { |
59 | auto entry = correlated_map.find(corr.binding); |
60 | if (entry != correlated_map.end()) { |
61 | corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
62 | } |
63 | } |
64 | // now rewrite any correlated BoundColumnRef expressions inside the subquery |
65 | ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery, |
66 | [&](Expression &child) { RewriteCorrelatedExpressions(child); }); |
67 | } |
68 | |
69 | void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) { |
70 | if (child.type == ExpressionType::BOUND_COLUMN_REF) { |
71 | // bound column reference |
72 | auto &bound_colref = (BoundColumnRefExpression &)child; |
73 | if (bound_colref.depth == 0) { |
74 | // not a correlated column, ignore |
75 | return; |
76 | } |
77 | // correlated column |
78 | // check the correlated map |
79 | auto entry = correlated_map.find(bound_colref.binding); |
80 | if (entry != correlated_map.end()) { |
81 | // we found the column in the correlated map! |
82 | // update the binding and reduce the depth by 1 |
83 | |
84 | bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); |
85 | bound_colref.depth--; |
86 | } |
87 | } else if (child.type == ExpressionType::SUBQUERY) { |
88 | // we encountered another subquery: rewrite recursively |
89 | assert(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); |
90 | auto &bound_subquery = (BoundSubqueryExpression &)child; |
91 | RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map); |
92 | rewrite.RewriteCorrelatedSubquery(bound_subquery); |
93 | } |
94 | } |
95 | |
96 | RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t<idx_t> &replacement_map) |
97 | : replacement_map(replacement_map) { |
98 | } |
99 | |
100 | unique_ptr<Expression> RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, |
101 | unique_ptr<Expression> *expr_ptr) { |
102 | auto entry = replacement_map.find(expr.binding); |
103 | if (entry != replacement_map.end()) { |
104 | // reference to a COUNT(*) aggregate |
105 | // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END |
106 | auto is_null = make_unique<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, TypeId::BOOL); |
107 | is_null->children.push_back(expr.Copy()); |
108 | auto check = move(is_null); |
109 | auto result_if_true = make_unique<BoundConstantExpression>(Value::Numeric(expr.return_type, 0)); |
110 | auto result_if_false = move(*expr_ptr); |
111 | return make_unique<BoundCaseExpression>(move(check), move(result_if_true), move(result_if_false)); |
112 | } |
113 | return nullptr; |
114 | } |
115 | |