1#include "catch.hpp"
2#include "duckdb/common/value_operations/value_operations.hpp"
3#include "test_helpers.hpp"
4
5#include <atomic>
6#include <random>
7#include <thread>
8
9using namespace duckdb;
10using namespace std;
11
12namespace test_concurrent_update {
13
14static constexpr int TRANSACTION_UPDATE_COUNT = 1000;
15static constexpr int TOTAL_ACCOUNTS = 20;
16static constexpr int MONEY_PER_ACCOUNT = 10;
17
18TEST_CASE("Single thread update", "[transactions]") {
19 unique_ptr<MaterializedQueryResult> result;
20 DuckDB db(nullptr);
21 Connection con(db);
22
23 // initialize the database
24 con.Query("CREATE TABLE integers(i INTEGER);");
25 int sum = 0;
26 for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) {
27 for (size_t j = 0; j < 10; j++) {
28 con.Query("INSERT INTO integers VALUES (" + to_string(j + 1) + ");");
29 sum += j + 1;
30 }
31 }
32
33 // check the sum
34 result = con.Query("SELECT SUM(i) FROM integers");
35 REQUIRE(CHECK_COLUMN(result, 0, {sum}));
36
37 // simple update, we should update INSERT_ELEMENTS elements
38 result = con.Query("UPDATE integers SET i=4 WHERE i=2");
39 REQUIRE(CHECK_COLUMN(result, 0, {TOTAL_ACCOUNTS}));
40
41 // check updated sum
42 result = con.Query("SELECT SUM(i) FROM integers");
43 REQUIRE(CHECK_COLUMN(result, 0, {sum + 2 * TOTAL_ACCOUNTS}));
44}
45
46static volatile bool finished_updating = false;
47static void read_total_balance(DuckDB *db, bool *read_correct) {
48 *read_correct = true;
49 Connection con(*db);
50 while (!finished_updating) {
51 // the total balance should remain constant regardless of updates
52 auto result = con.Query("SELECT SUM(money) FROM accounts");
53 if (!CHECK_COLUMN(result, 0, {TOTAL_ACCOUNTS * MONEY_PER_ACCOUNT})) {
54 *read_correct = false;
55 }
56 }
57}
58
59TEST_CASE("Concurrent update", "[updates][.]") {
60 unique_ptr<MaterializedQueryResult> result;
61 DuckDB db(nullptr);
62 Connection con(db);
63
64 // fixed seed random numbers
65 mt19937 generator;
66 generator.seed(42);
67 uniform_int_distribution<int> account_distribution(0, TOTAL_ACCOUNTS - 1);
68 auto random_account = bind(account_distribution, generator);
69
70 uniform_int_distribution<int> amount_distribution(0, MONEY_PER_ACCOUNT);
71 auto random_amount = bind(amount_distribution, generator);
72
73 finished_updating = false;
74 // initialize the database
75 con.Query("CREATE TABLE accounts(id INTEGER, money INTEGER)");
76 for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) {
77 con.Query("INSERT INTO accounts VALUES (" + to_string(i) + ", " + to_string(MONEY_PER_ACCOUNT) + ");");
78 }
79
80 bool read_correct;
81 // launch separate thread for reading aggregate
82 thread read_thread(read_total_balance, &db, &read_correct);
83
84 // start vigorously updating balances in this thread
85 for (size_t i = 0; i < TRANSACTION_UPDATE_COUNT; i++) {
86 int from = random_account();
87 int to = random_account();
88 while (to == from) {
89 to = random_account();
90 }
91 int amount = random_amount();
92
93 REQUIRE_NO_FAIL(con.Query("BEGIN TRANSACTION"));
94 result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(from));
95 Value money_from = result->collection.GetValue(0, 0);
96 result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(to));
97 Value money_to = result->collection.GetValue(0, 0);
98
99 REQUIRE_NO_FAIL(
100 con.Query("UPDATE accounts SET money = money - " + to_string(amount) + " WHERE id = " + to_string(from)));
101 REQUIRE_NO_FAIL(
102 con.Query("UPDATE accounts SET money = money + " + to_string(amount) + " WHERE id = " + to_string(to)));
103
104 result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(from));
105 Value new_money_from = result->collection.GetValue(0, 0);
106 result = con.Query("SELECT money FROM accounts WHERE id=" + to_string(to));
107 Value new_money_to = result->collection.GetValue(0, 0);
108
109 Value expected_money_from, expected_money_to;
110
111 expected_money_from = money_from - amount;
112 expected_money_to = money_to + amount;
113
114 REQUIRE(new_money_from == expected_money_from);
115 REQUIRE(new_money_to == expected_money_to);
116
117 REQUIRE_NO_FAIL(con.Query("COMMIT"));
118 }
119 finished_updating = true;
120 read_thread.join();
121 REQUIRE(read_correct);
122}
123
124static std::atomic<size_t> finished_threads;
125
126static void write_random_numbers_to_account(DuckDB *db, bool *correct, size_t nr) {
127 correct[nr] = true;
128 Connection con(*db);
129 for (size_t i = 0; i < TRANSACTION_UPDATE_COUNT; i++) {
130 // just make some changes to the total
131 // the total amount of money after the commit is the same
132 if (!con.Query("BEGIN TRANSACTION")->success) {
133 correct[nr] = false;
134 }
135 if (!con.Query("UPDATE accounts SET money = money + " + to_string(i * 2) + " WHERE id = " + to_string(nr))
136 ->success) {
137 correct[nr] = false;
138 }
139 if (!con.Query("UPDATE accounts SET money = money - " + to_string(i) + " WHERE id = " + to_string(nr))
140 ->success) {
141 correct[nr] = false;
142 }
143 if (!con.Query("UPDATE accounts SET money = money - " + to_string(i * 2) + " WHERE id = " + to_string(nr))
144 ->success) {
145 correct[nr] = false;
146 }
147 if (!con.Query("UPDATE accounts SET money = money + " + to_string(i) + " WHERE id = " + to_string(nr))
148 ->success) {
149 correct[nr] = false;
150 }
151 // we test both commit and rollback
152 // the result of both should be the same since the updates have a
153 // net-zero effect
154 if (!con.Query(nr % 2 == 0 ? "COMMIT" : "ROLLBACK")->success) {
155 correct[nr] = false;
156 }
157 }
158 finished_threads++;
159 if (finished_threads == TOTAL_ACCOUNTS) {
160 finished_updating = true;
161 }
162}
163
164TEST_CASE("Multiple concurrent updaters", "[updates][.]") {
165 unique_ptr<MaterializedQueryResult> result;
166 DuckDB db(nullptr);
167 Connection con(db);
168
169 finished_updating = false;
170 finished_threads = 0;
171 // initialize the database
172 con.Query("CREATE TABLE accounts(id INTEGER, money INTEGER)");
173 for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) {
174 con.Query("INSERT INTO accounts VALUES (" + to_string(i) + ", " + to_string(MONEY_PER_ACCOUNT) + ");");
175 }
176
177 bool correct[TOTAL_ACCOUNTS];
178 bool read_correct;
179 std::thread write_threads[TOTAL_ACCOUNTS];
180 // launch a thread for reading the table
181 thread read_thread(read_total_balance, &db, &read_correct);
182 // launch several threads for updating the table
183 for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) {
184 write_threads[i] = thread(write_random_numbers_to_account, &db, correct, i);
185 }
186 read_thread.join();
187 for (size_t i = 0; i < TOTAL_ACCOUNTS; i++) {
188 write_threads[i].join();
189 REQUIRE(correct[i]);
190 }
191}
192
193} // namespace test_concurrent_update
194