1#include "duckdb/execution/column_binding_resolver.hpp"
2
3#include "duckdb/planner/operator/logical_comparison_join.hpp"
4#include "duckdb/planner/operator/logical_any_join.hpp"
5#include "duckdb/planner/operator/logical_create_index.hpp"
6#include "duckdb/planner/operator/logical_delim_join.hpp"
7#include "duckdb/planner/operator/logical_insert.hpp"
8
9#include "duckdb/planner/expression/bound_columnref_expression.hpp"
10#include "duckdb/planner/expression/bound_reference_expression.hpp"
11
12#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
13#include "duckdb/common/to_string.hpp"
14
15namespace duckdb {
16
17ColumnBindingResolver::ColumnBindingResolver() {
18}
19
20void ColumnBindingResolver::VisitOperator(LogicalOperator &op) {
21 switch (op.type) {
22 case LogicalOperatorType::LOGICAL_ASOF_JOIN:
23 case LogicalOperatorType::LOGICAL_COMPARISON_JOIN:
24 case LogicalOperatorType::LOGICAL_DELIM_JOIN: {
25 // special case: comparison join
26 auto &comp_join = op.Cast<LogicalComparisonJoin>();
27 // first get the bindings of the LHS and resolve the LHS expressions
28 VisitOperator(op&: *comp_join.children[0]);
29 for (auto &cond : comp_join.conditions) {
30 VisitExpression(expression: &cond.left);
31 }
32 if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) {
33 // visit the duplicate eliminated columns on the LHS, if any
34 auto &delim_join = op.Cast<LogicalDelimJoin>();
35 for (auto &expr : delim_join.duplicate_eliminated_columns) {
36 VisitExpression(expression: &expr);
37 }
38 }
39 // then get the bindings of the RHS and resolve the RHS expressions
40 VisitOperator(op&: *comp_join.children[1]);
41 for (auto &cond : comp_join.conditions) {
42 VisitExpression(expression: &cond.right);
43 }
44 // finally update the bindings with the result bindings of the join
45 bindings = op.GetColumnBindings();
46 return;
47 }
48 case LogicalOperatorType::LOGICAL_ANY_JOIN: {
49 // ANY join, this join is different because we evaluate the expression on the bindings of BOTH join sides at
50 // once i.e. we set the bindings first to the bindings of the entire join, and then resolve the expressions of
51 // this operator
52 VisitOperatorChildren(op);
53 bindings = op.GetColumnBindings();
54 auto &any_join = op.Cast<LogicalAnyJoin>();
55 if (any_join.join_type == JoinType::SEMI || any_join.join_type == JoinType::ANTI) {
56 auto right_bindings = op.children[1]->GetColumnBindings();
57 bindings.insert(position: bindings.end(), first: right_bindings.begin(), last: right_bindings.end());
58 }
59 VisitOperatorExpressions(op);
60 return;
61 }
62 case LogicalOperatorType::LOGICAL_CREATE_INDEX: {
63 // CREATE INDEX statement, add the columns of the table with table index 0 to the binding set
64 // afterwards bind the expressions of the CREATE INDEX statement
65 auto &create_index = op.Cast<LogicalCreateIndex>();
66 bindings = LogicalOperator::GenerateColumnBindings(table_idx: 0, column_count: create_index.table.GetColumns().LogicalColumnCount());
67 VisitOperatorExpressions(op);
68 return;
69 }
70 case LogicalOperatorType::LOGICAL_GET: {
71 //! We first need to update the current set of bindings and then visit operator expressions
72 bindings = op.GetColumnBindings();
73 VisitOperatorExpressions(op);
74 return;
75 }
76 case LogicalOperatorType::LOGICAL_INSERT: {
77 //! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a
78 // ON CONFLICT DO UPDATE clause
79 auto &insert_op = op.Cast<LogicalInsert>();
80 if (insert_op.action_type != OnConflictAction::THROW) {
81 // Get the bindings from the children
82 VisitOperatorChildren(op);
83 auto column_count = insert_op.table.GetColumns().PhysicalColumnCount();
84 auto dummy_bindings = LogicalOperator::GenerateColumnBindings(table_idx: insert_op.excluded_table_index, column_count);
85 // Now insert our dummy bindings at the start of the bindings,
86 // so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns
87 bindings.insert(position: bindings.begin(), first: dummy_bindings.begin(), last: dummy_bindings.end());
88 if (insert_op.on_conflict_condition) {
89 VisitExpression(expression: &insert_op.on_conflict_condition);
90 }
91 if (insert_op.do_update_condition) {
92 VisitExpression(expression: &insert_op.do_update_condition);
93 }
94 VisitOperatorExpressions(op);
95 bindings = op.GetColumnBindings();
96 return;
97 }
98 }
99 default:
100 break;
101 }
102 // general case
103 // first visit the children of this operator
104 VisitOperatorChildren(op);
105 // now visit the expressions of this operator to resolve any bound column references
106 VisitOperatorExpressions(op);
107 // finally update the current set of bindings to the current set of column bindings
108 bindings = op.GetColumnBindings();
109}
110
111unique_ptr<Expression> ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr,
112 unique_ptr<Expression> *expr_ptr) {
113 D_ASSERT(expr.depth == 0);
114 // check the current set of column bindings to see which index corresponds to the column reference
115 for (idx_t i = 0; i < bindings.size(); i++) {
116 if (expr.binding == bindings[i]) {
117 return make_uniq<BoundReferenceExpression>(args&: expr.alias, args&: expr.return_type, args&: i);
118 }
119 }
120 // LCOV_EXCL_START
121 // could not bind the column reference, this should never happen and indicates a bug in the code
122 // generate an error message
123 string bound_columns = "[";
124 for (idx_t i = 0; i < bindings.size(); i++) {
125 if (i != 0) {
126 bound_columns += " ";
127 }
128 bound_columns += to_string(val: bindings[i].table_index) + "." + to_string(val: bindings[i].column_index);
129 }
130 bound_columns += "]";
131
132 throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.alias,
133 expr.binding.table_index, expr.binding.column_index, bound_columns);
134 // LCOV_EXCL_STOP
135}
136
137unordered_set<idx_t> ColumnBindingResolver::VerifyInternal(LogicalOperator &op) {
138 unordered_set<idx_t> result;
139 for (auto &child : op.children) {
140 auto child_indexes = VerifyInternal(op&: *child);
141 for (auto index : child_indexes) {
142 D_ASSERT(index != DConstants::INVALID_INDEX);
143 if (result.find(x: index) != result.end()) {
144 throw InternalException("Duplicate table index \"%lld\" found", index);
145 }
146 result.insert(x: index);
147 }
148 }
149 auto indexes = op.GetTableIndex();
150 for (auto index : indexes) {
151 D_ASSERT(index != DConstants::INVALID_INDEX);
152 if (result.find(x: index) != result.end()) {
153 throw InternalException("Duplicate table index \"%lld\" found", index);
154 }
155 result.insert(x: index);
156 }
157 return result;
158}
159
160void ColumnBindingResolver::Verify(LogicalOperator &op) {
161#ifdef DEBUG
162 VerifyInternal(op);
163#endif
164}
165
166} // namespace duckdb
167