1#include "duckdb.hh"
2#include "dbgen.hpp"
3
4#include <cassert>
5#include <cstring>
6#include <iostream>
7#include <stdexcept>
8#include <thread>
9#include <chrono>
10
11#include <regex>
12
13using namespace duckdb;
14using namespace std;
15
16static regex e_syntax("Query Error: syntax error at or near .*");
17
18duckdb_connection::duckdb_connection(string &conninfo) {
19 // in-memory database
20 database = make_unique<DuckDB>(nullptr);
21 connection = make_unique<Connection>(*database);
22}
23
24void duckdb_connection::q(const char *query) {
25 auto result = connection->Query(query);
26 if (!result->success) {
27 throw runtime_error(result->error);
28 }
29}
30
31schema_duckdb::schema_duckdb(std::string &conninfo, bool no_catalog) : duckdb_connection(conninfo) {
32 // generate empty TPC-H schema
33 tpch::dbgen(0, *database);
34
35 cerr << "Loading tables...";
36 auto result = connection->Query("SELECT * FROM sqlite_master() WHERE type IN ('table', 'view')");
37 if (!result->success) {
38 throw runtime_error(result->error);
39 }
40 for (size_t i = 0; i < result->collection.count; i++) {
41 auto type = result->collection.GetValue(0, i).str_value;
42 auto name = result->collection.GetValue(2, i).str_value;
43 bool view = type == "view";
44 table tab(name, "main", !view, !view);
45 tables.push_back(tab);
46 }
47 cerr << "done." << endl;
48
49 if (tables.size() == 0) {
50 throw std::runtime_error("No tables available in catalog!");
51 }
52
53 cerr << "Loading columns and constraints...";
54
55 for (auto t = tables.begin(); t != tables.end(); ++t) {
56 result = connection->Query("PRAGMA table_info('" + t->name + "')");
57 if (!result->success) {
58 throw runtime_error(result->error);
59 }
60 for (size_t i = 0; i < result->collection.count; i++) {
61 auto name = result->collection.GetValue(1, i).str_value;
62 auto type = result->collection.GetValue(2, i).str_value;
63 column c(name, sqltype::get(type));
64 t->columns().push_back(c);
65 }
66 }
67
68 cerr << "done." << endl;
69
70#define BINOP(n, t) \
71 do { \
72 op o(#n, sqltype::get(#t), sqltype::get(#t), sqltype::get(#t)); \
73 register_operator(o); \
74 } while (0)
75
76 // BINOP(||, TEXT);
77 BINOP(*, INTEGER);
78 BINOP(/, INTEGER);
79
80 BINOP(+, INTEGER);
81 BINOP(-, INTEGER);
82
83 // BINOP(>>, INTEGER);
84 // BINOP(<<, INTEGER);
85
86 // BINOP(&, INTEGER);
87 // BINOP(|, INTEGER);
88
89 BINOP(<, INTEGER);
90 BINOP(<=, INTEGER);
91 BINOP(>, INTEGER);
92 BINOP(>=, INTEGER);
93
94 BINOP(=, INTEGER);
95 BINOP(<>, INTEGER);
96 BINOP(IS, INTEGER);
97 BINOP(IS NOT, INTEGER);
98
99 BINOP(AND, INTEGER);
100 BINOP(OR, INTEGER);
101
102#define FUNC(n, r) \
103 do { \
104 routine proc("", "", sqltype::get(#r), #n); \
105 register_routine(proc); \
106 } while (0)
107
108#define FUNC1(n, r, a) \
109 do { \
110 routine proc("", "", sqltype::get(#r), #n); \
111 proc.argtypes.push_back(sqltype::get(#a)); \
112 register_routine(proc); \
113 } while (0)
114
115#define FUNC2(n, r, a, b) \
116 do { \
117 routine proc("", "", sqltype::get(#r), #n); \
118 proc.argtypes.push_back(sqltype::get(#a)); \
119 proc.argtypes.push_back(sqltype::get(#b)); \
120 register_routine(proc); \
121 } while (0)
122
123#define FUNC3(n, r, a, b, c) \
124 do { \
125 routine proc("", "", sqltype::get(#r), #n); \
126 proc.argtypes.push_back(sqltype::get(#a)); \
127 proc.argtypes.push_back(sqltype::get(#b)); \
128 proc.argtypes.push_back(sqltype::get(#c)); \
129 register_routine(proc); \
130 } while (0)
131
132 // FUNC(last_insert_rowid, INTEGER);
133 // FUNC(random, INTEGER);
134 // FUNC(sqlite_source_id, TEXT);
135 // FUNC(sqlite_version, TEXT);
136 // FUNC(total_changes, INTEGER);
137
138 FUNC1(abs, INTEGER, REAL);
139 // FUNC1(hex, TEXT, TEXT);
140 // FUNC1(length, INTEGER, TEXT);
141 // FUNC1(lower, TEXT, TEXT);
142 // FUNC1(ltrim, TEXT, TEXT);
143 // FUNC1(quote, TEXT, TEXT);
144 // FUNC1(randomblob, TEXT, INTEGER);
145 // FUNC1(round, INTEGER, REAL);
146 // FUNC1(rtrim, TEXT, TEXT);
147 // FUNC1(soundex, TEXT, TEXT);
148 // FUNC1(sqlite_compileoption_get, TEXT, INTEGER);
149 // FUNC1(sqlite_compileoption_used, INTEGER, TEXT);
150 // FUNC1(trim, TEXT, TEXT);
151 // FUNC1(typeof, TEXT, INTEGER);
152 // FUNC1(typeof, TEXT, NUMERIC);
153 // FUNC1(typeof, TEXT, REAL);
154 // FUNC1(typeof, TEXT, TEXT);
155 // FUNC1(unicode, INTEGER, TEXT);
156 // FUNC1(upper, TEXT, TEXT);
157 // FUNC1(zeroblob, TEXT, INTEGER);
158
159 // FUNC2(glob, INTEGER, TEXT, TEXT);
160 // FUNC2(instr, INTEGER, TEXT, TEXT);
161 // FUNC2(like, INTEGER, TEXT, TEXT);
162 // FUNC2(ltrim, TEXT, TEXT, TEXT);
163 // FUNC2(rtrim, TEXT, TEXT, TEXT);
164 // FUNC2(trim, TEXT, TEXT, TEXT);
165 // FUNC2(round, INTEGER, REAL, INTEGER);
166 // FUNC2(substr, TEXT, TEXT, INTEGER);
167
168 // FUNC3(substr, TEXT, TEXT, INTEGER, INTEGER);
169 // FUNC3(replace, TEXT, TEXT, TEXT, TEXT);
170
171#define AGG(n, r, a) \
172 do { \
173 routine proc("", "", sqltype::get(#r), #n); \
174 proc.argtypes.push_back(sqltype::get(#a)); \
175 register_aggregate(proc); \
176 } while (0)
177
178 AGG(avg, INTEGER, INTEGER);
179 AGG(avg, REAL, REAL);
180 AGG(count, INTEGER, REAL);
181 AGG(count, INTEGER, TEXT);
182 AGG(count, INTEGER, INTEGER);
183 // AGG(group_concat, TEXT, TEXT);
184 AGG(max, REAL, REAL);
185 AGG(max, INTEGER, INTEGER);
186 AGG(min, REAL, REAL);
187 AGG(min, INTEGER, INTEGER);
188 AGG(sum, REAL, REAL);
189 AGG(sum, INTEGER, INTEGER);
190 // AGG(total, REAL, INTEGER);
191 // AGG(total, REAL, REAL);
192
193 booltype = sqltype::get("INTEGER");
194 inttype = sqltype::get("INTEGER");
195
196 internaltype = sqltype::get("internal");
197 arraytype = sqltype::get("ARRAY");
198
199 true_literal = "1";
200 false_literal = "0";
201
202 generate_indexes();
203}
204
205dut_duckdb::dut_duckdb(std::string &conninfo) : duckdb_connection(conninfo) {
206 cerr << "Generating TPC-H...";
207 tpch::dbgen(0.1, *database);
208 cerr << "done." << endl;
209 // q("PRAGMA main.auto_vacuum = 2");
210}
211
212volatile bool is_active = false;
213// timeout is 10ms * TIMEOUT_TICKS
214#define TIMEOUT_TICKS 50
215
216void sleep_thread(Connection *connection) {
217 for (size_t i = 0; i < TIMEOUT_TICKS && is_active; i++) {
218 std::this_thread::sleep_for(std::chrono::milliseconds(10));
219 }
220 if (is_active) {
221 connection->Interrupt();
222 }
223}
224
225void dut_duckdb::test(const std::string &stmt) {
226 is_active = true;
227 thread interrupt_thread(sleep_thread, connection.get());
228 auto result = connection->Query(stmt);
229 is_active = false;
230 interrupt_thread.join();
231
232 if (!result->success) {
233 auto error = result->error.c_str();
234 try {
235 if (regex_match(error, e_syntax))
236 throw dut::syntax(error);
237 else
238 throw dut::failure(error);
239 } catch (dut::failure &e) {
240 throw;
241 }
242 }
243}
244