1#include <chrono>
2#include <cstdio>
3#include <thread>
4#include <iostream>
5
6#include "duckdb.hpp"
7#include "duckdb/common/types/data_chunk.hpp"
8#include "duckdb/common/vector_operations/vector_operations.hpp"
9#include "duckdb/common/unordered_map.hpp"
10#include "duckdb/common/string_util.hpp"
11#include "duckdb/main/client_context.hpp"
12
13// you can set this to enable compression. You will need to link zlib as well.
14// #define CPPHTTPLIB_ZLIB_SUPPORT 1
15
16#include "httplib.hpp"
17#include "json.hpp"
18
19#include <unordered_map>
20
21using namespace httplib;
22using namespace duckdb;
23using namespace nlohmann;
24
25void print_help() {
26 fprintf(stderr, "🦆 Usage: duckdb_rest_server\n");
27 fprintf(stderr, " --listen=[address] listening address\n");
28 fprintf(stderr, " --port=[no] listening port\n");
29 fprintf(stderr, " --database=[file] use given database file\n");
30 fprintf(stderr, " --read_only open database in read-only mode\n");
31 fprintf(stderr, " --query_timeout=[sec] query timeout in seconds\n");
32 fprintf(stderr, " --fetch_timeout=[sec] result set timeout in seconds\n");
33 fprintf(stderr, " --log=[file] log queries to file\n\n");
34 fprintf(stderr, "Version: %s\n", DUCKDB_SOURCE_ID);
35
36}
37
38// https://stackoverflow.com/a/12468109/2652376
39std::string random_string(size_t length) {
40 auto randchar = []() -> char {
41 const char charset[] = "0123456789"
42 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
43 "abcdefghijklmnopqrstuvwxyz";
44 const size_t max_index = (sizeof(charset) - 1);
45 return charset[rand() % max_index];
46 };
47 std::string str(length, 0);
48 std::generate_n(str.begin(), length, randchar);
49 return str;
50}
51
52struct RestClientState {
53 unique_ptr<QueryResult> res;
54 unique_ptr<Connection> con;
55 time_t touched;
56};
57
58enum ReturnContentType { JSON, BSON, CBOR, MESSAGE_PACK, UBJSON };
59
60template <class T, class TARGET> static void assign_json_loop(Vector &v, idx_t col_idx, idx_t count, json &j) {
61 auto data_ptr = FlatVector::GetData<T>(v);
62 auto &nullmask = FlatVector::Nullmask(v);
63 for (idx_t i = 0; i < count; i++) {
64 if (!nullmask[i]) {
65 j["data"][col_idx] += (TARGET)data_ptr[i];
66
67 } else {
68 j["data"][col_idx] += nullptr;
69 }
70 }
71}
72
73void serialize_chunk(QueryResult *res, DataChunk *chunk, json &j) {
74 assert(res);
75 Vector v2(TypeId::VARCHAR);
76 for (size_t col_idx = 0; col_idx < chunk->column_count(); col_idx++) {
77 Vector *v = &chunk->data[col_idx];
78 switch (res->sql_types[col_idx].id) {
79 case SQLTypeId::DATE:
80 case SQLTypeId::TIME:
81 case SQLTypeId::TIMESTAMP: {
82 VectorOperations::Cast(*v, v2, res->sql_types[col_idx], SQLType::VARCHAR, chunk->size());
83 v = &v2;
84 break;
85 }
86 default:
87 break;
88 }
89 v->Normalify(chunk->size());
90 assert(v);
91 switch (v->type) {
92 case TypeId::BOOL:
93 assign_json_loop<bool, int64_t>(*v, col_idx, chunk->size(), j);
94 break;
95 case TypeId::INT8:
96 assign_json_loop<int8_t, int64_t>(*v, col_idx, chunk->size(), j);
97 break;
98 case TypeId::INT16:
99 assign_json_loop<int16_t, int64_t>(*v, col_idx, chunk->size(), j);
100 break;
101 case TypeId::INT32:
102 assign_json_loop<int32_t, int64_t>(*v, col_idx, chunk->size(), j);
103 break;
104 case TypeId::INT64:
105 assign_json_loop<int64_t, int64_t>(*v, col_idx, chunk->size(), j);
106 break;
107 case TypeId::FLOAT:
108 assign_json_loop<float, double>(*v, col_idx, chunk->size(), j);
109 break;
110 case TypeId::DOUBLE:
111 assign_json_loop<float, double>(*v, col_idx, chunk->size(), j);
112 break;
113 case TypeId::VARCHAR: {
114 auto data_ptr = FlatVector::GetData<string_t>(*v);
115 auto &nullmask = FlatVector::Nullmask(*v);
116 for (idx_t i = 0; i < chunk->size(); i++) {
117 if (!nullmask[i]) {
118 j["data"][col_idx] += data_ptr[i].GetData();
119
120 } else {
121 j["data"][col_idx] += nullptr;
122 }
123 }
124 break;
125 }
126 default:
127 throw std::runtime_error("Unsupported Type");
128 }
129 }
130}
131
132void serialize_json(const Request &req, Response &resp, json &j) {
133 auto return_type = ReturnContentType::JSON;
134 j["duckdb_version"] = DUCKDB_SOURCE_ID;
135
136 if (req.has_header("Accept")) {
137 auto accept = req.get_header_value("Accept");
138 if (accept.rfind("application/bson", 0) == 0 || accept.rfind("application/x-bson", 0) == 0) {
139 return_type = ReturnContentType::BSON;
140 } else if (accept.rfind("application/cbor", 0) == 0) {
141 return_type = ReturnContentType::CBOR;
142 } else if (accept.rfind("application/msgpack", 0) == 0 || accept.rfind("application/x-msgpack", 0) == 0 ||
143 accept.rfind("application/vnd.msgpack", 0) == 0) {
144 return_type = ReturnContentType::MESSAGE_PACK;
145 } else if (accept.rfind("application/ubjson", 0) == 0) {
146 return_type = ReturnContentType::UBJSON;
147 }
148 }
149
150 switch (return_type) {
151 case ReturnContentType::JSON: {
152 if (req.has_param("callback")) {
153 auto jsonp_callback = req.get_param_value("callback");
154 resp.set_content(jsonp_callback + "(" + j.dump() + ");", "application/javascript");
155
156 } else {
157 resp.set_content(j.dump(), "application/json");
158 }
159 break;
160 }
161 case ReturnContentType::BSON: {
162 auto bson = json::to_bson(j);
163 resp.set_content((const char *)bson.data(), bson.size(), "application/bson");
164 break;
165 }
166 case ReturnContentType::CBOR: {
167 auto cbor = json::to_cbor(j);
168 resp.set_content((const char *)cbor.data(), cbor.size(), "application/cbor");
169 break;
170 }
171 case ReturnContentType::MESSAGE_PACK: {
172 auto msgpack = json::to_msgpack(j);
173 resp.set_content((const char *)msgpack.data(), msgpack.size(), "application/msgpack");
174 break;
175 }
176 case ReturnContentType::UBJSON: {
177 auto ubjson = json::to_ubjson(j);
178 resp.set_content((const char *)ubjson.data(), ubjson.size(), "application/ubjson");
179 break;
180 }
181 }
182}
183
184void sleep_thread(Connection *conn, bool *is_active, int timeout_duration) {
185 // timeout is given in seconds
186 // we wait 10ms per iteration, so timeout * 100 gives us the amount of
187 // iterations
188 assert(conn);
189 assert(is_active);
190
191 if (timeout_duration < 0) {
192 return;
193 }
194 for (size_t i = 0; i < (size_t)(timeout_duration * 100) && *is_active; i++) {
195 std::this_thread::sleep_for(std::chrono::milliseconds(10));
196 }
197 if (*is_active) {
198 conn->Interrupt();
199 }
200}
201
202void client_state_cleanup(unordered_map<string, RestClientState> *map, std::mutex *mutex, int timeout_duration) {
203 // timeout is given in seconds
204 while (true) {
205 // sleep for half the timeout duration
206 std::this_thread::sleep_for(std::chrono::milliseconds((timeout_duration * 1000) / 2));
207 {
208 std::lock_guard<std::mutex> guard(*mutex);
209 auto now = std::time(nullptr);
210 for (auto it = map->cbegin(); it != map->cend();) {
211 if (now - it->second.touched > timeout_duration) {
212 it = map->erase(it);
213 } else {
214 ++it;
215 }
216 }
217 }
218 }
219}
220
221int main(int argc, char **argv) {
222 Server svr;
223 if (!svr.is_valid()) {
224 printf("server has an error...\n");
225 return -1;
226 }
227
228 std::mutex out_mutex;
229 srand(time(nullptr));
230
231 DBConfig config;
232 string dbfile = "";
233 string logfile_name;
234
235 string listen = "localhost";
236 int port = 1294;
237 std::ofstream logfile;
238
239 int query_timeout = 60;
240 int fetch_timeout = 60 * 5;
241
242 // parse config
243 for (int arg_index = 1; arg_index < argc; ++arg_index) {
244 string arg = argv[arg_index];
245 if (arg == "--help") {
246 print_help();
247 exit(0);
248 } else if (arg == "--read_only") {
249 config.access_mode = AccessMode::READ_ONLY;
250 } else if (StringUtil::StartsWith(arg, "--database=")) {
251 auto splits = StringUtil::Split(arg, '=');
252 if (splits.size() != 2) {
253 print_help();
254 exit(1);
255 }
256 dbfile = string(splits[1]);
257 } else if (StringUtil::StartsWith(arg, "--log=")) {
258 auto splits = StringUtil::Split(arg, '=');
259 if (splits.size() != 2) {
260 print_help();
261 exit(1);
262 }
263 logfile_name = string(splits[1]);
264 } else if (StringUtil::StartsWith(arg, "--listen=")) {
265 auto splits = StringUtil::Split(arg, '=');
266 if (splits.size() != 2) {
267 print_help();
268 exit(1);
269 }
270 listen = string(splits[1]);
271 } else if (StringUtil::StartsWith(arg, "--port=")) {
272 auto splits = StringUtil::Split(arg, '=');
273 if (splits.size() != 2) {
274 print_help();
275 exit(1);
276 }
277 port = std::stoi(splits[1]);
278
279 } else if (StringUtil::StartsWith(arg, "--query_timeout=")) {
280 auto splits = StringUtil::Split(arg, '=');
281 if (splits.size() != 2) {
282 print_help();
283 exit(1);
284 }
285 query_timeout = std::stoi(splits[1]);
286
287 } else if (StringUtil::StartsWith(arg, "--fetch_timeout=")) {
288 auto splits = StringUtil::Split(arg, '=');
289 if (splits.size() != 2) {
290 print_help();
291 exit(1);
292 }
293 fetch_timeout = std::stoi(splits[1]);
294
295 } else {
296 fprintf(stderr, "Error: unknown argument %s\n", arg.c_str());
297 print_help();
298 exit(1);
299 }
300 }
301
302 unordered_map<string, RestClientState> client_state_map;
303 std::mutex client_state_map_mutex;
304 std::thread client_state_cleanup_thread(client_state_cleanup, &client_state_map, &client_state_map_mutex,
305 fetch_timeout);
306
307 if (!logfile_name.empty()) {
308 logfile.open(logfile_name, std::ios_base::app);
309 }
310
311 DuckDB duckdb(dbfile.empty() ? nullptr : dbfile.c_str(), &config);
312
313 svr.Get("/query", [&](const Request &req, Response &resp) {
314 auto q = req.get_param_value("q");
315 {
316 std::lock_guard<std::mutex> guard(out_mutex);
317 logfile << q << " ; -- DFgoEnx9UIRgHFsVYW8K" << std::endl
318 << std::flush; // using a terminator that will **never** occur in queries
319 }
320
321 json j;
322
323 RestClientState state;
324 state.con = make_unique<Connection>(duckdb);
325 state.con->EnableProfiling();
326 state.touched = std::time(nullptr);
327 bool is_active = true;
328
329 std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout);
330 auto res = state.con->context->Query(q, true);
331
332 is_active = false;
333 interrupt_thread.join();
334
335 state.res = move(res);
336
337 if (state.res->success) {
338 j = {{"query", q},
339 {"success", state.res->success},
340 {"column_count", state.res->types.size()},
341
342 {"statement_type", StatementTypeToString(state.res->statement_type)},
343 {"names", json(state.res->names)},
344 {"name_index_map", json::object()},
345 {"types", json::array()},
346 {"sql_types", json::array()},
347 {"data", json::array()}};
348
349 for (auto &sql_type : state.res->sql_types) {
350 j["sql_types"] += SQLTypeToString(sql_type);
351 }
352 for (auto &type : state.res->types) {
353 j["types"] += TypeIdToString(type);
354 }
355
356 // make it easier to get col data by name
357 size_t col_idx = 0;
358 for (auto &name : state.res->names) {
359 j["name_index_map"][name] = col_idx;
360 col_idx++;
361 }
362
363 // only do this if query was successful
364 string query_ref = random_string(10);
365 j["ref"] = query_ref;
366 auto chunk = state.res->Fetch();
367 serialize_chunk(state.res.get(), chunk.get(), j);
368 {
369 std::lock_guard<std::mutex> guard(client_state_map_mutex);
370 client_state_map[query_ref] = move(state);
371 }
372
373 } else {
374 j = {{"query", q}, {"success", state.res->success}, {"error", state.res->error}};
375 }
376
377 serialize_json(req, resp, j);
378 });
379
380 svr.Get("/fetch", [&](const Request &req, Response &resp) {
381 auto ref = req.get_param_value("ref");
382 json j;
383 RestClientState state;
384 bool found_state = false;
385 {
386 std::lock_guard<std::mutex> guard(client_state_map_mutex);
387 auto it = client_state_map.find(ref);
388 if (it != client_state_map.end()) {
389 state = move(it->second);
390 client_state_map.erase(it);
391 found_state = true;
392 }
393 }
394
395 if (found_state) {
396 bool is_active = true;
397 std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout);
398 auto chunk = state.res->Fetch();
399 is_active = false;
400 interrupt_thread.join();
401
402 j = {{"success", true}, {"ref", ref}, {"count", chunk->size()}, {"data", json::array()}};
403 serialize_chunk(state.res.get(), chunk.get(), j);
404 if (chunk->size() != 0) {
405 std::lock_guard<std::mutex> guard(client_state_map_mutex);
406 state.touched = std::time(nullptr);
407 client_state_map[ref] = move(state);
408 }
409 } else {
410 j = {{"success", false}, {"error", "Unable to find ref."}};
411 }
412
413 serialize_json(req, resp, j);
414 });
415
416 svr.Get("/close", [&](const Request &req, Response &resp) {
417 auto ref = req.get_param_value("ref");
418 Connection conn(duckdb);
419 json j;
420 std::lock_guard<std::mutex> guard(client_state_map_mutex);
421 if (client_state_map.find(ref) != client_state_map.end()) {
422 client_state_map.erase(client_state_map.find(ref));
423 j = {{"success", true}, {"ref", ref}};
424 } else {
425 j = {{"success", false}, {"error", "Unable to find ref."}};
426 }
427
428 serialize_json(req, resp, j);
429 });
430
431 std::cout << "🦆 serving " + dbfile + " on http://" + listen + ":" + std::to_string(port) + "\n";
432
433 svr.listen(listen.c_str(), port);
434 return 0;
435}
436