1 | #include "duckdb/planner/expression_binder/lateral_binder.hpp" |
2 | #include "duckdb/planner/expression_iterator.hpp" |
3 | #include "duckdb/planner/logical_operator_visitor.hpp" |
4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { |
10 | } |
11 | |
12 | void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { |
13 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
14 | auto &bound_colref = expr.Cast<BoundColumnRefExpression>(); |
15 | if (bound_colref.depth > 0) { |
16 | // add the correlated column info |
17 | CorrelatedColumnInfo info(bound_colref); |
18 | if (std::find(first: correlated_columns.begin(), last: correlated_columns.end(), val: info) == correlated_columns.end()) { |
19 | correlated_columns.push_back(x: std::move(info)); |
20 | } |
21 | } |
22 | } |
23 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { ExtractCorrelatedColumns(expr&: child); }); |
24 | } |
25 | |
26 | BindResult LateralBinder::BindColumnRef(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { |
27 | if (depth == 0) { |
28 | throw InternalException("Lateral binder can only bind correlated columns" ); |
29 | } |
30 | auto result = ExpressionBinder::BindExpression(expr_ptr, depth); |
31 | if (result.HasError()) { |
32 | return result; |
33 | } |
34 | if (depth > 1) { |
35 | throw BinderException("Nested lateral joins are not supported yet" ); |
36 | } |
37 | ExtractCorrelatedColumns(expr&: *result.expression); |
38 | return result; |
39 | } |
40 | |
41 | vector<CorrelatedColumnInfo> LateralBinder::ExtractCorrelatedColumns(Binder &binder) { |
42 | |
43 | if (correlated_columns.empty()) { |
44 | return binder.correlated_columns; |
45 | } |
46 | |
47 | // clear outer |
48 | correlated_columns.clear(); |
49 | auto all_correlated_columns = binder.correlated_columns; |
50 | |
51 | // remove outer from inner |
52 | for (auto &corr_column : correlated_columns) { |
53 | auto entry = std::find(first: binder.correlated_columns.begin(), last: binder.correlated_columns.end(), val: corr_column); |
54 | if (entry != binder.correlated_columns.end()) { |
55 | binder.correlated_columns.erase(position: entry); |
56 | } |
57 | } |
58 | |
59 | // add inner to outer |
60 | for (auto &corr_column : binder.correlated_columns) { |
61 | correlated_columns.push_back(x: corr_column); |
62 | } |
63 | |
64 | // clear inner |
65 | binder.correlated_columns.clear(); |
66 | return all_correlated_columns; |
67 | } |
68 | |
69 | BindResult LateralBinder::BindExpression(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { |
70 | auto &expr = *expr_ptr; |
71 | switch (expr.GetExpressionClass()) { |
72 | case ExpressionClass::DEFAULT: |
73 | return BindResult("LATERAL join cannot contain DEFAULT clause" ); |
74 | case ExpressionClass::WINDOW: |
75 | return BindResult("LATERAL join cannot contain window functions!" ); |
76 | case ExpressionClass::COLUMN_REF: |
77 | return BindColumnRef(expr_ptr, depth, root_expression); |
78 | default: |
79 | return ExpressionBinder::BindExpression(expr_ptr, depth); |
80 | } |
81 | } |
82 | |
83 | string LateralBinder::UnsupportedAggregateMessage() { |
84 | return "LATERAL join cannot contain aggregates!" ; |
85 | } |
86 | |
87 | class ExpressionDepthReducer : public LogicalOperatorVisitor { |
88 | public: |
89 | explicit ExpressionDepthReducer(const vector<CorrelatedColumnInfo> &correlated) : correlated_columns(correlated) { |
90 | } |
91 | |
92 | protected: |
93 | void ReduceColumnRefDepth(BoundColumnRefExpression &expr) { |
94 | // don't need to reduce this |
95 | if (expr.depth == 0) { |
96 | return; |
97 | } |
98 | for (auto &correlated : correlated_columns) { |
99 | if (correlated.binding == expr.binding) { |
100 | D_ASSERT(expr.depth > 1); |
101 | expr.depth--; |
102 | break; |
103 | } |
104 | } |
105 | } |
106 | |
107 | unique_ptr<Expression> VisitReplace(BoundColumnRefExpression &expr, unique_ptr<Expression> *expr_ptr) override { |
108 | ReduceColumnRefDepth(expr); |
109 | return nullptr; |
110 | } |
111 | |
112 | void ReduceExpressionSubquery(BoundSubqueryExpression &expr) { |
113 | for (auto &s_correlated : expr.binder->correlated_columns) { |
114 | for (auto &correlated : correlated_columns) { |
115 | if (correlated == s_correlated) { |
116 | s_correlated.depth--; |
117 | break; |
118 | } |
119 | } |
120 | } |
121 | } |
122 | |
123 | void ReduceExpressionDepth(Expression &expr) { |
124 | if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { |
125 | ReduceColumnRefDepth(expr&: expr.Cast<BoundColumnRefExpression>()); |
126 | } |
127 | if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { |
128 | ReduceExpressionSubquery(expr&: expr.Cast<BoundSubqueryExpression>()); |
129 | } |
130 | } |
131 | |
132 | unique_ptr<Expression> VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) override { |
133 | ReduceExpressionSubquery(expr); |
134 | ExpressionIterator::EnumerateQueryNodeChildren( |
135 | node&: *expr.subquery, callback: [&](Expression &child_expr) { ReduceExpressionDepth(expr&: child_expr); }); |
136 | return nullptr; |
137 | } |
138 | |
139 | const vector<CorrelatedColumnInfo> &correlated_columns; |
140 | }; |
141 | |
142 | void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector<CorrelatedColumnInfo> &correlated) { |
143 | ExpressionDepthReducer depth_reducer(correlated); |
144 | depth_reducer.VisitOperator(op); |
145 | } |
146 | |
147 | } // namespace duckdb |
148 | |