1#include "duckdb/parser/expression/columnref_expression.hpp"
2#include "duckdb/parser/statement/update_statement.hpp"
3#include "duckdb/planner/binder.hpp"
4#include "duckdb/planner/tableref/bound_joinref.hpp"
5#include "duckdb/planner/bound_tableref.hpp"
6#include "duckdb/planner/constraints/bound_check_constraint.hpp"
7#include "duckdb/planner/expression/bound_columnref_expression.hpp"
8#include "duckdb/planner/expression/bound_default_expression.hpp"
9#include "duckdb/planner/expression_binder/update_binder.hpp"
10#include "duckdb/planner/expression_binder/where_binder.hpp"
11#include "duckdb/planner/operator/logical_filter.hpp"
12#include "duckdb/planner/operator/logical_get.hpp"
13#include "duckdb/planner/operator/logical_projection.hpp"
14#include "duckdb/planner/operator/logical_update.hpp"
15#include "duckdb/planner/tableref/bound_basetableref.hpp"
16#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
17#include "duckdb/storage/data_table.hpp"
18
19#include <algorithm>
20
21namespace duckdb {
22
23// This creates a LogicalProjection and moves 'root' into it as a child
24// unless there are no expressions to project, in which case it just returns 'root'
25unique_ptr<LogicalOperator> Binder::BindUpdateSet(LogicalOperator &op, unique_ptr<LogicalOperator> root,
26 UpdateSetInfo &set_info, TableCatalogEntry &table,
27 vector<PhysicalIndex> &columns) {
28 auto proj_index = GenerateTableIndex();
29
30 vector<unique_ptr<Expression>> projection_expressions;
31 D_ASSERT(set_info.columns.size() == set_info.expressions.size());
32 for (idx_t i = 0; i < set_info.columns.size(); i++) {
33 auto &colname = set_info.columns[i];
34 auto &expr = set_info.expressions[i];
35 if (!table.ColumnExists(name: colname)) {
36 throw BinderException("Referenced update column %s not found in table!", colname);
37 }
38 auto &column = table.GetColumn(name: colname);
39 if (column.Generated()) {
40 throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name());
41 }
42 if (std::find(first: columns.begin(), last: columns.end(), val: column.Physical()) != columns.end()) {
43 throw BinderException("Multiple assignments to same column \"%s\"", colname);
44 }
45 columns.push_back(x: column.Physical());
46 if (expr->type == ExpressionType::VALUE_DEFAULT) {
47 op.expressions.push_back(x: make_uniq<BoundDefaultExpression>(args: column.Type()));
48 } else {
49 UpdateBinder binder(*this, context);
50 binder.target_type = column.Type();
51 auto bound_expr = binder.Bind(expr);
52 PlanSubqueries(expr&: bound_expr, root);
53
54 op.expressions.push_back(x: make_uniq<BoundColumnRefExpression>(
55 args&: bound_expr->return_type, args: ColumnBinding(proj_index, projection_expressions.size())));
56 projection_expressions.push_back(x: std::move(bound_expr));
57 }
58 }
59 if (op.type != LogicalOperatorType::LOGICAL_UPDATE && projection_expressions.empty()) {
60 return root;
61 }
62 // now create the projection
63 auto proj = make_uniq<LogicalProjection>(args&: proj_index, args: std::move(projection_expressions));
64 proj->AddChild(child: std::move(root));
65 return unique_ptr_cast<LogicalProjection, LogicalOperator>(src: std::move(proj));
66}
67
68BoundStatement Binder::Bind(UpdateStatement &stmt) {
69 BoundStatement result;
70 unique_ptr<LogicalOperator> root;
71
72 // visit the table reference
73 auto bound_table = Bind(ref&: *stmt.table);
74 if (bound_table->type != TableReferenceType::BASE_TABLE) {
75 throw BinderException("Can only update base table!");
76 }
77 auto &table_binding = bound_table->Cast<BoundBaseTableRef>();
78 auto &table = table_binding.table;
79
80 // Add CTEs as bindable
81 AddCTEMap(cte_map&: stmt.cte_map);
82
83 optional_ptr<LogicalGet> get;
84 if (stmt.from_table) {
85 auto from_binder = Binder::CreateBinder(context, parent: this);
86 BoundJoinRef bound_crossproduct(JoinRefType::CROSS);
87 bound_crossproduct.left = std::move(bound_table);
88 bound_crossproduct.right = from_binder->Bind(ref&: *stmt.from_table);
89 root = CreatePlan(ref&: bound_crossproduct);
90 get = &root->children[0]->Cast<LogicalGet>();
91 bind_context.AddContext(other: std::move(from_binder->bind_context));
92 } else {
93 root = CreatePlan(ref&: *bound_table);
94 get = &root->Cast<LogicalGet>();
95 }
96
97 if (!table.temporary) {
98 // update of persistent table: not read only!
99 properties.modified_databases.insert(x: table.catalog.GetName());
100 }
101 auto update = make_uniq<LogicalUpdate>(args&: table);
102
103 // set return_chunk boolean early because it needs uses update_is_del_and_insert logic
104 if (!stmt.returning_list.empty()) {
105 update->return_chunk = true;
106 }
107 // bind the default values
108 BindDefaultValues(columns: table.GetColumns(), bound_defaults&: update->bound_defaults);
109
110 // project any additional columns required for the condition/expressions
111 if (stmt.set_info->condition) {
112 WhereBinder binder(*this, context);
113 auto condition = binder.Bind(expr&: stmt.set_info->condition);
114
115 PlanSubqueries(expr&: condition, root);
116 auto filter = make_uniq<LogicalFilter>(args: std::move(condition));
117 filter->AddChild(child: std::move(root));
118 root = std::move(filter);
119 }
120
121 D_ASSERT(stmt.set_info);
122 D_ASSERT(stmt.set_info->columns.size() == stmt.set_info->expressions.size());
123
124 auto proj_tmp = BindUpdateSet(op&: *update, root: std::move(root), set_info&: *stmt.set_info, table, columns&: update->columns);
125 D_ASSERT(proj_tmp->type == LogicalOperatorType::LOGICAL_PROJECTION);
126 auto proj = unique_ptr_cast<LogicalOperator, LogicalProjection>(src: std::move(proj_tmp));
127
128 // bind any extra columns necessary for CHECK constraints or indexes
129 table.BindUpdateConstraints(get&: *get, proj&: *proj, update&: *update, context);
130
131 // finally add the row id column to the projection list
132 proj->expressions.push_back(x: make_uniq<BoundColumnRefExpression>(
133 args: LogicalType::ROW_TYPE, args: ColumnBinding(get->table_index, get->column_ids.size())));
134 get->column_ids.push_back(x: COLUMN_IDENTIFIER_ROW_ID);
135
136 // set the projection as child of the update node and finalize the result
137 update->AddChild(child: std::move(proj));
138
139 auto update_table_index = GenerateTableIndex();
140 update->table_index = update_table_index;
141 if (!stmt.returning_list.empty()) {
142 unique_ptr<LogicalOperator> update_as_logicaloperator = std::move(update);
143
144 return BindReturning(returning_list: std::move(stmt.returning_list), table, alias: stmt.table->alias, update_table_index,
145 child_operator: std::move(update_as_logicaloperator), result: std::move(result));
146 }
147
148 result.names = {"Count"};
149 result.types = {LogicalType::BIGINT};
150 result.plan = std::move(update);
151 properties.allow_stream_result = false;
152 properties.return_type = StatementReturnType::CHANGED_ROWS;
153 return result;
154}
155
156} // namespace duckdb
157