1 | #include "catch.hpp" |
2 | #include "duckdb/common/value_operations/value_operations.hpp" |
3 | #include "test_helpers.hpp" |
4 | |
5 | #include <atomic> |
6 | #include <random> |
7 | #include <thread> |
8 | |
9 | using namespace duckdb; |
10 | using namespace std; |
11 | |
12 | namespace test_concurrent_update { |
13 | |
14 | static constexpr int TRANSACTION_UPDATE_COUNT = 1000; |
15 | static constexpr int TOTAL_ACCOUNTS = 20; |
16 | static constexpr int MONEY_PER_ACCOUNT = 10; |
17 | |
18 | TEST_CASE("Single thread update" , "[transactions]" ) { |
19 | unique_ptr<MaterializedQueryResult> result; |
20 | DuckDB db(nullptr); |
21 | Connection con(db); |
22 | |
23 | // initialize the database |
24 | con.Query("CREATE TABLE integers(i INTEGER);" ); |
25 | int sum = 0; |
26 | for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) { |
27 | for (size_t j = 0; j < 10; j++) { |
28 | con.Query("INSERT INTO integers VALUES (" + to_string(j + 1) + ");" ); |
29 | sum += j + 1; |
30 | } |
31 | } |
32 | |
33 | // check the sum |
34 | result = con.Query("SELECT SUM(i) FROM integers" ); |
35 | REQUIRE(CHECK_COLUMN(result, 0, {sum})); |
36 | |
37 | // simple update, we should update INSERT_ELEMENTS elements |
38 | result = con.Query("UPDATE integers SET i=4 WHERE i=2" ); |
39 | REQUIRE(CHECK_COLUMN(result, 0, {TOTAL_ACCOUNTS})); |
40 | |
41 | // check updated sum |
42 | result = con.Query("SELECT SUM(i) FROM integers" ); |
43 | REQUIRE(CHECK_COLUMN(result, 0, {sum + 2 * TOTAL_ACCOUNTS})); |
44 | } |
45 | |
46 | static volatile bool finished_updating = false; |
47 | static void read_total_balance(DuckDB *db, bool *read_correct) { |
48 | *read_correct = true; |
49 | Connection con(*db); |
50 | while (!finished_updating) { |
51 | // the total balance should remain constant regardless of updates |
52 | auto result = con.Query("SELECT SUM(money) FROM accounts" ); |
53 | if (!CHECK_COLUMN(result, 0, {TOTAL_ACCOUNTS * MONEY_PER_ACCOUNT})) { |
54 | *read_correct = false; |
55 | } |
56 | } |
57 | } |
58 | |
59 | TEST_CASE("Concurrent update" , "[updates][.]" ) { |
60 | unique_ptr<MaterializedQueryResult> result; |
61 | DuckDB db(nullptr); |
62 | Connection con(db); |
63 | |
64 | // fixed seed random numbers |
65 | mt19937 generator; |
66 | generator.seed(42); |
67 | uniform_int_distribution<int> account_distribution(0, TOTAL_ACCOUNTS - 1); |
68 | auto random_account = bind(account_distribution, generator); |
69 | |
70 | uniform_int_distribution<int> amount_distribution(0, MONEY_PER_ACCOUNT); |
71 | auto random_amount = bind(amount_distribution, generator); |
72 | |
73 | finished_updating = false; |
74 | // initialize the database |
75 | con.Query("CREATE TABLE accounts(id INTEGER, money INTEGER)" ); |
76 | for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) { |
77 | con.Query("INSERT INTO accounts VALUES (" + to_string(i) + ", " + to_string(MONEY_PER_ACCOUNT) + ");" ); |
78 | } |
79 | |
80 | bool read_correct; |
81 | // launch separate thread for reading aggregate |
82 | thread read_thread(read_total_balance, &db, &read_correct); |
83 | |
84 | // start vigorously updating balances in this thread |
85 | for (size_t i = 0; i < TRANSACTION_UPDATE_COUNT; i++) { |
86 | int from = random_account(); |
87 | int to = random_account(); |
88 | while (to == from) { |
89 | to = random_account(); |
90 | } |
91 | int amount = random_amount(); |
92 | |
93 | REQUIRE_NO_FAIL(con.Query("BEGIN TRANSACTION" )); |
94 | result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(from)); |
95 | Value money_from = result->collection.GetValue(0, 0); |
96 | result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(to)); |
97 | Value money_to = result->collection.GetValue(0, 0); |
98 | |
99 | REQUIRE_NO_FAIL( |
100 | con.Query("UPDATE accounts SET money = money - " + to_string(amount) + " WHERE id = " + to_string(from))); |
101 | REQUIRE_NO_FAIL( |
102 | con.Query("UPDATE accounts SET money = money + " + to_string(amount) + " WHERE id = " + to_string(to))); |
103 | |
104 | result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(from)); |
105 | Value new_money_from = result->collection.GetValue(0, 0); |
106 | result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(to)); |
107 | Value new_money_to = result->collection.GetValue(0, 0); |
108 | |
109 | Value expected_money_from, expected_money_to; |
110 | |
111 | expected_money_from = money_from - amount; |
112 | expected_money_to = money_to + amount; |
113 | |
114 | REQUIRE(new_money_from == expected_money_from); |
115 | REQUIRE(new_money_to == expected_money_to); |
116 | |
117 | REQUIRE_NO_FAIL(con.Query("COMMIT" )); |
118 | } |
119 | finished_updating = true; |
120 | read_thread.join(); |
121 | REQUIRE(read_correct); |
122 | } |
123 | |
124 | static std::atomic<size_t> finished_threads; |
125 | |
126 | static void write_random_numbers_to_account(DuckDB *db, bool *correct, size_t nr) { |
127 | correct[nr] = true; |
128 | Connection con(*db); |
129 | for (size_t i = 0; i < TRANSACTION_UPDATE_COUNT; i++) { |
130 | // just make some changes to the total |
131 | // the total amount of money after the commit is the same |
132 | if (!con.Query("BEGIN TRANSACTION" )->success) { |
133 | correct[nr] = false; |
134 | } |
135 | if (!con.Query("UPDATE accounts SET money = money + " + to_string(i * 2) + " WHERE id = " + to_string(nr)) |
136 | ->success) { |
137 | correct[nr] = false; |
138 | } |
139 | if (!con.Query("UPDATE accounts SET money = money - " + to_string(i) + " WHERE id = " + to_string(nr)) |
140 | ->success) { |
141 | correct[nr] = false; |
142 | } |
143 | if (!con.Query("UPDATE accounts SET money = money - " + to_string(i * 2) + " WHERE id = " + to_string(nr)) |
144 | ->success) { |
145 | correct[nr] = false; |
146 | } |
147 | if (!con.Query("UPDATE accounts SET money = money + " + to_string(i) + " WHERE id = " + to_string(nr)) |
148 | ->success) { |
149 | correct[nr] = false; |
150 | } |
151 | // we test both commit and rollback |
152 | // the result of both should be the same since the updates have a |
153 | // net-zero effect |
154 | if (!con.Query(nr % 2 == 0 ? "COMMIT" : "ROLLBACK" )->success) { |
155 | correct[nr] = false; |
156 | } |
157 | } |
158 | finished_threads++; |
159 | if (finished_threads == TOTAL_ACCOUNTS) { |
160 | finished_updating = true; |
161 | } |
162 | } |
163 | |
164 | TEST_CASE("Multiple concurrent updaters" , "[updates][.]" ) { |
165 | unique_ptr<MaterializedQueryResult> result; |
166 | DuckDB db(nullptr); |
167 | Connection con(db); |
168 | |
169 | finished_updating = false; |
170 | finished_threads = 0; |
171 | // initialize the database |
172 | con.Query("CREATE TABLE accounts(id INTEGER, money INTEGER)" ); |
173 | for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) { |
174 | con.Query("INSERT INTO accounts VALUES (" + to_string(i) + ", " + to_string(MONEY_PER_ACCOUNT) + ");" ); |
175 | } |
176 | |
177 | bool correct[TOTAL_ACCOUNTS]; |
178 | bool read_correct; |
179 | std::thread write_threads[TOTAL_ACCOUNTS]; |
180 | // launch a thread for reading the table |
181 | thread read_thread(read_total_balance, &db, &read_correct); |
182 | // launch several threads for updating the table |
183 | for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) { |
184 | write_threads[i] = thread(write_random_numbers_to_account, &db, correct, i); |
185 | } |
186 | read_thread.join(); |
187 | for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) { |
188 | write_threads[i].join(); |
189 | REQUIRE(correct[i]); |
190 | } |
191 | } |
192 | |
193 | } // namespace test_concurrent_update |
194 | |