1#include "catch.hpp"
2#include "duckdb/common/value_operations/value_operations.hpp"
3#include "test_helpers.hpp"
4
5#include <atomic>
6#include <random>
7#include <thread>
8
9using namespace duckdb;
10using namespace std;
11
12namespace test_concurrent_delete {
13
14static constexpr int THREAD_COUNT = 10;
15static constexpr int INSERT_ELEMENTS = 100;
16
17TEST_CASE("Single thread delete", "[transactions]") {
18 unique_ptr<QueryResult> result;
19 DuckDB db(nullptr);
20 Connection con(db);
21 vector<unique_ptr<Connection>> connections;
22
23 // initialize the database
24 con.Query("CREATE TABLE integers(i INTEGER);");
25 int sum = 0;
26 for (size_t i = 0; i < INSERT_ELEMENTS; i++) {
27 for (size_t j = 0; j < 10; j++) {
28 con.Query("INSERT INTO integers VALUES (" + to_string(j + 1) + ");");
29 sum += j + 1;
30 }
31 }
32
33 // check the sum
34 result = con.Query("SELECT SUM(i) FROM integers");
35 REQUIRE(CHECK_COLUMN(result, 0, {sum}));
36
37 // simple delete, we should delete INSERT_ELEMENTS elements
38 result = con.Query("DELETE FROM integers WHERE i=2");
39 REQUIRE(CHECK_COLUMN(result, 0, {INSERT_ELEMENTS}));
40
41 // check sum again
42 result = con.Query("SELECT SUM(i) FROM integers");
43 REQUIRE(CHECK_COLUMN(result, 0, {sum - 2 * INSERT_ELEMENTS}));
44}
45
46TEST_CASE("Sequential delete", "[transactions]") {
47 unique_ptr<MaterializedQueryResult> result;
48 DuckDB db(nullptr);
49 Connection con(db);
50 vector<unique_ptr<Connection>> connections;
51 Value count;
52
53 // initialize the database
54 con.Query("CREATE TABLE integers(i INTEGER);");
55
56 int sum = 0;
57 for (size_t i = 0; i < INSERT_ELEMENTS; i++) {
58 for (size_t j = 0; j < 10; j++) {
59 con.Query("INSERT INTO integers VALUES (" + to_string(j + 1) + ");");
60 sum += j + 1;
61 }
62 }
63
64 for (size_t i = 0; i < THREAD_COUNT; i++) {
65 connections.push_back(make_unique<Connection>(db));
66 connections[i]->Query("BEGIN TRANSACTION;");
67 }
68
69 for (size_t i = 0; i < THREAD_COUNT; i++) {
70 // check the current count
71 result = connections[i]->Query("SELECT SUM(i) FROM integers");
72 REQUIRE_NO_FAIL(*result);
73 count = result->collection.chunks[0]->GetValue(0, 0);
74 REQUIRE(count == sum);
75 // delete the elements for this thread
76 REQUIRE_NO_FAIL(connections[i]->Query("DELETE FROM integers WHERE i=" + to_string(i + 1)));
77 // check the updated count
78 result = connections[i]->Query("SELECT SUM(i) FROM integers");
79 REQUIRE_NO_FAIL(*result);
80 count = result->collection.chunks[0]->GetValue(0, 0);
81 REQUIRE(count == sum - (i + 1) * INSERT_ELEMENTS);
82 }
83 // check the count on the original connection
84 result = con.Query("SELECT SUM(i) FROM integers");
85 REQUIRE_NO_FAIL(*result);
86 count = result->collection.chunks[0]->GetValue(0, 0);
87 REQUIRE(count == sum);
88
89 // commit everything
90 for (size_t i = 0; i < THREAD_COUNT; i++) {
91 connections[i]->Query("COMMIT;");
92 }
93
94 // check that the count is 0 now
95 result = con.Query("SELECT COUNT(i) FROM integers");
96 REQUIRE_NO_FAIL(*result);
97 count = result->collection.chunks[0]->GetValue(0, 0);
98 REQUIRE(count == 0);
99}
100
101TEST_CASE("Rollback delete", "[transactions]") {
102 unique_ptr<MaterializedQueryResult> result;
103 DuckDB db(nullptr);
104 Connection con(db);
105 vector<unique_ptr<Connection>> connections;
106
107 // initialize the database
108 con.Query("CREATE TABLE integers(i INTEGER);");
109 int sum = 0;
110 for (size_t i = 0; i < INSERT_ELEMENTS; i++) {
111 for (size_t j = 0; j < 10; j++) {
112 con.Query("INSERT INTO integers VALUES (" + to_string(j + 1) + ");");
113 sum += j + 1;
114 }
115 }
116
117 // begin transaction
118 REQUIRE_NO_FAIL(con.Query("BEGIN TRANSACTION"));
119
120 // check the sum
121 result = con.Query("SELECT SUM(i) FROM integers");
122 REQUIRE(CHECK_COLUMN(result, 0, {sum}));
123
124 // simple delete
125 result = con.Query("DELETE FROM integers WHERE i=2");
126 REQUIRE(CHECK_COLUMN(result, 0, {100}));
127
128 // check sum again
129 result = con.Query("SELECT SUM(i) FROM integers");
130 REQUIRE(CHECK_COLUMN(result, 0, {sum - 2 * INSERT_ELEMENTS}));
131
132 // rollback transaction
133 REQUIRE_NO_FAIL(con.Query("ROLLBACK"));
134
135 // check the sum again
136 result = con.Query("SELECT SUM(i) FROM integers");
137 REQUIRE(CHECK_COLUMN(result, 0, {sum}));
138}
139
140static volatile std::atomic<int> finished_threads;
141
142static void delete_elements(DuckDB *db, bool *correct, size_t threadnr) {
143 correct[threadnr] = true;
144 Connection con(*db);
145 // initial count
146 con.Query("BEGIN TRANSACTION;");
147 auto result = con.Query("SELECT COUNT(*) FROM integers");
148 Value count = result->collection.chunks[0]->GetValue(0, 0);
149 auto start_count = count.GetValue<int64_t>();
150
151 for (size_t i = 0; i < INSERT_ELEMENTS; i++) {
152 // count should decrease by one for every delete we do
153 auto element = INSERT_ELEMENTS * threadnr + i;
154 if (!con.Query("DELETE FROM integers WHERE i=" + to_string(element))->success) {
155 correct[threadnr] = false;
156 }
157 result = con.Query("SELECT COUNT(*) FROM integers");
158 if (!result->success) {
159 correct[threadnr] = false;
160 } else {
161 Value new_count = result->collection.chunks[0]->GetValue(0, 0);
162 if (new_count != start_count - (i + 1)) {
163 correct[threadnr] = false;
164 }
165 count = new_count;
166 }
167 }
168 finished_threads++;
169 while (finished_threads != THREAD_COUNT)
170 ;
171 con.Query("COMMIT;");
172}
173
174TEST_CASE("Concurrent delete", "[transactions][.]") {
175 unique_ptr<MaterializedQueryResult> result;
176 DuckDB db(nullptr);
177 Connection con(db);
178
179 // initialize the database
180 REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER);"));
181 for (size_t i = 0; i < INSERT_ELEMENTS; i++) {
182 for (size_t j = 0; j < THREAD_COUNT; j++) {
183 auto element = INSERT_ELEMENTS * j + i;
184 con.Query("INSERT INTO integers VALUES (" + to_string(element) + ");");
185 }
186 }
187
188 finished_threads = 0;
189
190 bool correct[THREAD_COUNT];
191 thread threads[THREAD_COUNT];
192 for (size_t i = 0; i < THREAD_COUNT; i++) {
193 threads[i] = thread(delete_elements, &db, correct, i);
194 }
195
196 for (size_t i = 0; i < THREAD_COUNT; i++) {
197 threads[i].join();
198 REQUIRE(correct[i]);
199 }
200
201 // check that the count is 0 now
202 result = con.Query("SELECT COUNT(i) FROM integers");
203 REQUIRE_NO_FAIL(*result);
204 auto count = result->collection.chunks[0]->GetValue(0, 0);
205 REQUIRE(count == 0);
206}
207
208} // namespace test_concurrent_delete
209