1#include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp"
2
3#include "duckdb/planner/expression/bound_case_expression.hpp"
4#include "duckdb/planner/expression/bound_columnref_expression.hpp"
5#include "duckdb/planner/expression/bound_constant_expression.hpp"
6#include "duckdb/planner/expression/bound_operator_expression.hpp"
7#include "duckdb/planner/expression/bound_subquery_expression.hpp"
8#include "duckdb/planner/expression_iterator.hpp"
9
10using namespace duckdb;
11using namespace std;
12
13RewriteCorrelatedExpressions::RewriteCorrelatedExpressions(ColumnBinding base_binding,
14 column_binding_map_t<idx_t> &correlated_map)
15 : base_binding(base_binding), correlated_map(correlated_map) {
16}
17
18void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) {
19 VisitOperatorExpressions(op);
20}
21
22unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr,
23 unique_ptr<Expression> *expr_ptr) {
24 if (expr.depth == 0) {
25 return nullptr;
26 }
27 // correlated column reference
28 // replace with the entry referring to the duplicate eliminated scan
29 assert(expr.depth == 1);
30 auto entry = correlated_map.find(expr.binding);
31 assert(entry != correlated_map.end());
32
33 expr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second);
34 expr.depth = 0;
35 return nullptr;
36}
37
38unique_ptr<Expression> RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryExpression &expr,
39 unique_ptr<Expression> *expr_ptr) {
40 if (!expr.IsCorrelated()) {
41 return nullptr;
42 }
43 // subquery detected within this subquery
44 // recursively rewrite it using the RewriteCorrelatedRecursive class
45 RewriteCorrelatedRecursive rewrite(expr, base_binding, correlated_map);
46 rewrite.RewriteCorrelatedSubquery(expr);
47 return nullptr;
48}
49
50RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedRecursive(
51 BoundSubqueryExpression &parent, ColumnBinding base_binding, column_binding_map_t<idx_t> &correlated_map)
52 : parent(parent), base_binding(base_binding), correlated_map(correlated_map) {
53}
54
55void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(
56 BoundSubqueryExpression &expr) {
57 // rewrite the binding in the correlated list of the subquery)
58 for (auto &corr : expr.binder->correlated_columns) {
59 auto entry = correlated_map.find(corr.binding);
60 if (entry != correlated_map.end()) {
61 corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second);
62 }
63 }
64 // now rewrite any correlated BoundColumnRef expressions inside the subquery
65 ExpressionIterator::EnumerateQueryNodeChildren(*expr.subquery,
66 [&](Expression &child) { RewriteCorrelatedExpressions(child); });
67}
68
69void RewriteCorrelatedExpressions::RewriteCorrelatedRecursive::RewriteCorrelatedExpressions(Expression &child) {
70 if (child.type == ExpressionType::BOUND_COLUMN_REF) {
71 // bound column reference
72 auto &bound_colref = (BoundColumnRefExpression &)child;
73 if (bound_colref.depth == 0) {
74 // not a correlated column, ignore
75 return;
76 }
77 // correlated column
78 // check the correlated map
79 auto entry = correlated_map.find(bound_colref.binding);
80 if (entry != correlated_map.end()) {
81 // we found the column in the correlated map!
82 // update the binding and reduce the depth by 1
83
84 bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second);
85 bound_colref.depth--;
86 }
87 } else if (child.type == ExpressionType::SUBQUERY) {
88 // we encountered another subquery: rewrite recursively
89 assert(child.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY);
90 auto &bound_subquery = (BoundSubqueryExpression &)child;
91 RewriteCorrelatedRecursive rewrite(bound_subquery, base_binding, correlated_map);
92 rewrite.RewriteCorrelatedSubquery(bound_subquery);
93 }
94}
95
96RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t<idx_t> &replacement_map)
97 : replacement_map(replacement_map) {
98}
99
100unique_ptr<Expression> RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr,
101 unique_ptr<Expression> *expr_ptr) {
102 auto entry = replacement_map.find(expr.binding);
103 if (entry != replacement_map.end()) {
104 // reference to a COUNT(*) aggregate
105 // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END
106 auto is_null = make_unique<BoundOperatorExpression>(ExpressionType::OPERATOR_IS_NULL, TypeId::BOOL);
107 is_null->children.push_back(expr.Copy());
108 auto check = move(is_null);
109 auto result_if_true = make_unique<BoundConstantExpression>(Value::Numeric(expr.return_type, 0));
110 auto result_if_false = move(*expr_ptr);
111 return make_unique<BoundCaseExpression>(move(check), move(result_if_true), move(result_if_false));
112 }
113 return nullptr;
114}
115