1#define CATCH_CONFIG_MAIN
2#include "catch.hpp"
3#include "sqlite3.h"
4#include <string>
5#include <thread>
6
7using namespace std;
8
9static int concatenate_results(void *arg, int ncols, char **vals, char **colnames) {
10 auto &results = *((vector<vector<string>> *)arg);
11 if (results.size() == 0) {
12 results.resize(ncols);
13 }
14 for (int i = 0; i < ncols; i++) {
15 results[i].push_back(vals[i] ? vals[i] : "");
16 }
17 return SQLITE_OK;
18}
19
20// C++ wrapper class for the C wrapper API that wraps our C++ API, because why not
21class SQLiteDBWrapper {
22public:
23 SQLiteDBWrapper() : db(nullptr) {
24 }
25 ~SQLiteDBWrapper() {
26 if (db) {
27 sqlite3_close(db);
28 }
29 }
30
31 sqlite3 *db;
32 vector<vector<string>> results;
33
34public:
35 int Open(string filename) {
36 return sqlite3_open(filename.c_str(), &db) == SQLITE_OK;
37 }
38
39 string GetErrorMessage() {
40 auto err = sqlite3_errmsg(db);
41 return err ? string(err) : string();
42 }
43
44 bool Execute(string query) {
45 results.clear();
46 char *errmsg = nullptr;
47 int rc = sqlite3_exec(db, query.c_str(), concatenate_results, &results, &errmsg);
48 if (errmsg) {
49 sqlite3_free(errmsg);
50 }
51 return rc == SQLITE_OK;
52 }
53
54 void PrintResult() {
55 for (size_t row_idx = 0; row_idx < results[0].size(); row_idx++) {
56 for (size_t col_idx = 0; col_idx < results.size(); col_idx++) {
57 printf("%s|", results[col_idx][row_idx].c_str());
58 }
59 printf("\n");
60 }
61 }
62
63 bool CheckColumn(size_t column, vector<string> expected_data) {
64 if (column >= results.size()) {
65 fprintf(stderr, "Column index is out of range!\n");
66 PrintResult();
67 return false;
68 }
69 if (results[column].size() != expected_data.size()) {
70 fprintf(stderr, "Row counts do not match!\n");
71 PrintResult();
72 return false;
73 }
74 for (size_t i = 0; i < expected_data.size(); i++) {
75 if (expected_data[i] != results[column][i]) {
76 fprintf(stderr, "Value does not match: expected \"%s\" but got \"%s\"\n", expected_data[i].c_str(),
77 results[column][i].c_str());
78 return false;
79 }
80 }
81 return true;
82 }
83};
84
85class SQLiteStmtWrapper {
86public:
87 SQLiteStmtWrapper() : stmt(nullptr) {
88 }
89 ~SQLiteStmtWrapper() {
90 Finalize();
91 }
92
93 sqlite3_stmt *stmt;
94 string error_message;
95
96 int Prepare(sqlite3 *db, const char *zSql, int nByte, const char **pzTail) {
97 Finalize();
98 return sqlite3_prepare_v2(db, zSql, nByte, &stmt, pzTail);
99 }
100
101 void Finalize() {
102 if (stmt) {
103 sqlite3_finalize(stmt);
104 stmt = nullptr;
105 }
106 }
107};
108
109TEST_CASE("Basic sqlite wrapper usage", "[sqlite3wrapper]") {
110 SQLiteDBWrapper db;
111
112 // open an in-memory db
113 REQUIRE(db.Open(":memory:"));
114
115 // standard selection
116 REQUIRE(db.Execute("SELECT 42;"));
117 REQUIRE(db.CheckColumn(0, {"42"}));
118
119 // simple statements
120 REQUIRE(db.Execute("CREATE TABLE test(i INTEGER)"));
121 REQUIRE(db.Execute("INSERT INTO test VALUES (1), (2), (3)"));
122 REQUIRE(db.Execute("SELECT SUM(t1.i)::BIGINT FROM test t1, test t2, test t3;"));
123 REQUIRE(db.CheckColumn(0, {"54"}));
124
125 REQUIRE(db.Execute("DELETE FROM test WHERE i=2"));
126 REQUIRE(db.Execute("UPDATE test SET i=i+1"));
127 REQUIRE(db.Execute("SELECT * FROM test ORDER BY 1;"));
128 REQUIRE(db.CheckColumn(0, {"2", "4"}));
129
130 // test different types
131#ifndef SQLITE_TEST
132 REQUIRE(
133 db.Execute("SELECT CAST('1992-01-01' AS DATE), 3, 'hello world', CAST('1992-01-01 00:00:00' AS TIMESTAMP);"));
134 REQUIRE(db.CheckColumn(0, {"1992-01-01"}));
135 REQUIRE(db.CheckColumn(1, {"3"}));
136 REQUIRE(db.CheckColumn(2, {"hello world"}));
137 REQUIRE(db.CheckColumn(3, {"1992-01-01 00:00:00"}));
138#endif
139
140 // handle errors
141 // syntax error
142 REQUIRE(!db.Execute("SELEC 42"));
143 // catalog error
144 REQUIRE(!db.Execute("SELECT * FROM nonexistant_tbl"));
145}
146
147TEST_CASE("Basic prepared statement usage", "[sqlite3wrapper]") {
148 SQLiteDBWrapper db;
149 SQLiteStmtWrapper stmt;
150
151 // open an in-memory db
152 REQUIRE(db.Open(":memory:"));
153 REQUIRE(db.Execute("CREATE TABLE test(i INTEGER, j BIGINT, k DATE, l VARCHAR)"));
154#ifndef SQLITE_TEST
155 // sqlite3_prepare_v2 errors
156 // nullptr for db/stmt, note: normal sqlite segfaults here
157 REQUIRE(sqlite3_prepare_v2(nullptr, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr, nullptr) ==
158 SQLITE_MISUSE);
159 REQUIRE(sqlite3_prepare_v2(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr, nullptr) ==
160 SQLITE_MISUSE);
161#endif
162 // prepared statement
163 REQUIRE(stmt.Prepare(db.db, "INSERT INTO test VALUES ($1, $2, $3, $4)", -1, nullptr) == SQLITE_OK);
164
165 // test for parameter count, names and indexes
166 REQUIRE(sqlite3_bind_parameter_count(nullptr) == 0);
167 REQUIRE(sqlite3_bind_parameter_count(stmt.stmt) == 4);
168 for (int i = 1; i < 5; i++) {
169 REQUIRE(sqlite3_bind_parameter_name(nullptr, i) == nullptr);
170 REQUIRE(sqlite3_bind_parameter_index(nullptr, nullptr) == 0);
171 REQUIRE(sqlite3_bind_parameter_index(stmt.stmt, nullptr) == 0);
172 REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, i) != nullptr);
173 REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, i) == string("$") + to_string(i));
174 REQUIRE(sqlite3_bind_parameter_index(stmt.stmt, sqlite3_bind_parameter_name(stmt.stmt, i)) == i);
175 }
176 REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, 0) == nullptr);
177 REQUIRE(sqlite3_bind_parameter_name(stmt.stmt, 5) == nullptr);
178
179#ifndef SQLITE_TEST
180 // this segfaults in SQLITE
181 REQUIRE(sqlite3_clear_bindings(nullptr) == SQLITE_MISUSE);
182#endif
183 REQUIRE(sqlite3_clear_bindings(stmt.stmt) == SQLITE_OK);
184 REQUIRE(sqlite3_clear_bindings(stmt.stmt) == SQLITE_OK);
185 // test for binding parameters
186 // incorrect bindings: nullptr as statement, wrong type and out of range binding
187 REQUIRE(sqlite3_bind_int(nullptr, 1, 1) == SQLITE_MISUSE);
188 REQUIRE(sqlite3_bind_int(stmt.stmt, 0, 1) == SQLITE_RANGE);
189 REQUIRE(sqlite3_bind_int(stmt.stmt, 5, 1) == SQLITE_RANGE);
190
191 // we can bind the incorrect type just fine
192 // error will only be thrown on execution
193 REQUIRE(sqlite3_bind_text(stmt.stmt, 1, "hello world", -1, nullptr) == SQLITE_OK);
194 REQUIRE(sqlite3_bind_int(stmt.stmt, 1, 1) == SQLITE_OK);
195 // we can rebind the same parameter
196 REQUIRE(sqlite3_bind_int(stmt.stmt, 1, 2) == SQLITE_OK);
197 REQUIRE(sqlite3_bind_int64(stmt.stmt, 2, 1000) == SQLITE_OK);
198 REQUIRE(sqlite3_bind_text(stmt.stmt, 3, "1992-01-01", -1, nullptr) == SQLITE_OK);
199 REQUIRE(sqlite3_bind_text(stmt.stmt, 4, "hello world", -1, nullptr) == SQLITE_OK);
200
201 REQUIRE(sqlite3_step(nullptr) == SQLITE_MISUSE);
202 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
203
204 // reset the statement
205 REQUIRE(sqlite3_reset(nullptr) == SQLITE_OK);
206 REQUIRE(sqlite3_reset(stmt.stmt) == SQLITE_OK);
207 // we can reset multiple times
208 REQUIRE(sqlite3_reset(stmt.stmt) == SQLITE_OK);
209
210 REQUIRE(sqlite3_bind_null(stmt.stmt, 1) == SQLITE_OK);
211 REQUIRE(sqlite3_bind_null(stmt.stmt, 2) == SQLITE_OK);
212 REQUIRE(sqlite3_bind_null(stmt.stmt, 3) == SQLITE_OK);
213 REQUIRE(sqlite3_bind_null(stmt.stmt, 4) == SQLITE_OK);
214
215 // we can step multiple times
216 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
217 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
218 REQUIRE(sqlite3_reset(stmt.stmt) == SQLITE_OK);
219 // after a reset we still have our bound values
220 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
221 // clearing the bindings results in us not having any values though
222 REQUIRE(sqlite3_clear_bindings(stmt.stmt) == SQLITE_OK);
223 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
224
225 REQUIRE(db.Execute("SELECT * FROM test ORDER BY 1"));
226 REQUIRE(db.CheckColumn(0, {"", "", "", "", "2"}));
227 REQUIRE(db.CheckColumn(1, {"", "", "", "", "1000"}));
228 REQUIRE(db.CheckColumn(2, {"", "", "", "", "1992-01-01"}));
229 REQUIRE(db.CheckColumn(3, {"", "", "", "", "hello world"}));
230
231 REQUIRE(sqlite3_finalize(nullptr) == SQLITE_OK);
232
233 // first prepare the statement again
234 REQUIRE(stmt.Prepare(db.db, "SELECT * FROM test WHERE i=CAST($1 AS INTEGER)", -1, nullptr) == SQLITE_OK);
235 // bind a non-integer here
236 REQUIRE(sqlite3_bind_text(stmt.stmt, 1, "hello", -1, nullptr) == SQLITE_OK);
237#ifndef SQLITE_TEST
238 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ERROR);
239 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ERROR);
240 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ERROR);
241 // need to be prepare aggain
242 REQUIRE(stmt.Prepare(db.db, "SELECT * FROM test WHERE i=CAST($1 AS INTEGER)", -1, nullptr) == SQLITE_OK);
243 REQUIRE(sqlite3_bind_text(stmt.stmt, 1, "2", -1, nullptr) == SQLITE_OK);
244 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ROW);
245#else
246 // sqlite allows string to int casts ("hello" becomes 0)
247 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
248 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
249#endif
250
251 // rebind and call again
252 // need to reset first
253 REQUIRE(sqlite3_bind_text(stmt.stmt, 1, "1", -1, nullptr) == SQLITE_MISUSE);
254 REQUIRE(sqlite3_reset(stmt.stmt) == SQLITE_OK);
255
256 REQUIRE(sqlite3_bind_text(stmt.stmt, 1, "2", -1, nullptr) == SQLITE_OK);
257 // repeatedly call sqlite3_step on a SELECT statement
258 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ROW);
259 // verify the results
260 REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 0)) == string("2"));
261 REQUIRE(sqlite3_column_int(stmt.stmt, 0) == 2);
262 REQUIRE(sqlite3_column_int64(stmt.stmt, 0) == 2);
263 REQUIRE(sqlite3_column_double(stmt.stmt, 0) == 2);
264 REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 1)) == string("1000"));
265 REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 2)) == string("1992-01-01"));
266 REQUIRE(string((char *)sqlite3_column_text(stmt.stmt, 3)) == string("hello world"));
267 REQUIRE(sqlite3_column_int(stmt.stmt, 3) == 0);
268 REQUIRE(sqlite3_column_int64(stmt.stmt, 3) == 0);
269 REQUIRE(sqlite3_column_double(stmt.stmt, 3) == 0);
270 REQUIRE(sqlite3_column_text(stmt.stmt, -1) == nullptr);
271 REQUIRE(sqlite3_column_text(stmt.stmt, 10) == nullptr);
272
273 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
274 // no data in the current row
275 REQUIRE(sqlite3_column_int(stmt.stmt, 0) == 0);
276 REQUIRE(sqlite3_column_int(nullptr, 0) == 0);
277 // the query resets again after SQLITE_DONE
278 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ROW);
279 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
280
281 // sqlite bind and errors
282 REQUIRE(stmt.Prepare(db.db, "SELECT * FROM non_existant_table", -1, nullptr) == SQLITE_ERROR);
283 REQUIRE(stmt.stmt == nullptr);
284
285 // sqlite3 prepare leftovers
286 // empty statement
287 const char *leftover;
288 REQUIRE(stmt.Prepare(db.db, "", -1, &leftover) == SQLITE_OK);
289 REQUIRE(leftover != nullptr);
290 REQUIRE(string(leftover) == "");
291 // leftover comment
292 REQUIRE(stmt.Prepare(db.db, "SELECT 42; --hello\nSELECT 3", -1, &leftover) == SQLITE_OK);
293 REQUIRE(leftover != nullptr);
294 REQUIRE(string(leftover) == " --hello\nSELECT 3");
295 // leftover extra statement
296 REQUIRE(stmt.Prepare(db.db, "SELECT 42--hello;\n, 3; SELECT 17", -1, &leftover) == SQLITE_OK);
297 REQUIRE(leftover != nullptr);
298 REQUIRE(string(leftover) == " SELECT 17");
299 // no query
300 REQUIRE(stmt.Prepare(db.db, nullptr, -1, &leftover) == SQLITE_MISUSE);
301
302 // sqlite3 prepare nByte
303 // any negative value can be used, not just -1
304 REQUIRE(stmt.Prepare(db.db, "SELECT 42", -1000, &leftover) == SQLITE_OK);
305 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ROW);
306 REQUIRE(sqlite3_column_int(stmt.stmt, 0) == 42);
307 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
308 // we can use nByte to skip reading part of string (in this case, skip WHERE 1=0)
309 REQUIRE(stmt.Prepare(db.db, "SELECT 42 WHERE 1=0", 9, &leftover) == SQLITE_OK);
310 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_ROW);
311 REQUIRE(sqlite3_column_int(stmt.stmt, 0) == 42);
312 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
313 // using too large nByte?
314 REQUIRE(stmt.Prepare(db.db, "SELECT 42 WHERE 1=0", 19, &leftover) == SQLITE_OK);
315 REQUIRE(sqlite3_step(stmt.stmt) == SQLITE_DONE);
316}
317
318static void sqlite3_interrupt_fast(SQLiteDBWrapper *db, bool *success) {
319 *success = db->Execute("SELECT SUM(i1.i) FROM integers i1, integers i2, integers i3, integers i4, integers i5");
320}
321
322TEST_CASE("Test sqlite3_interrupt", "[sqlite3wrapper]") {
323 SQLiteDBWrapper db;
324 bool success;
325
326 // open an in-memory db
327 REQUIRE(db.Open(":memory:"));
328 REQUIRE(db.Execute("CREATE TABLE integers(i INTEGER)"));
329 // create a database with 5 values
330 REQUIRE(db.Execute("INSERT INTO integers VALUES (1), (2), (3), (4), (5)"));
331 // 5 + 5 * 5 = 30 values
332 REQUIRE(db.Execute("INSERT INTO integers SELECT i1.i FROM integers i1, integers i2"));
333 // 30 + 30 * 30 = 930 values
334 REQUIRE(db.Execute("INSERT INTO integers SELECT i1.i FROM integers i1, integers i2"));
335 // run a thread that will run a big cross product
336 thread t1(sqlite3_interrupt_fast, &db, &success);
337 // wait a second and interrupt the db
338 std::this_thread::sleep_for(std::chrono::milliseconds(1000));
339 sqlite3_interrupt(db.db);
340 // join the thread again
341 t1.join();
342 // the execution should have been cancelled
343 REQUIRE(!success);
344}
345
346TEST_CASE("Test different statement types", "[sqlite3wrapper]") {
347 SQLiteDBWrapper db;
348
349 // open an in-memory db
350 REQUIRE(db.Open(":memory:"));
351 // create
352 REQUIRE(db.Execute("CREATE TABLE integers(i INTEGER)"));
353 // prepare
354 REQUIRE(db.Execute("PREPARE v1 AS INSERT INTO integers VALUES (?)"));
355 // execute
356 REQUIRE(db.Execute("EXECUTE v1(1)"));
357 REQUIRE(db.Execute("EXECUTE v1(2)"));
358 REQUIRE(db.Execute("EXECUTE v1(3)"));
359 // select
360 REQUIRE(db.Execute("SELECT * FROM integers ORDER BY 1"));
361 REQUIRE(db.CheckColumn(0, {"1", "2", "3"}));
362
363 // update
364 REQUIRE(db.Execute("UPDATE integers SET i=i+1"));
365 // delete
366 REQUIRE(db.Execute("DELETE FROM integers WHERE i=4"));
367 // verify
368 REQUIRE(db.Execute("SELECT * FROM integers ORDER BY 1"));
369 REQUIRE(db.CheckColumn(0, {"2", "3"}));
370
371 // transactions
372 REQUIRE(db.Execute("BEGIN TRANSACTION"));
373 REQUIRE(db.Execute("UPDATE integers SET i=i+1"));
374 REQUIRE(db.Execute("ROLLBACK"));
375 // verify
376 REQUIRE(db.Execute("SELECT * FROM integers ORDER BY 1"));
377 REQUIRE(db.CheckColumn(0, {"2", "3"}));
378
379 // commit
380 REQUIRE(db.Execute("BEGIN TRANSACTION"));
381 REQUIRE(db.Execute("UPDATE integers SET i=i+1"));
382 REQUIRE(db.Execute("COMMIT"));
383 // verify
384 REQUIRE(db.Execute("SELECT * FROM integers ORDER BY 1"));
385 REQUIRE(db.CheckColumn(0, {"3", "4"}));
386}
387
388TEST_CASE("Test rollback of aborted transaction", "[sqlite3wrapper]") {
389 SQLiteDBWrapper db;
390
391 // open an in-memory db
392 REQUIRE(db.Open(":memory:"));
393
394 // can start a transaction
395 REQUIRE(db.Execute("START TRANSACTION"));
396 // cannot start a transaction within a transaction
397 REQUIRE(!db.Execute("START TRANSACTION"));
398 // now we need to rollback!
399 REQUIRE(db.Execute("ROLLBACK"));
400 // can start a transaction again after a rollback
401 REQUIRE(db.Execute("START TRANSACTION"));
402}
403