| 1 | #include "duckdb/main/relation/update_relation.hpp" |
| 2 | #include "duckdb/parser/statement/update_statement.hpp" |
| 3 | #include "duckdb/planner/binder.hpp" |
| 4 | #include "duckdb/main/client_context.hpp" |
| 5 | #include "duckdb/parser/tableref/basetableref.hpp" |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | UpdateRelation::UpdateRelation(ClientContextWrapper &context, unique_ptr<ParsedExpression> condition_p, |
| 10 | string schema_name_p, string table_name_p, vector<string> update_columns_p, |
| 11 | vector<unique_ptr<ParsedExpression>> expressions_p) |
| 12 | : Relation(context, RelationType::UPDATE_RELATION), condition(std::move(condition_p)), |
| 13 | schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)), |
| 14 | update_columns(std::move(update_columns_p)), expressions(std::move(expressions_p)) { |
| 15 | D_ASSERT(update_columns.size() == expressions.size()); |
| 16 | context.GetContext()->TryBindRelation(relation&: *this, result_columns&: this->columns); |
| 17 | } |
| 18 | |
| 19 | BoundStatement UpdateRelation::Bind(Binder &binder) { |
| 20 | auto basetable = make_uniq<BaseTableRef>(); |
| 21 | basetable->schema_name = schema_name; |
| 22 | basetable->table_name = table_name; |
| 23 | |
| 24 | UpdateStatement stmt; |
| 25 | stmt.set_info = make_uniq<UpdateSetInfo>(); |
| 26 | |
| 27 | stmt.set_info->condition = condition ? condition->Copy() : nullptr; |
| 28 | stmt.table = std::move(basetable); |
| 29 | stmt.set_info->columns = update_columns; |
| 30 | for (auto &expr : expressions) { |
| 31 | stmt.set_info->expressions.push_back(x: expr->Copy()); |
| 32 | } |
| 33 | return binder.Bind(statement&: stmt.Cast<SQLStatement>()); |
| 34 | } |
| 35 | |
| 36 | const vector<ColumnDefinition> &UpdateRelation::Columns() { |
| 37 | return columns; |
| 38 | } |
| 39 | |
| 40 | string UpdateRelation::ToString(idx_t depth) { |
| 41 | string str = RenderWhitespace(depth) + "UPDATE " + table_name + " SET\n" ; |
| 42 | for (idx_t i = 0; i < expressions.size(); i++) { |
| 43 | str += update_columns[i] + " = " + expressions[i]->ToString() + "\n" ; |
| 44 | } |
| 45 | if (condition) { |
| 46 | str += "WHERE " + condition->ToString() + "\n" ; |
| 47 | } |
| 48 | return str; |
| 49 | } |
| 50 | |
| 51 | } // namespace duckdb |
| 52 | |