1 | #include "duckdb/execution/operator/persistent/physical_update.hpp" |
2 | |
3 | #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" |
4 | #include "duckdb/common/types/column/column_data_collection.hpp" |
5 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
6 | #include "duckdb/execution/expression_executor.hpp" |
7 | #include "duckdb/main/client_context.hpp" |
8 | #include "duckdb/parallel/thread_context.hpp" |
9 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
10 | #include "duckdb/storage/data_table.hpp" |
11 | |
12 | namespace duckdb { |
13 | |
14 | PhysicalUpdate::PhysicalUpdate(vector<LogicalType> types, TableCatalogEntry &tableref, DataTable &table, |
15 | vector<PhysicalIndex> columns, vector<unique_ptr<Expression>> expressions, |
16 | vector<unique_ptr<Expression>> bound_defaults, idx_t estimated_cardinality, |
17 | bool return_chunk) |
18 | : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), |
19 | table(table), columns(std::move(columns)), expressions(std::move(expressions)), |
20 | bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) { |
21 | } |
22 | |
23 | //===--------------------------------------------------------------------===// |
24 | // Sink |
25 | //===--------------------------------------------------------------------===// |
26 | class UpdateGlobalState : public GlobalSinkState { |
27 | public: |
28 | explicit UpdateGlobalState(ClientContext &context, const vector<LogicalType> &return_types) |
29 | : updated_count(0), return_collection(context, return_types) { |
30 | } |
31 | |
32 | mutex lock; |
33 | idx_t updated_count; |
34 | unordered_set<row_t> updated_columns; |
35 | ColumnDataCollection return_collection; |
36 | }; |
37 | |
38 | class UpdateLocalState : public LocalSinkState { |
39 | public: |
40 | UpdateLocalState(ClientContext &context, const vector<unique_ptr<Expression>> &expressions, |
41 | const vector<LogicalType> &table_types, const vector<unique_ptr<Expression>> &bound_defaults) |
42 | : default_executor(context, bound_defaults) { |
43 | // initialize the update chunk |
44 | auto &allocator = Allocator::Get(context); |
45 | vector<LogicalType> update_types; |
46 | update_types.reserve(n: expressions.size()); |
47 | for (auto &expr : expressions) { |
48 | update_types.push_back(x: expr->return_type); |
49 | } |
50 | update_chunk.Initialize(allocator, types: update_types); |
51 | // initialize the mock chunk |
52 | mock_chunk.Initialize(allocator, types: table_types); |
53 | } |
54 | |
55 | DataChunk update_chunk; |
56 | DataChunk mock_chunk; |
57 | ExpressionExecutor default_executor; |
58 | }; |
59 | |
60 | SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
61 | auto &gstate = input.global_state.Cast<UpdateGlobalState>(); |
62 | auto &lstate = input.local_state.Cast<UpdateLocalState>(); |
63 | |
64 | DataChunk &update_chunk = lstate.update_chunk; |
65 | DataChunk &mock_chunk = lstate.mock_chunk; |
66 | |
67 | chunk.Flatten(); |
68 | lstate.default_executor.SetChunk(chunk); |
69 | |
70 | // update data in the base table |
71 | // the row ids are given to us as the last column of the child chunk |
72 | auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; |
73 | update_chunk.Reset(); |
74 | update_chunk.SetCardinality(chunk); |
75 | |
76 | for (idx_t i = 0; i < expressions.size(); i++) { |
77 | if (expressions[i]->type == ExpressionType::VALUE_DEFAULT) { |
78 | // default expression, set to the default value of the column |
79 | lstate.default_executor.ExecuteExpression(expr_idx: columns[i].index, result&: update_chunk.data[i]); |
80 | } else { |
81 | D_ASSERT(expressions[i]->type == ExpressionType::BOUND_REF); |
82 | // index into child chunk |
83 | auto &binding = expressions[i]->Cast<BoundReferenceExpression>(); |
84 | update_chunk.data[i].Reference(other&: chunk.data[binding.index]); |
85 | } |
86 | } |
87 | |
88 | lock_guard<mutex> glock(gstate.lock); |
89 | if (update_is_del_and_insert) { |
90 | // index update or update on complex type, perform a delete and an append instead |
91 | |
92 | // figure out which rows have not yet been deleted in this update |
93 | // this is required since we might see the same row_id multiple times |
94 | // in the case of an UPDATE query that e.g. has joins |
95 | auto row_id_data = FlatVector::GetData<row_t>(vector&: row_ids); |
96 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
97 | idx_t update_count = 0; |
98 | for (idx_t i = 0; i < update_chunk.size(); i++) { |
99 | auto row_id = row_id_data[i]; |
100 | if (gstate.updated_columns.find(x: row_id) == gstate.updated_columns.end()) { |
101 | gstate.updated_columns.insert(x: row_id); |
102 | sel.set_index(idx: update_count++, loc: i); |
103 | } |
104 | } |
105 | if (update_count != update_chunk.size()) { |
106 | // we need to slice here |
107 | update_chunk.Slice(sel_vector: sel, count: update_count); |
108 | } |
109 | table.Delete(table&: tableref, context&: context.client, row_ids, count: update_chunk.size()); |
110 | // for the append we need to arrange the columns in a specific manner (namely the "standard table order") |
111 | mock_chunk.SetCardinality(update_chunk); |
112 | for (idx_t i = 0; i < columns.size(); i++) { |
113 | mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]); |
114 | } |
115 | table.LocalAppend(table&: tableref, context&: context.client, chunk&: mock_chunk); |
116 | } else { |
117 | if (return_chunk) { |
118 | mock_chunk.SetCardinality(update_chunk); |
119 | for (idx_t i = 0; i < columns.size(); i++) { |
120 | mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]); |
121 | } |
122 | } |
123 | table.Update(table&: tableref, context&: context.client, row_ids, column_ids: columns, data&: update_chunk); |
124 | } |
125 | |
126 | if (return_chunk) { |
127 | gstate.return_collection.Append(new_chunk&: mock_chunk); |
128 | } |
129 | |
130 | gstate.updated_count += chunk.size(); |
131 | |
132 | return SinkResultType::NEED_MORE_INPUT; |
133 | } |
134 | |
135 | unique_ptr<GlobalSinkState> PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const { |
136 | return make_uniq<UpdateGlobalState>(args&: context, args: GetTypes()); |
137 | } |
138 | |
139 | unique_ptr<LocalSinkState> PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { |
140 | return make_uniq<UpdateLocalState>(args&: context.client, args: expressions, args: table.GetTypes(), args: bound_defaults); |
141 | } |
142 | |
143 | void PhysicalUpdate::Combine(ExecutionContext &context, GlobalSinkState &gstate, LocalSinkState &lstate) const { |
144 | auto &state = lstate.Cast<UpdateLocalState>(); |
145 | auto &client_profiler = QueryProfiler::Get(context&: context.client); |
146 | context.thread.profiler.Flush(phys_op: *this, expression_executor&: state.default_executor, name: "default_executor" , id: 1); |
147 | client_profiler.Flush(profiler&: context.thread.profiler); |
148 | } |
149 | |
150 | //===--------------------------------------------------------------------===// |
151 | // Source |
152 | //===--------------------------------------------------------------------===// |
153 | class UpdateSourceState : public GlobalSourceState { |
154 | public: |
155 | explicit UpdateSourceState(const PhysicalUpdate &op) { |
156 | if (op.return_chunk) { |
157 | D_ASSERT(op.sink_state); |
158 | auto &g = op.sink_state->Cast<UpdateGlobalState>(); |
159 | g.return_collection.InitializeScan(state&: scan_state); |
160 | } |
161 | } |
162 | |
163 | ColumnDataScanState scan_state; |
164 | }; |
165 | |
166 | unique_ptr<GlobalSourceState> PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const { |
167 | return make_uniq<UpdateSourceState>(args: *this); |
168 | } |
169 | |
170 | SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, |
171 | OperatorSourceInput &input) const { |
172 | auto &state = input.global_state.Cast<UpdateSourceState>(); |
173 | auto &g = sink_state->Cast<UpdateGlobalState>(); |
174 | if (!return_chunk) { |
175 | chunk.SetCardinality(1); |
176 | chunk.SetValue(col_idx: 0, index: 0, val: Value::BIGINT(value: g.updated_count)); |
177 | return SourceResultType::FINISHED; |
178 | } |
179 | |
180 | g.return_collection.Scan(state&: state.scan_state, result&: chunk); |
181 | |
182 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
183 | } |
184 | |
185 | } // namespace duckdb |
186 | |