| 1 | #include "duckdb/planner/table_binding.hpp" |
| 2 | |
| 3 | #include "duckdb/common/string_util.hpp" |
| 4 | #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" |
| 5 | #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" |
| 6 | #include "duckdb/parser/expression/columnref_expression.hpp" |
| 7 | #include "duckdb/parser/tableref/subqueryref.hpp" |
| 8 | #include "duckdb/planner/bind_context.hpp" |
| 9 | #include "duckdb/planner/bound_query_node.hpp" |
| 10 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 11 | #include "duckdb/planner/expression/bound_lambdaref_expression.hpp" |
| 12 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
| 13 | |
| 14 | #include <algorithm> |
| 15 | |
| 16 | namespace duckdb { |
| 17 | |
| 18 | Binding::Binding(BindingType binding_type, const string &alias, vector<LogicalType> coltypes, vector<string> colnames, |
| 19 | idx_t index) |
| 20 | : binding_type(binding_type), alias(alias), index(index), types(std::move(coltypes)), names(std::move(colnames)) { |
| 21 | D_ASSERT(types.size() == names.size()); |
| 22 | for (idx_t i = 0; i < names.size(); i++) { |
| 23 | auto &name = names[i]; |
| 24 | D_ASSERT(!name.empty()); |
| 25 | if (name_map.find(x: name) != name_map.end()) { |
| 26 | throw BinderException("table \"%s\" has duplicate column name \"%s\"" , alias, name); |
| 27 | } |
| 28 | name_map[name] = i; |
| 29 | } |
| 30 | } |
| 31 | |
| 32 | bool Binding::TryGetBindingIndex(const string &column_name, column_t &result) { |
| 33 | auto entry = name_map.find(x: column_name); |
| 34 | if (entry == name_map.end()) { |
| 35 | return false; |
| 36 | } |
| 37 | auto column_info = entry->second; |
| 38 | result = column_info; |
| 39 | return true; |
| 40 | } |
| 41 | |
| 42 | column_t Binding::GetBindingIndex(const string &column_name) { |
| 43 | column_t result; |
| 44 | if (!TryGetBindingIndex(column_name, result)) { |
| 45 | throw InternalException("Binding index for column \"%s\" not found" , column_name); |
| 46 | } |
| 47 | return result; |
| 48 | } |
| 49 | |
| 50 | bool Binding::HasMatchingBinding(const string &column_name) { |
| 51 | column_t result; |
| 52 | return TryGetBindingIndex(column_name, result); |
| 53 | } |
| 54 | |
| 55 | string Binding::ColumnNotFoundError(const string &column_name) const { |
| 56 | return StringUtil::Format(fmt_str: "Values list \"%s\" does not have a column named \"%s\"" , params: alias, params: column_name); |
| 57 | } |
| 58 | |
| 59 | BindResult Binding::Bind(ColumnRefExpression &colref, idx_t depth) { |
| 60 | column_t column_index; |
| 61 | bool success = false; |
| 62 | success = TryGetBindingIndex(column_name: colref.GetColumnName(), result&: column_index); |
| 63 | if (!success) { |
| 64 | return BindResult(ColumnNotFoundError(column_name: colref.GetColumnName())); |
| 65 | } |
| 66 | ColumnBinding binding; |
| 67 | binding.table_index = index; |
| 68 | binding.column_index = column_index; |
| 69 | LogicalType sql_type = types[column_index]; |
| 70 | if (colref.alias.empty()) { |
| 71 | colref.alias = names[column_index]; |
| 72 | } |
| 73 | return BindResult(make_uniq<BoundColumnRefExpression>(args: colref.GetName(), args&: sql_type, args&: binding, args&: depth)); |
| 74 | } |
| 75 | |
| 76 | optional_ptr<StandardEntry> Binding::GetStandardEntry() { |
| 77 | return nullptr; |
| 78 | } |
| 79 | |
| 80 | EntryBinding::EntryBinding(const string &alias, vector<LogicalType> types_p, vector<string> names_p, idx_t index, |
| 81 | StandardEntry &entry) |
| 82 | : Binding(BindingType::CATALOG_ENTRY, alias, std::move(types_p), std::move(names_p), index), entry(entry) { |
| 83 | } |
| 84 | |
| 85 | optional_ptr<StandardEntry> EntryBinding::GetStandardEntry() { |
| 86 | return &entry; |
| 87 | } |
| 88 | |
| 89 | TableBinding::TableBinding(const string &alias, vector<LogicalType> types_p, vector<string> names_p, |
| 90 | vector<column_t> &bound_column_ids, optional_ptr<StandardEntry> entry, idx_t index, |
| 91 | bool add_row_id) |
| 92 | : Binding(BindingType::TABLE, alias, std::move(types_p), std::move(names_p), index), |
| 93 | bound_column_ids(bound_column_ids), entry(entry) { |
| 94 | if (add_row_id) { |
| 95 | if (name_map.find(x: "rowid" ) == name_map.end()) { |
| 96 | name_map["rowid" ] = COLUMN_IDENTIFIER_ROW_ID; |
| 97 | } |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | static void ReplaceAliases(ParsedExpression &expr, const ColumnList &list, |
| 102 | const unordered_map<idx_t, string> &alias_map) { |
| 103 | if (expr.type == ExpressionType::COLUMN_REF) { |
| 104 | auto &colref = expr.Cast<ColumnRefExpression>(); |
| 105 | D_ASSERT(!colref.IsQualified()); |
| 106 | auto &col_names = colref.column_names; |
| 107 | D_ASSERT(col_names.size() == 1); |
| 108 | auto idx_entry = list.GetColumnIndex(column_name&: col_names[0]); |
| 109 | auto &alias = alias_map.at(k: idx_entry.index); |
| 110 | col_names = {alias}; |
| 111 | } |
| 112 | ParsedExpressionIterator::EnumerateChildren( |
| 113 | expr, callback: [&](const ParsedExpression &child) { ReplaceAliases(expr&: (ParsedExpression &)child, list, alias_map); }); |
| 114 | } |
| 115 | |
| 116 | static void BakeTableName(ParsedExpression &expr, const string &table_name) { |
| 117 | if (expr.type == ExpressionType::COLUMN_REF) { |
| 118 | auto &colref = expr.Cast<ColumnRefExpression>(); |
| 119 | D_ASSERT(!colref.IsQualified()); |
| 120 | auto &col_names = colref.column_names; |
| 121 | col_names.insert(position: col_names.begin(), x: table_name); |
| 122 | } |
| 123 | ParsedExpressionIterator::EnumerateChildren( |
| 124 | expr, callback: [&](const ParsedExpression &child) { BakeTableName(expr&: (ParsedExpression &)child, table_name); }); |
| 125 | } |
| 126 | |
| 127 | unique_ptr<ParsedExpression> TableBinding::ExpandGeneratedColumn(const string &column_name) { |
| 128 | auto catalog_entry = GetStandardEntry(); |
| 129 | D_ASSERT(catalog_entry); // Should only be called on a TableBinding |
| 130 | |
| 131 | D_ASSERT(catalog_entry->type == CatalogType::TABLE_ENTRY); |
| 132 | auto &table_entry = catalog_entry->Cast<TableCatalogEntry>(); |
| 133 | |
| 134 | // Get the index of the generated column |
| 135 | auto column_index = GetBindingIndex(column_name); |
| 136 | D_ASSERT(table_entry.GetColumn(LogicalIndex(column_index)).Generated()); |
| 137 | // Get a copy of the generated column |
| 138 | auto expression = table_entry.GetColumn(idx: LogicalIndex(column_index)).GeneratedExpression().Copy(); |
| 139 | unordered_map<idx_t, string> alias_map; |
| 140 | for (auto &entry : name_map) { |
| 141 | alias_map[entry.second] = entry.first; |
| 142 | } |
| 143 | ReplaceAliases(expr&: *expression, list: table_entry.GetColumns(), alias_map); |
| 144 | BakeTableName(expr&: *expression, table_name: alias); |
| 145 | return (expression); |
| 146 | } |
| 147 | |
| 148 | const vector<column_t> &TableBinding::GetBoundColumnIds() const { |
| 149 | #ifdef DEBUG |
| 150 | unordered_set<column_t> column_ids; |
| 151 | for (auto &id : bound_column_ids) { |
| 152 | auto result = column_ids.insert(id); |
| 153 | // assert that all entries in the bound_column_ids are unique |
| 154 | D_ASSERT(result.second); |
| 155 | auto it = std::find_if(name_map.begin(), name_map.end(), |
| 156 | [&](const std::pair<const string, column_t> &it) { return it.second == id; }); |
| 157 | // assert that every id appears in the name_map |
| 158 | D_ASSERT(it != name_map.end()); |
| 159 | // the order that they appear in is not guaranteed to be sequential |
| 160 | } |
| 161 | #endif |
| 162 | return bound_column_ids; |
| 163 | } |
| 164 | |
| 165 | ColumnBinding TableBinding::GetColumnBinding(column_t column_index) { |
| 166 | auto &column_ids = bound_column_ids; |
| 167 | ColumnBinding binding; |
| 168 | |
| 169 | // Locate the column_id that matches the 'column_index' |
| 170 | auto it = std::find_if(first: column_ids.begin(), last: column_ids.end(), |
| 171 | pred: [&](const column_t &id) -> bool { return id == column_index; }); |
| 172 | // Get the index of it |
| 173 | binding.column_index = std::distance(first: column_ids.begin(), last: it); |
| 174 | // If it wasn't found, add it |
| 175 | if (it == column_ids.end()) { |
| 176 | column_ids.push_back(x: column_index); |
| 177 | } |
| 178 | |
| 179 | binding.table_index = index; |
| 180 | return binding; |
| 181 | } |
| 182 | |
| 183 | BindResult TableBinding::Bind(ColumnRefExpression &colref, idx_t depth) { |
| 184 | auto &column_name = colref.GetColumnName(); |
| 185 | column_t column_index; |
| 186 | bool success = false; |
| 187 | success = TryGetBindingIndex(column_name, result&: column_index); |
| 188 | if (!success) { |
| 189 | return BindResult(ColumnNotFoundError(column_name)); |
| 190 | } |
| 191 | auto entry = GetStandardEntry(); |
| 192 | if (entry && column_index != COLUMN_IDENTIFIER_ROW_ID) { |
| 193 | D_ASSERT(entry->type == CatalogType::TABLE_ENTRY); |
| 194 | // Either there is no table, or the columns category has to be standard |
| 195 | auto &table_entry = entry->Cast<TableCatalogEntry>(); |
| 196 | auto &column_entry = table_entry.GetColumn(idx: LogicalIndex(column_index)); |
| 197 | (void)table_entry; |
| 198 | (void)column_entry; |
| 199 | D_ASSERT(column_entry.Category() == TableColumnType::STANDARD); |
| 200 | } |
| 201 | // fetch the type of the column |
| 202 | LogicalType col_type; |
| 203 | if (column_index == COLUMN_IDENTIFIER_ROW_ID) { |
| 204 | // row id: BIGINT type |
| 205 | col_type = LogicalType::BIGINT; |
| 206 | } else { |
| 207 | // normal column: fetch type from base column |
| 208 | col_type = types[column_index]; |
| 209 | if (colref.alias.empty()) { |
| 210 | colref.alias = names[column_index]; |
| 211 | } |
| 212 | } |
| 213 | ColumnBinding binding = GetColumnBinding(column_index); |
| 214 | return BindResult(make_uniq<BoundColumnRefExpression>(args: colref.GetName(), args&: col_type, args&: binding, args&: depth)); |
| 215 | } |
| 216 | |
| 217 | optional_ptr<StandardEntry> TableBinding::GetStandardEntry() { |
| 218 | return entry; |
| 219 | } |
| 220 | |
| 221 | string TableBinding::ColumnNotFoundError(const string &column_name) const { |
| 222 | return StringUtil::Format(fmt_str: "Table \"%s\" does not have a column named \"%s\"" , params: alias, params: column_name); |
| 223 | } |
| 224 | |
| 225 | DummyBinding::DummyBinding(vector<LogicalType> types_p, vector<string> names_p, string dummy_name_p) |
| 226 | : Binding(BindingType::DUMMY, DummyBinding::DUMMY_NAME + dummy_name_p, std::move(types_p), std::move(names_p), |
| 227 | DConstants::INVALID_INDEX), |
| 228 | dummy_name(std::move(dummy_name_p)) { |
| 229 | } |
| 230 | |
| 231 | BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t depth) { |
| 232 | column_t column_index; |
| 233 | if (!TryGetBindingIndex(column_name: colref.GetColumnName(), result&: column_index)) { |
| 234 | throw InternalException("Column %s not found in bindings" , colref.GetColumnName()); |
| 235 | } |
| 236 | ColumnBinding binding(index, column_index); |
| 237 | |
| 238 | // we are binding a parameter to create the dummy binding, no arguments are supplied |
| 239 | return BindResult(make_uniq<BoundColumnRefExpression>(args: colref.GetName(), args&: types[column_index], args&: binding, args&: depth)); |
| 240 | } |
| 241 | |
| 242 | BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t lambda_index, idx_t depth) { |
| 243 | column_t column_index; |
| 244 | if (!TryGetBindingIndex(column_name: colref.GetColumnName(), result&: column_index)) { |
| 245 | throw InternalException("Column %s not found in bindings" , colref.GetColumnName()); |
| 246 | } |
| 247 | ColumnBinding binding(index, column_index); |
| 248 | return BindResult( |
| 249 | make_uniq<BoundLambdaRefExpression>(args: colref.GetName(), args&: types[column_index], args&: binding, args&: lambda_index, args&: depth)); |
| 250 | } |
| 251 | |
| 252 | unique_ptr<ParsedExpression> DummyBinding::ParamToArg(ColumnRefExpression &colref) { |
| 253 | column_t column_index; |
| 254 | if (!TryGetBindingIndex(column_name: colref.GetColumnName(), result&: column_index)) { |
| 255 | throw InternalException("Column %s not found in macro" , colref.GetColumnName()); |
| 256 | } |
| 257 | auto arg = (*arguments)[column_index]->Copy(); |
| 258 | arg->alias = colref.alias; |
| 259 | return arg; |
| 260 | } |
| 261 | |
| 262 | } // namespace duckdb |
| 263 | |