| 1 | #include "duckdb/planner/expression_binder/lateral_binder.hpp" | 
| 2 | #include "duckdb/planner/expression_iterator.hpp" | 
| 3 | #include "duckdb/planner/logical_operator_visitor.hpp" | 
| 4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" | 
| 5 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" | 
| 6 |  | 
| 7 | namespace duckdb { | 
| 8 |  | 
| 9 | LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { | 
| 10 | } | 
| 11 |  | 
| 12 | void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { | 
| 13 | 	if (expr.type == ExpressionType::BOUND_COLUMN_REF) { | 
| 14 | 		auto &bound_colref = expr.Cast<BoundColumnRefExpression>(); | 
| 15 | 		if (bound_colref.depth > 0) { | 
| 16 | 			// add the correlated column info | 
| 17 | 			CorrelatedColumnInfo info(bound_colref); | 
| 18 | 			if (std::find(first: correlated_columns.begin(), last: correlated_columns.end(), val: info) == correlated_columns.end()) { | 
| 19 | 				correlated_columns.push_back(x: std::move(info)); | 
| 20 | 			} | 
| 21 | 		} | 
| 22 | 	} | 
| 23 | 	ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { ExtractCorrelatedColumns(expr&: child); }); | 
| 24 | } | 
| 25 |  | 
| 26 | BindResult LateralBinder::BindColumnRef(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { | 
| 27 | 	if (depth == 0) { | 
| 28 | 		throw InternalException("Lateral binder can only bind correlated columns" ); | 
| 29 | 	} | 
| 30 | 	auto result = ExpressionBinder::BindExpression(expr_ptr, depth); | 
| 31 | 	if (result.HasError()) { | 
| 32 | 		return result; | 
| 33 | 	} | 
| 34 | 	if (depth > 1) { | 
| 35 | 		throw BinderException("Nested lateral joins are not supported yet" ); | 
| 36 | 	} | 
| 37 | 	ExtractCorrelatedColumns(expr&: *result.expression); | 
| 38 | 	return result; | 
| 39 | } | 
| 40 |  | 
| 41 | vector<CorrelatedColumnInfo> LateralBinder::ExtractCorrelatedColumns(Binder &binder) { | 
| 42 |  | 
| 43 | 	if (correlated_columns.empty()) { | 
| 44 | 		return binder.correlated_columns; | 
| 45 | 	} | 
| 46 |  | 
| 47 | 	// clear outer | 
| 48 | 	correlated_columns.clear(); | 
| 49 | 	auto all_correlated_columns = binder.correlated_columns; | 
| 50 |  | 
| 51 | 	// remove outer from inner | 
| 52 | 	for (auto &corr_column : correlated_columns) { | 
| 53 | 		auto entry = std::find(first: binder.correlated_columns.begin(), last: binder.correlated_columns.end(), val: corr_column); | 
| 54 | 		if (entry != binder.correlated_columns.end()) { | 
| 55 | 			binder.correlated_columns.erase(position: entry); | 
| 56 | 		} | 
| 57 | 	} | 
| 58 |  | 
| 59 | 	// add inner to outer | 
| 60 | 	for (auto &corr_column : binder.correlated_columns) { | 
| 61 | 		correlated_columns.push_back(x: corr_column); | 
| 62 | 	} | 
| 63 |  | 
| 64 | 	// clear inner | 
| 65 | 	binder.correlated_columns.clear(); | 
| 66 | 	return all_correlated_columns; | 
| 67 | } | 
| 68 |  | 
| 69 | BindResult LateralBinder::BindExpression(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { | 
| 70 | 	auto &expr = *expr_ptr; | 
| 71 | 	switch (expr.GetExpressionClass()) { | 
| 72 | 	case ExpressionClass::DEFAULT: | 
| 73 | 		return BindResult("LATERAL join cannot contain DEFAULT clause" ); | 
| 74 | 	case ExpressionClass::WINDOW: | 
| 75 | 		return BindResult("LATERAL join cannot contain window functions!" ); | 
| 76 | 	case ExpressionClass::COLUMN_REF: | 
| 77 | 		return BindColumnRef(expr_ptr, depth, root_expression); | 
| 78 | 	default: | 
| 79 | 		return ExpressionBinder::BindExpression(expr_ptr, depth); | 
| 80 | 	} | 
| 81 | } | 
| 82 |  | 
| 83 | string LateralBinder::UnsupportedAggregateMessage() { | 
| 84 | 	return "LATERAL join cannot contain aggregates!" ; | 
| 85 | } | 
| 86 |  | 
| 87 | class ExpressionDepthReducer : public LogicalOperatorVisitor { | 
| 88 | public: | 
| 89 | 	explicit ExpressionDepthReducer(const vector<CorrelatedColumnInfo> &correlated) : correlated_columns(correlated) { | 
| 90 | 	} | 
| 91 |  | 
| 92 | protected: | 
| 93 | 	void ReduceColumnRefDepth(BoundColumnRefExpression &expr) { | 
| 94 | 		// don't need to reduce this | 
| 95 | 		if (expr.depth == 0) { | 
| 96 | 			return; | 
| 97 | 		} | 
| 98 | 		for (auto &correlated : correlated_columns) { | 
| 99 | 			if (correlated.binding == expr.binding) { | 
| 100 | 				D_ASSERT(expr.depth > 1); | 
| 101 | 				expr.depth--; | 
| 102 | 				break; | 
| 103 | 			} | 
| 104 | 		} | 
| 105 | 	} | 
| 106 |  | 
| 107 | 	unique_ptr<Expression> VisitReplace(BoundColumnRefExpression &expr, unique_ptr<Expression> *expr_ptr) override { | 
| 108 | 		ReduceColumnRefDepth(expr); | 
| 109 | 		return nullptr; | 
| 110 | 	} | 
| 111 |  | 
| 112 | 	void ReduceExpressionSubquery(BoundSubqueryExpression &expr) { | 
| 113 | 		for (auto &s_correlated : expr.binder->correlated_columns) { | 
| 114 | 			for (auto &correlated : correlated_columns) { | 
| 115 | 				if (correlated == s_correlated) { | 
| 116 | 					s_correlated.depth--; | 
| 117 | 					break; | 
| 118 | 				} | 
| 119 | 			} | 
| 120 | 		} | 
| 121 | 	} | 
| 122 |  | 
| 123 | 	void ReduceExpressionDepth(Expression &expr) { | 
| 124 | 		if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { | 
| 125 | 			ReduceColumnRefDepth(expr&: expr.Cast<BoundColumnRefExpression>()); | 
| 126 | 		} | 
| 127 | 		if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { | 
| 128 | 			ReduceExpressionSubquery(expr&: expr.Cast<BoundSubqueryExpression>()); | 
| 129 | 		} | 
| 130 | 	} | 
| 131 |  | 
| 132 | 	unique_ptr<Expression> VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) override { | 
| 133 | 		ReduceExpressionSubquery(expr); | 
| 134 | 		ExpressionIterator::EnumerateQueryNodeChildren( | 
| 135 | 		    node&: *expr.subquery, callback: [&](Expression &child_expr) { ReduceExpressionDepth(expr&: child_expr); }); | 
| 136 | 		return nullptr; | 
| 137 | 	} | 
| 138 |  | 
| 139 | 	const vector<CorrelatedColumnInfo> &correlated_columns; | 
| 140 | }; | 
| 141 |  | 
| 142 | void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector<CorrelatedColumnInfo> &correlated) { | 
| 143 | 	ExpressionDepthReducer depth_reducer(correlated); | 
| 144 | 	depth_reducer.VisitOperator(op); | 
| 145 | } | 
| 146 |  | 
| 147 | } // namespace duckdb | 
| 148 |  |