1 | #include <chrono> |
2 | #include <cstdio> |
3 | #include <thread> |
4 | #include <iostream> |
5 | |
6 | #include "duckdb.hpp" |
7 | #include "duckdb/common/types/data_chunk.hpp" |
8 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
9 | #include "duckdb/common/unordered_map.hpp" |
10 | #include "duckdb/common/string_util.hpp" |
11 | #include "duckdb/main/client_context.hpp" |
12 | |
13 | // you can set this to enable compression. You will need to link zlib as well. |
14 | // #define CPPHTTPLIB_ZLIB_SUPPORT 1 |
15 | |
16 | #include "httplib.hpp" |
17 | #include "json.hpp" |
18 | |
19 | #include <unordered_map> |
20 | |
21 | using namespace httplib; |
22 | using namespace duckdb; |
23 | using namespace nlohmann; |
24 | |
25 | void print_help() { |
26 | fprintf(stderr, "🦆 Usage: duckdb_rest_server\n" ); |
27 | fprintf(stderr, " --listen=[address] listening address\n" ); |
28 | fprintf(stderr, " --port=[no] listening port\n" ); |
29 | fprintf(stderr, " --database=[file] use given database file\n" ); |
30 | fprintf(stderr, " --read_only open database in read-only mode\n" ); |
31 | fprintf(stderr, " --query_timeout=[sec] query timeout in seconds\n" ); |
32 | fprintf(stderr, " --fetch_timeout=[sec] result set timeout in seconds\n" ); |
33 | fprintf(stderr, " --log=[file] log queries to file\n\n" ); |
34 | fprintf(stderr, "Version: %s\n" , DUCKDB_SOURCE_ID); |
35 | |
36 | } |
37 | |
38 | // https://stackoverflow.com/a/12468109/2652376 |
39 | std::string random_string(size_t length) { |
40 | auto randchar = []() -> char { |
41 | const char charset[] = "0123456789" |
42 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
43 | "abcdefghijklmnopqrstuvwxyz" ; |
44 | const size_t max_index = (sizeof(charset) - 1); |
45 | return charset[rand() % max_index]; |
46 | }; |
47 | std::string str(length, 0); |
48 | std::generate_n(str.begin(), length, randchar); |
49 | return str; |
50 | } |
51 | |
52 | struct RestClientState { |
53 | unique_ptr<QueryResult> res; |
54 | unique_ptr<Connection> con; |
55 | time_t touched; |
56 | }; |
57 | |
58 | enum ReturnContentType { JSON, BSON, CBOR, MESSAGE_PACK, UBJSON }; |
59 | |
60 | template <class T, class TARGET> static void assign_json_loop(Vector &v, idx_t col_idx, idx_t count, json &j) { |
61 | auto data_ptr = FlatVector::GetData<T>(v); |
62 | auto &nullmask = FlatVector::Nullmask(v); |
63 | for (idx_t i = 0; i < count; i++) { |
64 | if (!nullmask[i]) { |
65 | j["data" ][col_idx] += (TARGET)data_ptr[i]; |
66 | |
67 | } else { |
68 | j["data" ][col_idx] += nullptr; |
69 | } |
70 | } |
71 | } |
72 | |
73 | void serialize_chunk(QueryResult *res, DataChunk *chunk, json &j) { |
74 | assert(res); |
75 | Vector v2(TypeId::VARCHAR); |
76 | for (size_t col_idx = 0; col_idx < chunk->column_count(); col_idx++) { |
77 | Vector *v = &chunk->data[col_idx]; |
78 | switch (res->sql_types[col_idx].id) { |
79 | case SQLTypeId::DATE: |
80 | case SQLTypeId::TIME: |
81 | case SQLTypeId::TIMESTAMP: { |
82 | VectorOperations::Cast(*v, v2, res->sql_types[col_idx], SQLType::VARCHAR, chunk->size()); |
83 | v = &v2; |
84 | break; |
85 | } |
86 | default: |
87 | break; |
88 | } |
89 | v->Normalify(chunk->size()); |
90 | assert(v); |
91 | switch (v->type) { |
92 | case TypeId::BOOL: |
93 | assign_json_loop<bool, int64_t>(*v, col_idx, chunk->size(), j); |
94 | break; |
95 | case TypeId::INT8: |
96 | assign_json_loop<int8_t, int64_t>(*v, col_idx, chunk->size(), j); |
97 | break; |
98 | case TypeId::INT16: |
99 | assign_json_loop<int16_t, int64_t>(*v, col_idx, chunk->size(), j); |
100 | break; |
101 | case TypeId::INT32: |
102 | assign_json_loop<int32_t, int64_t>(*v, col_idx, chunk->size(), j); |
103 | break; |
104 | case TypeId::INT64: |
105 | assign_json_loop<int64_t, int64_t>(*v, col_idx, chunk->size(), j); |
106 | break; |
107 | case TypeId::FLOAT: |
108 | assign_json_loop<float, double>(*v, col_idx, chunk->size(), j); |
109 | break; |
110 | case TypeId::DOUBLE: |
111 | assign_json_loop<float, double>(*v, col_idx, chunk->size(), j); |
112 | break; |
113 | case TypeId::VARCHAR: { |
114 | auto data_ptr = FlatVector::GetData<string_t>(*v); |
115 | auto &nullmask = FlatVector::Nullmask(*v); |
116 | for (idx_t i = 0; i < chunk->size(); i++) { |
117 | if (!nullmask[i]) { |
118 | j["data" ][col_idx] += data_ptr[i].GetData(); |
119 | |
120 | } else { |
121 | j["data" ][col_idx] += nullptr; |
122 | } |
123 | } |
124 | break; |
125 | } |
126 | default: |
127 | throw std::runtime_error("Unsupported Type" ); |
128 | } |
129 | } |
130 | } |
131 | |
132 | void serialize_json(const Request &req, Response &resp, json &j) { |
133 | auto return_type = ReturnContentType::JSON; |
134 | j["duckdb_version" ] = DUCKDB_SOURCE_ID; |
135 | |
136 | if (req.has_header("Accept" )) { |
137 | auto accept = req.get_header_value("Accept" ); |
138 | if (accept.rfind("application/bson" , 0) == 0 || accept.rfind("application/x-bson" , 0) == 0) { |
139 | return_type = ReturnContentType::BSON; |
140 | } else if (accept.rfind("application/cbor" , 0) == 0) { |
141 | return_type = ReturnContentType::CBOR; |
142 | } else if (accept.rfind("application/msgpack" , 0) == 0 || accept.rfind("application/x-msgpack" , 0) == 0 || |
143 | accept.rfind("application/vnd.msgpack" , 0) == 0) { |
144 | return_type = ReturnContentType::MESSAGE_PACK; |
145 | } else if (accept.rfind("application/ubjson" , 0) == 0) { |
146 | return_type = ReturnContentType::UBJSON; |
147 | } |
148 | } |
149 | |
150 | switch (return_type) { |
151 | case ReturnContentType::JSON: { |
152 | if (req.has_param("callback" )) { |
153 | auto jsonp_callback = req.get_param_value("callback" ); |
154 | resp.set_content(jsonp_callback + "(" + j.dump() + ");" , "application/javascript" ); |
155 | |
156 | } else { |
157 | resp.set_content(j.dump(), "application/json" ); |
158 | } |
159 | break; |
160 | } |
161 | case ReturnContentType::BSON: { |
162 | auto bson = json::to_bson(j); |
163 | resp.set_content((const char *)bson.data(), bson.size(), "application/bson" ); |
164 | break; |
165 | } |
166 | case ReturnContentType::CBOR: { |
167 | auto cbor = json::to_cbor(j); |
168 | resp.set_content((const char *)cbor.data(), cbor.size(), "application/cbor" ); |
169 | break; |
170 | } |
171 | case ReturnContentType::MESSAGE_PACK: { |
172 | auto msgpack = json::to_msgpack(j); |
173 | resp.set_content((const char *)msgpack.data(), msgpack.size(), "application/msgpack" ); |
174 | break; |
175 | } |
176 | case ReturnContentType::UBJSON: { |
177 | auto ubjson = json::to_ubjson(j); |
178 | resp.set_content((const char *)ubjson.data(), ubjson.size(), "application/ubjson" ); |
179 | break; |
180 | } |
181 | } |
182 | } |
183 | |
184 | void sleep_thread(Connection *conn, bool *is_active, int timeout_duration) { |
185 | // timeout is given in seconds |
186 | // we wait 10ms per iteration, so timeout * 100 gives us the amount of |
187 | // iterations |
188 | assert(conn); |
189 | assert(is_active); |
190 | |
191 | if (timeout_duration < 0) { |
192 | return; |
193 | } |
194 | for (size_t i = 0; i < (size_t)(timeout_duration * 100) && *is_active; i++) { |
195 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
196 | } |
197 | if (*is_active) { |
198 | conn->Interrupt(); |
199 | } |
200 | } |
201 | |
202 | void client_state_cleanup(unordered_map<string, RestClientState> *map, std::mutex *mutex, int timeout_duration) { |
203 | // timeout is given in seconds |
204 | while (true) { |
205 | // sleep for half the timeout duration |
206 | std::this_thread::sleep_for(std::chrono::milliseconds((timeout_duration * 1000) / 2)); |
207 | { |
208 | std::lock_guard<std::mutex> guard(*mutex); |
209 | auto now = std::time(nullptr); |
210 | for (auto it = map->cbegin(); it != map->cend();) { |
211 | if (now - it->second.touched > timeout_duration) { |
212 | it = map->erase(it); |
213 | } else { |
214 | ++it; |
215 | } |
216 | } |
217 | } |
218 | } |
219 | } |
220 | |
221 | int main(int argc, char **argv) { |
222 | Server svr; |
223 | if (!svr.is_valid()) { |
224 | printf("server has an error...\n" ); |
225 | return -1; |
226 | } |
227 | |
228 | std::mutex out_mutex; |
229 | srand(time(nullptr)); |
230 | |
231 | DBConfig config; |
232 | string dbfile = "" ; |
233 | string logfile_name; |
234 | |
235 | string listen = "localhost" ; |
236 | int port = 1294; |
237 | std::ofstream logfile; |
238 | |
239 | int query_timeout = 60; |
240 | int fetch_timeout = 60 * 5; |
241 | |
242 | // parse config |
243 | for (int arg_index = 1; arg_index < argc; ++arg_index) { |
244 | string arg = argv[arg_index]; |
245 | if (arg == "--help" ) { |
246 | print_help(); |
247 | exit(0); |
248 | } else if (arg == "--read_only" ) { |
249 | config.access_mode = AccessMode::READ_ONLY; |
250 | } else if (StringUtil::StartsWith(arg, "--database=" )) { |
251 | auto splits = StringUtil::Split(arg, '='); |
252 | if (splits.size() != 2) { |
253 | print_help(); |
254 | exit(1); |
255 | } |
256 | dbfile = string(splits[1]); |
257 | } else if (StringUtil::StartsWith(arg, "--log=" )) { |
258 | auto splits = StringUtil::Split(arg, '='); |
259 | if (splits.size() != 2) { |
260 | print_help(); |
261 | exit(1); |
262 | } |
263 | logfile_name = string(splits[1]); |
264 | } else if (StringUtil::StartsWith(arg, "--listen=" )) { |
265 | auto splits = StringUtil::Split(arg, '='); |
266 | if (splits.size() != 2) { |
267 | print_help(); |
268 | exit(1); |
269 | } |
270 | listen = string(splits[1]); |
271 | } else if (StringUtil::StartsWith(arg, "--port=" )) { |
272 | auto splits = StringUtil::Split(arg, '='); |
273 | if (splits.size() != 2) { |
274 | print_help(); |
275 | exit(1); |
276 | } |
277 | port = std::stoi(splits[1]); |
278 | |
279 | } else if (StringUtil::StartsWith(arg, "--query_timeout=" )) { |
280 | auto splits = StringUtil::Split(arg, '='); |
281 | if (splits.size() != 2) { |
282 | print_help(); |
283 | exit(1); |
284 | } |
285 | query_timeout = std::stoi(splits[1]); |
286 | |
287 | } else if (StringUtil::StartsWith(arg, "--fetch_timeout=" )) { |
288 | auto splits = StringUtil::Split(arg, '='); |
289 | if (splits.size() != 2) { |
290 | print_help(); |
291 | exit(1); |
292 | } |
293 | fetch_timeout = std::stoi(splits[1]); |
294 | |
295 | } else { |
296 | fprintf(stderr, "Error: unknown argument %s\n" , arg.c_str()); |
297 | print_help(); |
298 | exit(1); |
299 | } |
300 | } |
301 | |
302 | unordered_map<string, RestClientState> client_state_map; |
303 | std::mutex client_state_map_mutex; |
304 | std::thread client_state_cleanup_thread(client_state_cleanup, &client_state_map, &client_state_map_mutex, |
305 | fetch_timeout); |
306 | |
307 | if (!logfile_name.empty()) { |
308 | logfile.open(logfile_name, std::ios_base::app); |
309 | } |
310 | |
311 | DuckDB duckdb(dbfile.empty() ? nullptr : dbfile.c_str(), &config); |
312 | |
313 | svr.Get("/query" , [&](const Request &req, Response &resp) { |
314 | auto q = req.get_param_value("q" ); |
315 | { |
316 | std::lock_guard<std::mutex> guard(out_mutex); |
317 | logfile << q << " ; -- DFgoEnx9UIRgHFsVYW8K" << std::endl |
318 | << std::flush; // using a terminator that will **never** occur in queries |
319 | } |
320 | |
321 | json j; |
322 | |
323 | RestClientState state; |
324 | state.con = make_unique<Connection>(duckdb); |
325 | state.con->EnableProfiling(); |
326 | state.touched = std::time(nullptr); |
327 | bool is_active = true; |
328 | |
329 | std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout); |
330 | auto res = state.con->context->Query(q, true); |
331 | |
332 | is_active = false; |
333 | interrupt_thread.join(); |
334 | |
335 | state.res = move(res); |
336 | |
337 | if (state.res->success) { |
338 | j = {{"query" , q}, |
339 | {"success" , state.res->success}, |
340 | {"column_count" , state.res->types.size()}, |
341 | |
342 | {"statement_type" , StatementTypeToString(state.res->statement_type)}, |
343 | {"names" , json(state.res->names)}, |
344 | {"name_index_map" , json::object()}, |
345 | {"types" , json::array()}, |
346 | {"sql_types" , json::array()}, |
347 | {"data" , json::array()}}; |
348 | |
349 | for (auto &sql_type : state.res->sql_types) { |
350 | j["sql_types" ] += SQLTypeToString(sql_type); |
351 | } |
352 | for (auto &type : state.res->types) { |
353 | j["types" ] += TypeIdToString(type); |
354 | } |
355 | |
356 | // make it easier to get col data by name |
357 | size_t col_idx = 0; |
358 | for (auto &name : state.res->names) { |
359 | j["name_index_map" ][name] = col_idx; |
360 | col_idx++; |
361 | } |
362 | |
363 | // only do this if query was successful |
364 | string query_ref = random_string(10); |
365 | j["ref" ] = query_ref; |
366 | auto chunk = state.res->Fetch(); |
367 | serialize_chunk(state.res.get(), chunk.get(), j); |
368 | { |
369 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
370 | client_state_map[query_ref] = move(state); |
371 | } |
372 | |
373 | } else { |
374 | j = {{"query" , q}, {"success" , state.res->success}, {"error" , state.res->error}}; |
375 | } |
376 | |
377 | serialize_json(req, resp, j); |
378 | }); |
379 | |
380 | svr.Get("/fetch" , [&](const Request &req, Response &resp) { |
381 | auto ref = req.get_param_value("ref" ); |
382 | json j; |
383 | RestClientState state; |
384 | bool found_state = false; |
385 | { |
386 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
387 | auto it = client_state_map.find(ref); |
388 | if (it != client_state_map.end()) { |
389 | state = move(it->second); |
390 | client_state_map.erase(it); |
391 | found_state = true; |
392 | } |
393 | } |
394 | |
395 | if (found_state) { |
396 | bool is_active = true; |
397 | std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout); |
398 | auto chunk = state.res->Fetch(); |
399 | is_active = false; |
400 | interrupt_thread.join(); |
401 | |
402 | j = {{"success" , true}, {"ref" , ref}, {"count" , chunk->size()}, {"data" , json::array()}}; |
403 | serialize_chunk(state.res.get(), chunk.get(), j); |
404 | if (chunk->size() != 0) { |
405 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
406 | state.touched = std::time(nullptr); |
407 | client_state_map[ref] = move(state); |
408 | } |
409 | } else { |
410 | j = {{"success" , false}, {"error" , "Unable to find ref." }}; |
411 | } |
412 | |
413 | serialize_json(req, resp, j); |
414 | }); |
415 | |
416 | svr.Get("/close" , [&](const Request &req, Response &resp) { |
417 | auto ref = req.get_param_value("ref" ); |
418 | Connection conn(duckdb); |
419 | json j; |
420 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
421 | if (client_state_map.find(ref) != client_state_map.end()) { |
422 | client_state_map.erase(client_state_map.find(ref)); |
423 | j = {{"success" , true}, {"ref" , ref}}; |
424 | } else { |
425 | j = {{"success" , false}, {"error" , "Unable to find ref." }}; |
426 | } |
427 | |
428 | serialize_json(req, resp, j); |
429 | }); |
430 | |
431 | std::cout << "🦆 serving " + dbfile + " on http://" + listen + ":" + std::to_string(port) + "\n" ; |
432 | |
433 | svr.listen(listen.c_str(), port); |
434 | return 0; |
435 | } |
436 | |