| 1 | #include "duckdb/planner/expression_binder/lateral_binder.hpp" |
| 2 | #include "duckdb/planner/expression_iterator.hpp" |
| 3 | #include "duckdb/planner/logical_operator_visitor.hpp" |
| 4 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_subquery_expression.hpp" |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context) { |
| 10 | } |
| 11 | |
| 12 | void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { |
| 13 | if (expr.type == ExpressionType::BOUND_COLUMN_REF) { |
| 14 | auto &bound_colref = expr.Cast<BoundColumnRefExpression>(); |
| 15 | if (bound_colref.depth > 0) { |
| 16 | // add the correlated column info |
| 17 | CorrelatedColumnInfo info(bound_colref); |
| 18 | if (std::find(first: correlated_columns.begin(), last: correlated_columns.end(), val: info) == correlated_columns.end()) { |
| 19 | correlated_columns.push_back(x: std::move(info)); |
| 20 | } |
| 21 | } |
| 22 | } |
| 23 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { ExtractCorrelatedColumns(expr&: child); }); |
| 24 | } |
| 25 | |
| 26 | BindResult LateralBinder::BindColumnRef(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { |
| 27 | if (depth == 0) { |
| 28 | throw InternalException("Lateral binder can only bind correlated columns" ); |
| 29 | } |
| 30 | auto result = ExpressionBinder::BindExpression(expr_ptr, depth); |
| 31 | if (result.HasError()) { |
| 32 | return result; |
| 33 | } |
| 34 | if (depth > 1) { |
| 35 | throw BinderException("Nested lateral joins are not supported yet" ); |
| 36 | } |
| 37 | ExtractCorrelatedColumns(expr&: *result.expression); |
| 38 | return result; |
| 39 | } |
| 40 | |
| 41 | vector<CorrelatedColumnInfo> LateralBinder::ExtractCorrelatedColumns(Binder &binder) { |
| 42 | |
| 43 | if (correlated_columns.empty()) { |
| 44 | return binder.correlated_columns; |
| 45 | } |
| 46 | |
| 47 | // clear outer |
| 48 | correlated_columns.clear(); |
| 49 | auto all_correlated_columns = binder.correlated_columns; |
| 50 | |
| 51 | // remove outer from inner |
| 52 | for (auto &corr_column : correlated_columns) { |
| 53 | auto entry = std::find(first: binder.correlated_columns.begin(), last: binder.correlated_columns.end(), val: corr_column); |
| 54 | if (entry != binder.correlated_columns.end()) { |
| 55 | binder.correlated_columns.erase(position: entry); |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | // add inner to outer |
| 60 | for (auto &corr_column : binder.correlated_columns) { |
| 61 | correlated_columns.push_back(x: corr_column); |
| 62 | } |
| 63 | |
| 64 | // clear inner |
| 65 | binder.correlated_columns.clear(); |
| 66 | return all_correlated_columns; |
| 67 | } |
| 68 | |
| 69 | BindResult LateralBinder::BindExpression(unique_ptr<ParsedExpression> &expr_ptr, idx_t depth, bool root_expression) { |
| 70 | auto &expr = *expr_ptr; |
| 71 | switch (expr.GetExpressionClass()) { |
| 72 | case ExpressionClass::DEFAULT: |
| 73 | return BindResult("LATERAL join cannot contain DEFAULT clause" ); |
| 74 | case ExpressionClass::WINDOW: |
| 75 | return BindResult("LATERAL join cannot contain window functions!" ); |
| 76 | case ExpressionClass::COLUMN_REF: |
| 77 | return BindColumnRef(expr_ptr, depth, root_expression); |
| 78 | default: |
| 79 | return ExpressionBinder::BindExpression(expr_ptr, depth); |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | string LateralBinder::UnsupportedAggregateMessage() { |
| 84 | return "LATERAL join cannot contain aggregates!" ; |
| 85 | } |
| 86 | |
| 87 | class ExpressionDepthReducer : public LogicalOperatorVisitor { |
| 88 | public: |
| 89 | explicit ExpressionDepthReducer(const vector<CorrelatedColumnInfo> &correlated) : correlated_columns(correlated) { |
| 90 | } |
| 91 | |
| 92 | protected: |
| 93 | void ReduceColumnRefDepth(BoundColumnRefExpression &expr) { |
| 94 | // don't need to reduce this |
| 95 | if (expr.depth == 0) { |
| 96 | return; |
| 97 | } |
| 98 | for (auto &correlated : correlated_columns) { |
| 99 | if (correlated.binding == expr.binding) { |
| 100 | D_ASSERT(expr.depth > 1); |
| 101 | expr.depth--; |
| 102 | break; |
| 103 | } |
| 104 | } |
| 105 | } |
| 106 | |
| 107 | unique_ptr<Expression> VisitReplace(BoundColumnRefExpression &expr, unique_ptr<Expression> *expr_ptr) override { |
| 108 | ReduceColumnRefDepth(expr); |
| 109 | return nullptr; |
| 110 | } |
| 111 | |
| 112 | void ReduceExpressionSubquery(BoundSubqueryExpression &expr) { |
| 113 | for (auto &s_correlated : expr.binder->correlated_columns) { |
| 114 | for (auto &correlated : correlated_columns) { |
| 115 | if (correlated == s_correlated) { |
| 116 | s_correlated.depth--; |
| 117 | break; |
| 118 | } |
| 119 | } |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | void ReduceExpressionDepth(Expression &expr) { |
| 124 | if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { |
| 125 | ReduceColumnRefDepth(expr&: expr.Cast<BoundColumnRefExpression>()); |
| 126 | } |
| 127 | if (expr.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { |
| 128 | ReduceExpressionSubquery(expr&: expr.Cast<BoundSubqueryExpression>()); |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | unique_ptr<Expression> VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) override { |
| 133 | ReduceExpressionSubquery(expr); |
| 134 | ExpressionIterator::EnumerateQueryNodeChildren( |
| 135 | node&: *expr.subquery, callback: [&](Expression &child_expr) { ReduceExpressionDepth(expr&: child_expr); }); |
| 136 | return nullptr; |
| 137 | } |
| 138 | |
| 139 | const vector<CorrelatedColumnInfo> &correlated_columns; |
| 140 | }; |
| 141 | |
| 142 | void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector<CorrelatedColumnInfo> &correlated) { |
| 143 | ExpressionDepthReducer depth_reducer(correlated); |
| 144 | depth_reducer.VisitOperator(op); |
| 145 | } |
| 146 | |
| 147 | } // namespace duckdb |
| 148 | |