1#include "duckdb/planner/expression_binder/lateral_binder.hpp"
2#include "duckdb/planner/expression_iterator.hpp"
3#include "duckdb/planner/logical_operator_visitor.hpp"
4#include "duckdb/planner/expression/bound_columnref_expression.hpp"
5#include "duckdb/planner/expression/bound_subquery_expression.hpp"
6
7namespace duckdb {
8
9LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) {
10}
11
12void LateralBinder::ExtractCorrelatedColumns(Expression &expr) {
13 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
14 auto &bound_colref = expr.Cast<BoundColumnRefExpression>();
15 if (bound_colref.depth > 0) {
16 // add the correlated column info
17 CorrelatedColumnInfo info(bound_colref);
18 if (std::find(first: correlated_columns.begin(), last: correlated_columns.end(), val: info) == correlated_columns.end()) {
19 correlated_columns.push_back(x: std::move(info));
20 }
21 }
22 }
23 ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { ExtractCorrelatedColumns(expr&: child); });
24}
25
26BindResult LateralBinder::BindColumnRef(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) {
27 if (depth == 0) {
28 throw InternalException("Lateral binder can only bind correlated columns");
29 }
30 auto result = ExpressionBinder::BindExpression(expr_ptr, depth);
31 if (result.HasError()) {
32 return result;
33 }
34 if (depth > 1) {
35 throw BinderException("Nested lateral joins are not supported yet");
36 }
37 ExtractCorrelatedColumns(expr&: *result.expression);
38 return result;
39}
40
41vector<CorrelatedColumnInfo> LateralBinder::ExtractCorrelatedColumns(Binder &binder) {
42
43 if (correlated_columns.empty()) {
44 return binder.correlated_columns;
45 }
46
47 // clear outer
48 correlated_columns.clear();
49 auto all_correlated_columns = binder.correlated_columns;
50
51 // remove outer from inner
52 for (auto &corr_column : correlated_columns) {
53 auto entry = std::find(first: binder.correlated_columns.begin(), last: binder.correlated_columns.end(), val: corr_column);
54 if (entry != binder.correlated_columns.end()) {
55 binder.correlated_columns.erase(position: entry);
56 }
57 }
58
59 // add inner to outer
60 for (auto &corr_column : binder.correlated_columns) {
61 correlated_columns.push_back(x: corr_column);
62 }
63
64 // clear inner
65 binder.correlated_columns.clear();
66 return all_correlated_columns;
67}
68
69BindResult LateralBinder::BindExpression(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) {
70 auto &expr = *expr_ptr;
71 switch (expr.GetExpressionClass()) {
72 case ExpressionClass::DEFAULT:
73 return BindResult("LATERAL join cannot contain DEFAULT clause");
74 case ExpressionClass::WINDOW:
75 return BindResult("LATERAL join cannot contain window functions!");
76 case ExpressionClass::COLUMN_REF:
77 return BindColumnRef(expr_ptr, depth, root_expression);
78 default:
79 return ExpressionBinder::BindExpression(expr_ptr, depth);
80 }
81}
82
83string LateralBinder::UnsupportedAggregateMessage() {
84 return "LATERAL join cannot contain aggregates!";
85}
86
87class ExpressionDepthReducer : public LogicalOperatorVisitor {
88public:
89 explicit ExpressionDepthReducer(const vector<CorrelatedColumnInfo> &correlated) : correlated_columns(correlated) {
90 }
91
92protected:
93 void ReduceColumnRefDepth(BoundColumnRefExpression &expr) {
94 // don't need to reduce this
95 if (expr.depth == 0) {
96 return;
97 }
98 for (auto &correlated : correlated_columns) {
99 if (correlated.binding == expr.binding) {
100 D_ASSERT(expr.depth > 1);
101 expr.depth--;
102 break;
103 }
104 }
105 }
106
107 unique_ptr<Expression> VisitReplace(BoundColumnRefExpression &expr, unique_ptr<Expression> *expr_ptr) override {
108 ReduceColumnRefDepth(expr);
109 return nullptr;
110 }
111
112 void ReduceExpressionSubquery(BoundSubqueryExpression &expr) {
113 for (auto &s_correlated : expr.binder->correlated_columns) {
114 for (auto &correlated : correlated_columns) {
115 if (correlated == s_correlated) {
116 s_correlated.depth--;
117 break;
118 }
119 }
120 }
121 }
122
123 void ReduceExpressionDepth(Expression &expr) {
124 if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) {
125 ReduceColumnRefDepth(expr&: expr.Cast<BoundColumnRefExpression>());
126 }
127 if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) {
128 ReduceExpressionSubquery(expr&: expr.Cast<BoundSubqueryExpression>());
129 }
130 }
131
132 unique_ptr<Expression> VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) override {
133 ReduceExpressionSubquery(expr);
134 ExpressionIterator::EnumerateQueryNodeChildren(
135 node&: *expr.subquery, callback: [&](Expression &child_expr) { ReduceExpressionDepth(expr&: child_expr); });
136 return nullptr;
137 }
138
139 const vector<CorrelatedColumnInfo> &correlated_columns;
140};
141
142void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector<CorrelatedColumnInfo> &correlated) {
143 ExpressionDepthReducer depth_reducer(correlated);
144 depth_reducer.VisitOperator(op);
145}
146
147} // namespace duckdb
148