1#include "duckdb/execution/operator/persistent/physical_update.hpp"
2
3#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
4#include "duckdb/common/types/column/column_data_collection.hpp"
5#include "duckdb/common/vector_operations/vector_operations.hpp"
6#include "duckdb/execution/expression_executor.hpp"
7#include "duckdb/main/client_context.hpp"
8#include "duckdb/parallel/thread_context.hpp"
9#include "duckdb/planner/expression/bound_reference_expression.hpp"
10#include "duckdb/storage/data_table.hpp"
11
12namespace duckdb {
13
14PhysicalUpdate::PhysicalUpdate(vector<LogicalType> types, TableCatalogEntry &tableref, DataTable &table,
15 vector<PhysicalIndex> columns, vector<unique_ptr<Expression>> expressions,
16 vector<unique_ptr<Expression>> bound_defaults, idx_t estimated_cardinality,
17 bool return_chunk)
18 : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref),
19 table(table), columns(std::move(columns)), expressions(std::move(expressions)),
20 bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) {
21}
22
23//===--------------------------------------------------------------------===//
24// Sink
25//===--------------------------------------------------------------------===//
26class UpdateGlobalState : public GlobalSinkState {
27public:
28 explicit UpdateGlobalState(ClientContext &context, const vector<LogicalType> &return_types)
29 : updated_count(0), return_collection(context, return_types) {
30 }
31
32 mutex lock;
33 idx_t updated_count;
34 unordered_set<row_t> updated_columns;
35 ColumnDataCollection return_collection;
36};
37
38class UpdateLocalState : public LocalSinkState {
39public:
40 UpdateLocalState(ClientContext &context, const vector<unique_ptr<Expression>> &expressions,
41 const vector<LogicalType> &table_types, const vector<unique_ptr<Expression>> &bound_defaults)
42 : default_executor(context, bound_defaults) {
43 // initialize the update chunk
44 auto &allocator = Allocator::Get(context);
45 vector<LogicalType> update_types;
46 update_types.reserve(n: expressions.size());
47 for (auto &expr : expressions) {
48 update_types.push_back(x: expr->return_type);
49 }
50 update_chunk.Initialize(allocator, types: update_types);
51 // initialize the mock chunk
52 mock_chunk.Initialize(allocator, types: table_types);
53 }
54
55 DataChunk update_chunk;
56 DataChunk mock_chunk;
57 ExpressionExecutor default_executor;
58};
59
60SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const {
61 auto &gstate = input.global_state.Cast<UpdateGlobalState>();
62 auto &lstate = input.local_state.Cast<UpdateLocalState>();
63
64 DataChunk &update_chunk = lstate.update_chunk;
65 DataChunk &mock_chunk = lstate.mock_chunk;
66
67 chunk.Flatten();
68 lstate.default_executor.SetChunk(chunk);
69
70 // update data in the base table
71 // the row ids are given to us as the last column of the child chunk
72 auto &row_ids = chunk.data[chunk.ColumnCount() - 1];
73 update_chunk.Reset();
74 update_chunk.SetCardinality(chunk);
75
76 for (idx_t i = 0; i < expressions.size(); i++) {
77 if (expressions[i]->type == ExpressionType::VALUE_DEFAULT) {
78 // default expression, set to the default value of the column
79 lstate.default_executor.ExecuteExpression(expr_idx: columns[i].index, result&: update_chunk.data[i]);
80 } else {
81 D_ASSERT(expressions[i]->type == ExpressionType::BOUND_REF);
82 // index into child chunk
83 auto &binding = expressions[i]->Cast<BoundReferenceExpression>();
84 update_chunk.data[i].Reference(other&: chunk.data[binding.index]);
85 }
86 }
87
88 lock_guard<mutex> glock(gstate.lock);
89 if (update_is_del_and_insert) {
90 // index update or update on complex type, perform a delete and an append instead
91
92 // figure out which rows have not yet been deleted in this update
93 // this is required since we might see the same row_id multiple times
94 // in the case of an UPDATE query that e.g. has joins
95 auto row_id_data = FlatVector::GetData<row_t>(vector&: row_ids);
96 SelectionVector sel(STANDARD_VECTOR_SIZE);
97 idx_t update_count = 0;
98 for (idx_t i = 0; i < update_chunk.size(); i++) {
99 auto row_id = row_id_data[i];
100 if (gstate.updated_columns.find(x: row_id) == gstate.updated_columns.end()) {
101 gstate.updated_columns.insert(x: row_id);
102 sel.set_index(idx: update_count++, loc: i);
103 }
104 }
105 if (update_count != update_chunk.size()) {
106 // we need to slice here
107 update_chunk.Slice(sel_vector: sel, count: update_count);
108 }
109 table.Delete(table&: tableref, context&: context.client, row_ids, count: update_chunk.size());
110 // for the append we need to arrange the columns in a specific manner (namely the "standard table order")
111 mock_chunk.SetCardinality(update_chunk);
112 for (idx_t i = 0; i < columns.size(); i++) {
113 mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]);
114 }
115 table.LocalAppend(table&: tableref, context&: context.client, chunk&: mock_chunk);
116 } else {
117 if (return_chunk) {
118 mock_chunk.SetCardinality(update_chunk);
119 for (idx_t i = 0; i < columns.size(); i++) {
120 mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]);
121 }
122 }
123 table.Update(table&: tableref, context&: context.client, row_ids, column_ids: columns, data&: update_chunk);
124 }
125
126 if (return_chunk) {
127 gstate.return_collection.Append(new_chunk&: mock_chunk);
128 }
129
130 gstate.updated_count += chunk.size();
131
132 return SinkResultType::NEED_MORE_INPUT;
133}
134
135unique_ptr<GlobalSinkState> PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const {
136 return make_uniq<UpdateGlobalState>(args&: context, args: GetTypes());
137}
138
139unique_ptr<LocalSinkState> PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const {
140 return make_uniq<UpdateLocalState>(args&: context.client, args: expressions, args: table.GetTypes(), args: bound_defaults);
141}
142
143void PhysicalUpdate::Combine(ExecutionContext &context, GlobalSinkState &gstate, LocalSinkState &lstate) const {
144 auto &state = lstate.Cast<UpdateLocalState>();
145 auto &client_profiler = QueryProfiler::Get(context&: context.client);
146 context.thread.profiler.Flush(phys_op: *this, expression_executor&: state.default_executor, name: "default_executor", id: 1);
147 client_profiler.Flush(profiler&: context.thread.profiler);
148}
149
150//===--------------------------------------------------------------------===//
151// Source
152//===--------------------------------------------------------------------===//
153class UpdateSourceState : public GlobalSourceState {
154public:
155 explicit UpdateSourceState(const PhysicalUpdate &op) {
156 if (op.return_chunk) {
157 D_ASSERT(op.sink_state);
158 auto &g = op.sink_state->Cast<UpdateGlobalState>();
159 g.return_collection.InitializeScan(state&: scan_state);
160 }
161 }
162
163 ColumnDataScanState scan_state;
164};
165
166unique_ptr<GlobalSourceState> PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const {
167 return make_uniq<UpdateSourceState>(args: *this);
168}
169
170SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk,
171 OperatorSourceInput &input) const {
172 auto &state = input.global_state.Cast<UpdateSourceState>();
173 auto &g = sink_state->Cast<UpdateGlobalState>();
174 if (!return_chunk) {
175 chunk.SetCardinality(1);
176 chunk.SetValue(col_idx: 0, index: 0, val: Value::BIGINT(value: g.updated_count));
177 return SourceResultType::FINISHED;
178 }
179
180 g.return_collection.Scan(state&: state.scan_state, result&: chunk);
181
182 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
183}
184
185} // namespace duckdb
186