| 1 | #include <chrono> |
| 2 | #include <cstdio> |
| 3 | #include <thread> |
| 4 | #include <iostream> |
| 5 | |
| 6 | #include "duckdb.hpp" |
| 7 | #include "duckdb/common/types/data_chunk.hpp" |
| 8 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 9 | #include "duckdb/common/unordered_map.hpp" |
| 10 | #include "duckdb/common/string_util.hpp" |
| 11 | #include "duckdb/main/client_context.hpp" |
| 12 | |
| 13 | // you can set this to enable compression. You will need to link zlib as well. |
| 14 | // #define CPPHTTPLIB_ZLIB_SUPPORT 1 |
| 15 | |
| 16 | #include "httplib.hpp" |
| 17 | #include "json.hpp" |
| 18 | |
| 19 | #include <unordered_map> |
| 20 | |
| 21 | using namespace httplib; |
| 22 | using namespace duckdb; |
| 23 | using namespace nlohmann; |
| 24 | |
| 25 | void print_help() { |
| 26 | fprintf(stderr, "🦆 Usage: duckdb_rest_server\n" ); |
| 27 | fprintf(stderr, " --listen=[address] listening address\n" ); |
| 28 | fprintf(stderr, " --port=[no] listening port\n" ); |
| 29 | fprintf(stderr, " --database=[file] use given database file\n" ); |
| 30 | fprintf(stderr, " --read_only open database in read-only mode\n" ); |
| 31 | fprintf(stderr, " --query_timeout=[sec] query timeout in seconds\n" ); |
| 32 | fprintf(stderr, " --fetch_timeout=[sec] result set timeout in seconds\n" ); |
| 33 | fprintf(stderr, " --log=[file] log queries to file\n\n" ); |
| 34 | fprintf(stderr, "Version: %s\n" , DUCKDB_SOURCE_ID); |
| 35 | |
| 36 | } |
| 37 | |
| 38 | // https://stackoverflow.com/a/12468109/2652376 |
| 39 | std::string random_string(size_t length) { |
| 40 | auto randchar = []() -> char { |
| 41 | const char charset[] = "0123456789" |
| 42 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| 43 | "abcdefghijklmnopqrstuvwxyz" ; |
| 44 | const size_t max_index = (sizeof(charset) - 1); |
| 45 | return charset[rand() % max_index]; |
| 46 | }; |
| 47 | std::string str(length, 0); |
| 48 | std::generate_n(str.begin(), length, randchar); |
| 49 | return str; |
| 50 | } |
| 51 | |
| 52 | struct RestClientState { |
| 53 | unique_ptr<QueryResult> res; |
| 54 | unique_ptr<Connection> con; |
| 55 | time_t touched; |
| 56 | }; |
| 57 | |
| 58 | enum ReturnContentType { JSON, BSON, CBOR, MESSAGE_PACK, UBJSON }; |
| 59 | |
| 60 | template <class T, class TARGET> static void assign_json_loop(Vector &v, idx_t col_idx, idx_t count, json &j) { |
| 61 | auto data_ptr = FlatVector::GetData<T>(v); |
| 62 | auto &nullmask = FlatVector::Nullmask(v); |
| 63 | for (idx_t i = 0; i < count; i++) { |
| 64 | if (!nullmask[i]) { |
| 65 | j["data" ][col_idx] += (TARGET)data_ptr[i]; |
| 66 | |
| 67 | } else { |
| 68 | j["data" ][col_idx] += nullptr; |
| 69 | } |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | void serialize_chunk(QueryResult *res, DataChunk *chunk, json &j) { |
| 74 | assert(res); |
| 75 | Vector v2(TypeId::VARCHAR); |
| 76 | for (size_t col_idx = 0; col_idx < chunk->column_count(); col_idx++) { |
| 77 | Vector *v = &chunk->data[col_idx]; |
| 78 | switch (res->sql_types[col_idx].id) { |
| 79 | case SQLTypeId::DATE: |
| 80 | case SQLTypeId::TIME: |
| 81 | case SQLTypeId::TIMESTAMP: { |
| 82 | VectorOperations::Cast(*v, v2, res->sql_types[col_idx], SQLType::VARCHAR, chunk->size()); |
| 83 | v = &v2; |
| 84 | break; |
| 85 | } |
| 86 | default: |
| 87 | break; |
| 88 | } |
| 89 | v->Normalify(chunk->size()); |
| 90 | assert(v); |
| 91 | switch (v->type) { |
| 92 | case TypeId::BOOL: |
| 93 | assign_json_loop<bool, int64_t>(*v, col_idx, chunk->size(), j); |
| 94 | break; |
| 95 | case TypeId::INT8: |
| 96 | assign_json_loop<int8_t, int64_t>(*v, col_idx, chunk->size(), j); |
| 97 | break; |
| 98 | case TypeId::INT16: |
| 99 | assign_json_loop<int16_t, int64_t>(*v, col_idx, chunk->size(), j); |
| 100 | break; |
| 101 | case TypeId::INT32: |
| 102 | assign_json_loop<int32_t, int64_t>(*v, col_idx, chunk->size(), j); |
| 103 | break; |
| 104 | case TypeId::INT64: |
| 105 | assign_json_loop<int64_t, int64_t>(*v, col_idx, chunk->size(), j); |
| 106 | break; |
| 107 | case TypeId::FLOAT: |
| 108 | assign_json_loop<float, double>(*v, col_idx, chunk->size(), j); |
| 109 | break; |
| 110 | case TypeId::DOUBLE: |
| 111 | assign_json_loop<float, double>(*v, col_idx, chunk->size(), j); |
| 112 | break; |
| 113 | case TypeId::VARCHAR: { |
| 114 | auto data_ptr = FlatVector::GetData<string_t>(*v); |
| 115 | auto &nullmask = FlatVector::Nullmask(*v); |
| 116 | for (idx_t i = 0; i < chunk->size(); i++) { |
| 117 | if (!nullmask[i]) { |
| 118 | j["data" ][col_idx] += data_ptr[i].GetData(); |
| 119 | |
| 120 | } else { |
| 121 | j["data" ][col_idx] += nullptr; |
| 122 | } |
| 123 | } |
| 124 | break; |
| 125 | } |
| 126 | default: |
| 127 | throw std::runtime_error("Unsupported Type" ); |
| 128 | } |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | void serialize_json(const Request &req, Response &resp, json &j) { |
| 133 | auto return_type = ReturnContentType::JSON; |
| 134 | j["duckdb_version" ] = DUCKDB_SOURCE_ID; |
| 135 | |
| 136 | if (req.has_header("Accept" )) { |
| 137 | auto accept = req.get_header_value("Accept" ); |
| 138 | if (accept.rfind("application/bson" , 0) == 0 || accept.rfind("application/x-bson" , 0) == 0) { |
| 139 | return_type = ReturnContentType::BSON; |
| 140 | } else if (accept.rfind("application/cbor" , 0) == 0) { |
| 141 | return_type = ReturnContentType::CBOR; |
| 142 | } else if (accept.rfind("application/msgpack" , 0) == 0 || accept.rfind("application/x-msgpack" , 0) == 0 || |
| 143 | accept.rfind("application/vnd.msgpack" , 0) == 0) { |
| 144 | return_type = ReturnContentType::MESSAGE_PACK; |
| 145 | } else if (accept.rfind("application/ubjson" , 0) == 0) { |
| 146 | return_type = ReturnContentType::UBJSON; |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | switch (return_type) { |
| 151 | case ReturnContentType::JSON: { |
| 152 | if (req.has_param("callback" )) { |
| 153 | auto jsonp_callback = req.get_param_value("callback" ); |
| 154 | resp.set_content(jsonp_callback + "(" + j.dump() + ");" , "application/javascript" ); |
| 155 | |
| 156 | } else { |
| 157 | resp.set_content(j.dump(), "application/json" ); |
| 158 | } |
| 159 | break; |
| 160 | } |
| 161 | case ReturnContentType::BSON: { |
| 162 | auto bson = json::to_bson(j); |
| 163 | resp.set_content((const char *)bson.data(), bson.size(), "application/bson" ); |
| 164 | break; |
| 165 | } |
| 166 | case ReturnContentType::CBOR: { |
| 167 | auto cbor = json::to_cbor(j); |
| 168 | resp.set_content((const char *)cbor.data(), cbor.size(), "application/cbor" ); |
| 169 | break; |
| 170 | } |
| 171 | case ReturnContentType::MESSAGE_PACK: { |
| 172 | auto msgpack = json::to_msgpack(j); |
| 173 | resp.set_content((const char *)msgpack.data(), msgpack.size(), "application/msgpack" ); |
| 174 | break; |
| 175 | } |
| 176 | case ReturnContentType::UBJSON: { |
| 177 | auto ubjson = json::to_ubjson(j); |
| 178 | resp.set_content((const char *)ubjson.data(), ubjson.size(), "application/ubjson" ); |
| 179 | break; |
| 180 | } |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | void sleep_thread(Connection *conn, bool *is_active, int timeout_duration) { |
| 185 | // timeout is given in seconds |
| 186 | // we wait 10ms per iteration, so timeout * 100 gives us the amount of |
| 187 | // iterations |
| 188 | assert(conn); |
| 189 | assert(is_active); |
| 190 | |
| 191 | if (timeout_duration < 0) { |
| 192 | return; |
| 193 | } |
| 194 | for (size_t i = 0; i < (size_t)(timeout_duration * 100) && *is_active; i++) { |
| 195 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
| 196 | } |
| 197 | if (*is_active) { |
| 198 | conn->Interrupt(); |
| 199 | } |
| 200 | } |
| 201 | |
| 202 | void client_state_cleanup(unordered_map<string, RestClientState> *map, std::mutex *mutex, int timeout_duration) { |
| 203 | // timeout is given in seconds |
| 204 | while (true) { |
| 205 | // sleep for half the timeout duration |
| 206 | std::this_thread::sleep_for(std::chrono::milliseconds((timeout_duration * 1000) / 2)); |
| 207 | { |
| 208 | std::lock_guard<std::mutex> guard(*mutex); |
| 209 | auto now = std::time(nullptr); |
| 210 | for (auto it = map->cbegin(); it != map->cend();) { |
| 211 | if (now - it->second.touched > timeout_duration) { |
| 212 | it = map->erase(it); |
| 213 | } else { |
| 214 | ++it; |
| 215 | } |
| 216 | } |
| 217 | } |
| 218 | } |
| 219 | } |
| 220 | |
| 221 | int main(int argc, char **argv) { |
| 222 | Server svr; |
| 223 | if (!svr.is_valid()) { |
| 224 | printf("server has an error...\n" ); |
| 225 | return -1; |
| 226 | } |
| 227 | |
| 228 | std::mutex out_mutex; |
| 229 | srand(time(nullptr)); |
| 230 | |
| 231 | DBConfig config; |
| 232 | string dbfile = "" ; |
| 233 | string logfile_name; |
| 234 | |
| 235 | string listen = "localhost" ; |
| 236 | int port = 1294; |
| 237 | std::ofstream logfile; |
| 238 | |
| 239 | int query_timeout = 60; |
| 240 | int fetch_timeout = 60 * 5; |
| 241 | |
| 242 | // parse config |
| 243 | for (int arg_index = 1; arg_index < argc; ++arg_index) { |
| 244 | string arg = argv[arg_index]; |
| 245 | if (arg == "--help" ) { |
| 246 | print_help(); |
| 247 | exit(0); |
| 248 | } else if (arg == "--read_only" ) { |
| 249 | config.access_mode = AccessMode::READ_ONLY; |
| 250 | } else if (StringUtil::StartsWith(arg, "--database=" )) { |
| 251 | auto splits = StringUtil::Split(arg, '='); |
| 252 | if (splits.size() != 2) { |
| 253 | print_help(); |
| 254 | exit(1); |
| 255 | } |
| 256 | dbfile = string(splits[1]); |
| 257 | } else if (StringUtil::StartsWith(arg, "--log=" )) { |
| 258 | auto splits = StringUtil::Split(arg, '='); |
| 259 | if (splits.size() != 2) { |
| 260 | print_help(); |
| 261 | exit(1); |
| 262 | } |
| 263 | logfile_name = string(splits[1]); |
| 264 | } else if (StringUtil::StartsWith(arg, "--listen=" )) { |
| 265 | auto splits = StringUtil::Split(arg, '='); |
| 266 | if (splits.size() != 2) { |
| 267 | print_help(); |
| 268 | exit(1); |
| 269 | } |
| 270 | listen = string(splits[1]); |
| 271 | } else if (StringUtil::StartsWith(arg, "--port=" )) { |
| 272 | auto splits = StringUtil::Split(arg, '='); |
| 273 | if (splits.size() != 2) { |
| 274 | print_help(); |
| 275 | exit(1); |
| 276 | } |
| 277 | port = std::stoi(splits[1]); |
| 278 | |
| 279 | } else if (StringUtil::StartsWith(arg, "--query_timeout=" )) { |
| 280 | auto splits = StringUtil::Split(arg, '='); |
| 281 | if (splits.size() != 2) { |
| 282 | print_help(); |
| 283 | exit(1); |
| 284 | } |
| 285 | query_timeout = std::stoi(splits[1]); |
| 286 | |
| 287 | } else if (StringUtil::StartsWith(arg, "--fetch_timeout=" )) { |
| 288 | auto splits = StringUtil::Split(arg, '='); |
| 289 | if (splits.size() != 2) { |
| 290 | print_help(); |
| 291 | exit(1); |
| 292 | } |
| 293 | fetch_timeout = std::stoi(splits[1]); |
| 294 | |
| 295 | } else { |
| 296 | fprintf(stderr, "Error: unknown argument %s\n" , arg.c_str()); |
| 297 | print_help(); |
| 298 | exit(1); |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | unordered_map<string, RestClientState> client_state_map; |
| 303 | std::mutex client_state_map_mutex; |
| 304 | std::thread client_state_cleanup_thread(client_state_cleanup, &client_state_map, &client_state_map_mutex, |
| 305 | fetch_timeout); |
| 306 | |
| 307 | if (!logfile_name.empty()) { |
| 308 | logfile.open(logfile_name, std::ios_base::app); |
| 309 | } |
| 310 | |
| 311 | DuckDB duckdb(dbfile.empty() ? nullptr : dbfile.c_str(), &config); |
| 312 | |
| 313 | svr.Get("/query" , [&](const Request &req, Response &resp) { |
| 314 | auto q = req.get_param_value("q" ); |
| 315 | { |
| 316 | std::lock_guard<std::mutex> guard(out_mutex); |
| 317 | logfile << q << " ; -- DFgoEnx9UIRgHFsVYW8K" << std::endl |
| 318 | << std::flush; // using a terminator that will **never** occur in queries |
| 319 | } |
| 320 | |
| 321 | json j; |
| 322 | |
| 323 | RestClientState state; |
| 324 | state.con = make_unique<Connection>(duckdb); |
| 325 | state.con->EnableProfiling(); |
| 326 | state.touched = std::time(nullptr); |
| 327 | bool is_active = true; |
| 328 | |
| 329 | std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout); |
| 330 | auto res = state.con->context->Query(q, true); |
| 331 | |
| 332 | is_active = false; |
| 333 | interrupt_thread.join(); |
| 334 | |
| 335 | state.res = move(res); |
| 336 | |
| 337 | if (state.res->success) { |
| 338 | j = {{"query" , q}, |
| 339 | {"success" , state.res->success}, |
| 340 | {"column_count" , state.res->types.size()}, |
| 341 | |
| 342 | {"statement_type" , StatementTypeToString(state.res->statement_type)}, |
| 343 | {"names" , json(state.res->names)}, |
| 344 | {"name_index_map" , json::object()}, |
| 345 | {"types" , json::array()}, |
| 346 | {"sql_types" , json::array()}, |
| 347 | {"data" , json::array()}}; |
| 348 | |
| 349 | for (auto &sql_type : state.res->sql_types) { |
| 350 | j["sql_types" ] += SQLTypeToString(sql_type); |
| 351 | } |
| 352 | for (auto &type : state.res->types) { |
| 353 | j["types" ] += TypeIdToString(type); |
| 354 | } |
| 355 | |
| 356 | // make it easier to get col data by name |
| 357 | size_t col_idx = 0; |
| 358 | for (auto &name : state.res->names) { |
| 359 | j["name_index_map" ][name] = col_idx; |
| 360 | col_idx++; |
| 361 | } |
| 362 | |
| 363 | // only do this if query was successful |
| 364 | string query_ref = random_string(10); |
| 365 | j["ref" ] = query_ref; |
| 366 | auto chunk = state.res->Fetch(); |
| 367 | serialize_chunk(state.res.get(), chunk.get(), j); |
| 368 | { |
| 369 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
| 370 | client_state_map[query_ref] = move(state); |
| 371 | } |
| 372 | |
| 373 | } else { |
| 374 | j = {{"query" , q}, {"success" , state.res->success}, {"error" , state.res->error}}; |
| 375 | } |
| 376 | |
| 377 | serialize_json(req, resp, j); |
| 378 | }); |
| 379 | |
| 380 | svr.Get("/fetch" , [&](const Request &req, Response &resp) { |
| 381 | auto ref = req.get_param_value("ref" ); |
| 382 | json j; |
| 383 | RestClientState state; |
| 384 | bool found_state = false; |
| 385 | { |
| 386 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
| 387 | auto it = client_state_map.find(ref); |
| 388 | if (it != client_state_map.end()) { |
| 389 | state = move(it->second); |
| 390 | client_state_map.erase(it); |
| 391 | found_state = true; |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | if (found_state) { |
| 396 | bool is_active = true; |
| 397 | std::thread interrupt_thread(sleep_thread, state.con.get(), &is_active, query_timeout); |
| 398 | auto chunk = state.res->Fetch(); |
| 399 | is_active = false; |
| 400 | interrupt_thread.join(); |
| 401 | |
| 402 | j = {{"success" , true}, {"ref" , ref}, {"count" , chunk->size()}, {"data" , json::array()}}; |
| 403 | serialize_chunk(state.res.get(), chunk.get(), j); |
| 404 | if (chunk->size() != 0) { |
| 405 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
| 406 | state.touched = std::time(nullptr); |
| 407 | client_state_map[ref] = move(state); |
| 408 | } |
| 409 | } else { |
| 410 | j = {{"success" , false}, {"error" , "Unable to find ref." }}; |
| 411 | } |
| 412 | |
| 413 | serialize_json(req, resp, j); |
| 414 | }); |
| 415 | |
| 416 | svr.Get("/close" , [&](const Request &req, Response &resp) { |
| 417 | auto ref = req.get_param_value("ref" ); |
| 418 | Connection conn(duckdb); |
| 419 | json j; |
| 420 | std::lock_guard<std::mutex> guard(client_state_map_mutex); |
| 421 | if (client_state_map.find(ref) != client_state_map.end()) { |
| 422 | client_state_map.erase(client_state_map.find(ref)); |
| 423 | j = {{"success" , true}, {"ref" , ref}}; |
| 424 | } else { |
| 425 | j = {{"success" , false}, {"error" , "Unable to find ref." }}; |
| 426 | } |
| 427 | |
| 428 | serialize_json(req, resp, j); |
| 429 | }); |
| 430 | |
| 431 | std::cout << "🦆 serving " + dbfile + " on http://" + listen + ":" + std::to_string(port) + "\n" ; |
| 432 | |
| 433 | svr.listen(listen.c_str(), port); |
| 434 | return 0; |
| 435 | } |
| 436 | |