1#include "duckdb/optimizer/filter_pullup.hpp"
2#include "duckdb/planner/expression/bound_columnref_expression.hpp"
3#include "duckdb/planner/expression_iterator.hpp"
4#include "duckdb/planner/operator/logical_empty_result.hpp"
5#include "duckdb/planner/operator/logical_projection.hpp"
6#include "duckdb/planner/expression/bound_comparison_expression.hpp"
7
8namespace duckdb {
9
10static void RevertFilterPullup(LogicalProjection &proj, vector<unique_ptr<Expression>> &expressions) {
11 unique_ptr<LogicalFilter> filter = make_uniq<LogicalFilter>();
12 for (idx_t i = 0; i < expressions.size(); ++i) {
13 filter->expressions.push_back(x: std::move(expressions[i]));
14 }
15 expressions.clear();
16 filter->children.push_back(x: std::move(proj.children[0]));
17 proj.children[0] = std::move(filter);
18}
19
20static void ReplaceExpressionBinding(vector<unique_ptr<Expression>> &proj_expressions, Expression &expr,
21 idx_t proj_table_idx) {
22 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
23 bool found_proj_col = false;
24 BoundColumnRefExpression &colref = expr.Cast<BoundColumnRefExpression>();
25 // find the corresponding column index in the projection expressions
26 for (idx_t proj_idx = 0; proj_idx < proj_expressions.size(); proj_idx++) {
27 auto &proj_expr = *proj_expressions[proj_idx];
28 if (proj_expr.type == ExpressionType::BOUND_COLUMN_REF) {
29 if (colref.Equals(other: proj_expr)) {
30 colref.binding.table_index = proj_table_idx;
31 colref.binding.column_index = proj_idx;
32 found_proj_col = true;
33 break;
34 }
35 }
36 }
37 if (!found_proj_col) {
38 // Project a new column
39 auto new_colref = colref.Copy();
40 colref.binding.table_index = proj_table_idx;
41 colref.binding.column_index = proj_expressions.size();
42 proj_expressions.push_back(x: std::move(new_colref));
43 }
44 }
45 ExpressionIterator::EnumerateChildren(
46 expression&: expr, callback: [&](Expression &child) { return ReplaceExpressionBinding(proj_expressions, expr&: child, proj_table_idx); });
47}
48
49void FilterPullup::ProjectSetOperation(LogicalProjection &proj) {
50 vector<unique_ptr<Expression>> copy_proj_expressions;
51 // copying the project expressions, it's useful whether we should revert the filter pullup
52 for (idx_t i = 0; i < proj.expressions.size(); ++i) {
53 copy_proj_expressions.push_back(x: proj.expressions[i]->Copy());
54 }
55
56 // Replace filter expression bindings, when need we add new columns into the copied projection expression
57 vector<unique_ptr<Expression>> changed_filter_expressions;
58 for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) {
59 auto copy_filter_expr = filters_expr_pullup[i]->Copy();
60 ReplaceExpressionBinding(proj_expressions&: copy_proj_expressions, expr&: (Expression &)*copy_filter_expr, proj_table_idx: proj.table_index);
61 changed_filter_expressions.push_back(x: std::move(copy_filter_expr));
62 }
63
64 /// Case new columns were added into the projection
65 // we must skip filter pullup because adding new columns to these operators will change the result
66 if (copy_proj_expressions.size() > proj.expressions.size()) {
67 RevertFilterPullup(proj, expressions&: filters_expr_pullup);
68 return;
69 }
70
71 // now we must replace the filter bindings
72 D_ASSERT(filters_expr_pullup.size() == changed_filter_expressions.size());
73 for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) {
74 filters_expr_pullup[i] = std::move(changed_filter_expressions[i]);
75 }
76}
77
78unique_ptr<LogicalOperator> FilterPullup::PullupProjection(unique_ptr<LogicalOperator> op) {
79 D_ASSERT(op->type == LogicalOperatorType::LOGICAL_PROJECTION);
80 op->children[0] = Rewrite(op: std::move(op->children[0]));
81 if (!filters_expr_pullup.empty()) {
82 auto &proj = op->Cast<LogicalProjection>();
83 // INTERSECT, EXCEPT, and DISTINCT
84 if (!can_add_column) {
85 // special treatment for operators that cannot add columns, e.g., INTERSECT, EXCEPT, and DISTINCT
86 ProjectSetOperation(proj);
87 return op;
88 }
89
90 for (idx_t i = 0; i < filters_expr_pullup.size(); ++i) {
91 auto &expr = (Expression &)*filters_expr_pullup[i];
92 ReplaceExpressionBinding(proj_expressions&: proj.expressions, expr, proj_table_idx: proj.table_index);
93 }
94 }
95 return op;
96}
97
98} // namespace duckdb
99