| 1 | #include "duckdb/main/relation/update_relation.hpp" |
| 2 | #include "duckdb/parser/statement/update_statement.hpp" |
| 3 | #include "duckdb/planner/binder.hpp" |
| 4 | #include "duckdb/main/client_context.hpp" |
| 5 | #include "duckdb/parser/tableref/basetableref.hpp" |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | UpdateRelation::UpdateRelation(ClientContext &context, unique_ptr<ParsedExpression> condition_p, string schema_name_p, |
| 10 | string table_name_p, vector<string> update_columns_p, |
| 11 | vector<unique_ptr<ParsedExpression>> expressions_p) |
| 12 | : Relation(context, RelationType::UPDATE_RELATION), condition(move(condition_p)), schema_name(move(schema_name_p)), |
| 13 | table_name(move(table_name_p)), update_columns(move(update_columns_p)), expressions(move(expressions_p)) { |
| 14 | assert(update_columns.size() == expressions.size()); |
| 15 | context.TryBindRelation(*this, this->columns); |
| 16 | } |
| 17 | |
| 18 | unique_ptr<QueryNode> UpdateRelation::GetQueryNode() { |
| 19 | throw InternalException("Cannot create a query node from a UpdateRelation!" ); |
| 20 | } |
| 21 | |
| 22 | BoundStatement UpdateRelation::Bind(Binder &binder) { |
| 23 | auto basetable = make_unique<BaseTableRef>(); |
| 24 | basetable->schema_name = schema_name; |
| 25 | basetable->table_name = table_name; |
| 26 | |
| 27 | UpdateStatement stmt; |
| 28 | stmt.condition = condition ? condition->Copy() : nullptr; |
| 29 | stmt.table = move(basetable); |
| 30 | stmt.columns = update_columns; |
| 31 | for (auto &expr : expressions) { |
| 32 | stmt.expressions.push_back(expr->Copy()); |
| 33 | } |
| 34 | return binder.Bind((SQLStatement &)stmt); |
| 35 | } |
| 36 | |
| 37 | const vector<ColumnDefinition> &UpdateRelation::Columns() { |
| 38 | return columns; |
| 39 | } |
| 40 | |
| 41 | string UpdateRelation::ToString(idx_t depth) { |
| 42 | string str = RenderWhitespace(depth) + "UPDATE " + table_name + " SET\n" ; |
| 43 | for (idx_t i = 0; i < expressions.size(); i++) { |
| 44 | str += update_columns[i] + " = " + expressions[i]->ToString() + "\n" ; |
| 45 | } |
| 46 | if (condition) { |
| 47 | str += "WHERE " + condition->ToString() + "\n" ; |
| 48 | } |
| 49 | return str; |
| 50 | } |
| 51 | |
| 52 | } // namespace duckdb |
| 53 | |