| 1 | // Licensed to the Apache Software Foundation (ASF) under one |
| 2 | // or more contributor license agreements. See the NOTICE file |
| 3 | // distributed with this work for additional information |
| 4 | // regarding copyright ownership. The ASF licenses this file |
| 5 | // to you under the Apache License, Version 2.0 (the |
| 6 | // "License"); you may not use this file except in compliance |
| 7 | // with the License. You may obtain a copy of the License at |
| 8 | // |
| 9 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | // |
| 11 | // Unless required by applicable law or agreed to in writing, |
| 12 | // software distributed under the License is distributed on an |
| 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | // KIND, either express or implied. See the License for the |
| 15 | // specific language governing permissions and limitations |
| 16 | // under the License. |
| 17 | |
| 18 | #include "duckdb/common/adbc/driver_manager.h" |
| 19 | #include "duckdb/common/adbc/adbc.h" |
| 20 | #include "duckdb/common/adbc/adbc.hpp" |
| 21 | |
| 22 | #include <algorithm> |
| 23 | #include <cstring> |
| 24 | #include <string> |
| 25 | #include <unordered_map> |
| 26 | #include <utility> |
| 27 | |
| 28 | #if defined(_WIN32) |
| 29 | #include <windows.h> // Must come first |
| 30 | |
| 31 | #include <libloaderapi.h> |
| 32 | #include <strsafe.h> |
| 33 | #else |
| 34 | #include <dlfcn.h> |
| 35 | #endif // defined(_WIN32) |
| 36 | |
| 37 | namespace duckdb_adbc { |
| 38 | |
| 39 | // Platform-specific helpers |
| 40 | |
| 41 | #if defined(_WIN32) |
| 42 | /// Append a description of the Windows error to the buffer. |
| 43 | void GetWinError(std::string *buffer) { |
| 44 | DWORD rc = GetLastError(); |
| 45 | LPVOID message; |
| 46 | |
| 47 | FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, |
| 48 | /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), |
| 49 | reinterpret_cast<LPSTR>(&message), /*nSize=*/0, /*Arguments=*/nullptr); |
| 50 | |
| 51 | (*buffer) += '('; |
| 52 | (*buffer) += std::to_string(rc); |
| 53 | (*buffer) += ") " ; |
| 54 | (*buffer) += reinterpret_cast<char *>(message); |
| 55 | LocalFree(message); |
| 56 | } |
| 57 | |
| 58 | #endif // defined(_WIN32) |
| 59 | |
| 60 | // Error handling |
| 61 | |
| 62 | void ReleaseError(struct AdbcError *error) { |
| 63 | if (error) { |
| 64 | if (error->message) { |
| 65 | delete[] error->message; |
| 66 | } |
| 67 | error->message = nullptr; |
| 68 | error->release = nullptr; |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | void SetError(struct AdbcError *error, const std::string &message) { |
| 73 | if (!error) { |
| 74 | return; |
| 75 | } |
| 76 | if (error->message) { |
| 77 | // Append |
| 78 | std::string buffer = error->message; |
| 79 | buffer.reserve(res_arg: buffer.size() + message.size() + 1); |
| 80 | buffer += '\n'; |
| 81 | buffer += message; |
| 82 | error->release(error); |
| 83 | |
| 84 | error->message = new char[buffer.size() + 1]; |
| 85 | buffer.copy(s: error->message, n: buffer.size()); |
| 86 | error->message[buffer.size()] = '\0'; |
| 87 | } else { |
| 88 | error->message = new char[message.size() + 1]; |
| 89 | message.copy(s: error->message, n: message.size()); |
| 90 | error->message[message.size()] = '\0'; |
| 91 | } |
| 92 | error->release = ReleaseError; |
| 93 | } |
| 94 | |
| 95 | // Driver state |
| 96 | |
| 97 | /// Hold the driver DLL and the driver release callback in the driver struct. |
| 98 | struct ManagerDriverState { |
| 99 | // The original release callback |
| 100 | AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); |
| 101 | |
| 102 | #if defined(_WIN32) |
| 103 | // The loaded DLL |
| 104 | HMODULE handle; |
| 105 | #endif // defined(_WIN32) |
| 106 | }; |
| 107 | |
| 108 | /// Unload the driver DLL. |
| 109 | static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { |
| 110 | AdbcStatusCode status = ADBC_STATUS_OK; |
| 111 | |
| 112 | if (!driver->private_manager) |
| 113 | return status; |
| 114 | ManagerDriverState *state = reinterpret_cast<ManagerDriverState *>(driver->private_manager); |
| 115 | |
| 116 | if (state->driver_release) { |
| 117 | status = state->driver_release(driver, error); |
| 118 | } |
| 119 | |
| 120 | #if defined(_WIN32) |
| 121 | // TODO(apache/arrow-adbc#204): causes tests to segfault |
| 122 | // if (!FreeLibrary(state->handle)) { |
| 123 | // std::string message = "FreeLibrary() failed: "; |
| 124 | // GetWinError(&message); |
| 125 | // SetError(error, message); |
| 126 | // } |
| 127 | #endif // defined(_WIN32) |
| 128 | |
| 129 | driver->private_manager = nullptr; |
| 130 | delete state; |
| 131 | return status; |
| 132 | } |
| 133 | |
| 134 | // Default stubs |
| 135 | |
| 136 | AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, |
| 137 | struct ArrowArrayStream *out, struct AdbcError *error) { |
| 138 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 139 | } |
| 140 | |
| 141 | AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *, const char *, const char *, const char *, |
| 142 | struct ArrowSchema *, struct AdbcError *error) { |
| 143 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 144 | } |
| 145 | |
| 146 | AdbcStatusCode StatementBind(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *, |
| 147 | struct AdbcError *error) { |
| 148 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 149 | } |
| 150 | |
| 151 | AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, |
| 152 | struct AdbcError *error) { |
| 153 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 154 | } |
| 155 | AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *error) { |
| 156 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 157 | } |
| 158 | |
| 159 | /// Temporary state while the database is being configured. |
| 160 | struct TempDatabase { |
| 161 | std::unordered_map<std::string, std::string> options; |
| 162 | std::string driver; |
| 163 | // Default name (see adbc.h) |
| 164 | std::string entrypoint = "AdbcDriverInit" ; |
| 165 | AdbcDriverInitFunc init_func = nullptr; |
| 166 | }; |
| 167 | |
| 168 | /// Temporary state while the database is being configured. |
| 169 | struct TempConnection { |
| 170 | std::unordered_map<std::string, std::string> options; |
| 171 | }; |
| 172 | |
| 173 | // Direct implementations of API methods |
| 174 | |
| 175 | AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { |
| 176 | // Allocate a temporary structure to store options pre-Init |
| 177 | database->private_data = new TempDatabase(); |
| 178 | database->private_driver = nullptr; |
| 179 | return ADBC_STATUS_OK; |
| 180 | } |
| 181 | |
| 182 | AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, |
| 183 | struct AdbcError *error) { |
| 184 | if (database->private_driver) { |
| 185 | return database->private_driver->DatabaseSetOption(database, key, value, error); |
| 186 | } |
| 187 | |
| 188 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
| 189 | if (std::strcmp(s1: key, s2: "driver" ) == 0) { |
| 190 | args->driver = value; |
| 191 | } else if (std::strcmp(s1: key, s2: "entrypoint" ) == 0) { |
| 192 | args->entrypoint = value; |
| 193 | } else { |
| 194 | args->options[key] = value; |
| 195 | } |
| 196 | return ADBC_STATUS_OK; |
| 197 | } |
| 198 | |
| 199 | AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, |
| 200 | struct AdbcError *error) { |
| 201 | if (database->private_driver) { |
| 202 | return ADBC_STATUS_INVALID_STATE; |
| 203 | } |
| 204 | |
| 205 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
| 206 | args->init_func = init_func; |
| 207 | return ADBC_STATUS_OK; |
| 208 | } |
| 209 | |
| 210 | AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { |
| 211 | if (!database->private_data) { |
| 212 | SetError(error, message: "Must call AdbcDatabaseNew first" ); |
| 213 | return ADBC_STATUS_INVALID_STATE; |
| 214 | } |
| 215 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
| 216 | if (args->init_func) { |
| 217 | // Do nothing |
| 218 | } else if (args->driver.empty()) { |
| 219 | SetError(error, message: "Must provide 'driver' parameter" ); |
| 220 | return ADBC_STATUS_INVALID_ARGUMENT; |
| 221 | } |
| 222 | |
| 223 | database->private_driver = new AdbcDriver; |
| 224 | std::memset(s: database->private_driver, c: 0, n: sizeof(AdbcDriver)); |
| 225 | AdbcStatusCode status; |
| 226 | // So we don't confuse a driver into thinking it's initialized already |
| 227 | database->private_data = nullptr; |
| 228 | if (args->init_func) { |
| 229 | status = AdbcLoadDriverFromInitFunc(init_func: args->init_func, ADBC_VERSION_1_0_0, driver: database->private_driver, error); |
| 230 | } else { |
| 231 | status = AdbcLoadDriver(driver_name: args->driver.c_str(), entrypoint: args->entrypoint.c_str(), ADBC_VERSION_1_0_0, |
| 232 | driver: database->private_driver, error); |
| 233 | } |
| 234 | if (status != ADBC_STATUS_OK) { |
| 235 | // Restore private_data so it will be released by AdbcDatabaseRelease |
| 236 | database->private_data = args; |
| 237 | if (database->private_driver->release) { |
| 238 | database->private_driver->release(database->private_driver, error); |
| 239 | } |
| 240 | delete database->private_driver; |
| 241 | database->private_driver = nullptr; |
| 242 | return status; |
| 243 | } |
| 244 | status = database->private_driver->DatabaseNew(database, error); |
| 245 | if (status != ADBC_STATUS_OK) { |
| 246 | if (database->private_driver->release) { |
| 247 | database->private_driver->release(database->private_driver, error); |
| 248 | } |
| 249 | delete database->private_driver; |
| 250 | database->private_driver = nullptr; |
| 251 | return status; |
| 252 | } |
| 253 | for (const auto &option : args->options) { |
| 254 | status = |
| 255 | database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); |
| 256 | if (status != ADBC_STATUS_OK) { |
| 257 | delete args; |
| 258 | // Release the database |
| 259 | std::ignore = database->private_driver->DatabaseRelease(database, error); |
| 260 | if (database->private_driver->release) { |
| 261 | database->private_driver->release(database->private_driver, error); |
| 262 | } |
| 263 | delete database->private_driver; |
| 264 | database->private_driver = nullptr; |
| 265 | // Should be redundant, but ensure that AdbcDatabaseRelease |
| 266 | // below doesn't think that it contains a TempDatabase |
| 267 | database->private_data = nullptr; |
| 268 | return status; |
| 269 | } |
| 270 | } |
| 271 | delete args; |
| 272 | return database->private_driver->DatabaseInit(database, error); |
| 273 | } |
| 274 | |
| 275 | AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { |
| 276 | if (!database->private_driver) { |
| 277 | if (database->private_data) { |
| 278 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
| 279 | delete args; |
| 280 | database->private_data = nullptr; |
| 281 | return ADBC_STATUS_OK; |
| 282 | } |
| 283 | return ADBC_STATUS_INVALID_STATE; |
| 284 | } |
| 285 | auto status = database->private_driver->DatabaseRelease(database, error); |
| 286 | if (database->private_driver->release) { |
| 287 | database->private_driver->release(database->private_driver, error); |
| 288 | } |
| 289 | delete database->private_driver; |
| 290 | database->private_data = nullptr; |
| 291 | database->private_driver = nullptr; |
| 292 | return status; |
| 293 | } |
| 294 | |
| 295 | AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { |
| 296 | if (!connection->private_driver) { |
| 297 | return ADBC_STATUS_INVALID_STATE; |
| 298 | } |
| 299 | return connection->private_driver->ConnectionCommit(connection, error); |
| 300 | } |
| 301 | |
| 302 | AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, |
| 303 | struct ArrowArrayStream *out, struct AdbcError *error) { |
| 304 | if (!connection->private_driver) { |
| 305 | return ADBC_STATUS_INVALID_STATE; |
| 306 | } |
| 307 | return connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); |
| 308 | } |
| 309 | |
| 310 | AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, |
| 311 | const char *db_schema, const char *table_name, const char **table_types, |
| 312 | const char *column_name, struct ArrowArrayStream *stream, |
| 313 | struct AdbcError *error) { |
| 314 | if (!connection->private_driver) { |
| 315 | return ADBC_STATUS_INVALID_STATE; |
| 316 | } |
| 317 | return connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, |
| 318 | table_types, column_name, stream, error); |
| 319 | } |
| 320 | |
| 321 | AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, |
| 322 | const char *db_schema, const char *table_name, struct ArrowSchema *schema, |
| 323 | struct AdbcError *error) { |
| 324 | if (!connection->private_driver) { |
| 325 | return ADBC_STATUS_INVALID_STATE; |
| 326 | } |
| 327 | return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, |
| 328 | error); |
| 329 | } |
| 330 | |
| 331 | AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, |
| 332 | struct AdbcError *error) { |
| 333 | if (!connection->private_driver) { |
| 334 | return ADBC_STATUS_INVALID_STATE; |
| 335 | } |
| 336 | return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); |
| 337 | } |
| 338 | |
| 339 | AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, |
| 340 | struct AdbcError *error) { |
| 341 | if (!connection->private_data) { |
| 342 | SetError(error, message: "Must call AdbcConnectionNew first" ); |
| 343 | return ADBC_STATUS_INVALID_STATE; |
| 344 | } else if (!database->private_driver) { |
| 345 | SetError(error, message: "Database is not initialized" ); |
| 346 | return ADBC_STATUS_INVALID_ARGUMENT; |
| 347 | } |
| 348 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
| 349 | connection->private_data = nullptr; |
| 350 | std::unordered_map<std::string, std::string> options = std::move(args->options); |
| 351 | delete args; |
| 352 | |
| 353 | auto status = database->private_driver->ConnectionNew(connection, error); |
| 354 | if (status != ADBC_STATUS_OK) |
| 355 | return status; |
| 356 | connection->private_driver = database->private_driver; |
| 357 | |
| 358 | for (const auto &option : options) { |
| 359 | status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), |
| 360 | error); |
| 361 | if (status != ADBC_STATUS_OK) |
| 362 | return status; |
| 363 | } |
| 364 | return connection->private_driver->ConnectionInit(connection, database, error); |
| 365 | } |
| 366 | |
| 367 | AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { |
| 368 | // Allocate a temporary structure to store options pre-Init, because |
| 369 | // we don't get access to the database (and hence the driver |
| 370 | // function table) until then |
| 371 | connection->private_data = new TempConnection; |
| 372 | connection->private_driver = nullptr; |
| 373 | return ADBC_STATUS_OK; |
| 374 | } |
| 375 | |
| 376 | AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, |
| 377 | size_t serialized_length, struct ArrowArrayStream *out, |
| 378 | struct AdbcError *error) { |
| 379 | if (!connection->private_driver) { |
| 380 | return ADBC_STATUS_INVALID_STATE; |
| 381 | } |
| 382 | return connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, out, |
| 383 | error); |
| 384 | } |
| 385 | |
| 386 | AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { |
| 387 | if (!connection->private_driver) { |
| 388 | if (connection->private_data) { |
| 389 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
| 390 | delete args; |
| 391 | connection->private_data = nullptr; |
| 392 | return ADBC_STATUS_OK; |
| 393 | } |
| 394 | return ADBC_STATUS_INVALID_STATE; |
| 395 | } |
| 396 | auto status = connection->private_driver->ConnectionRelease(connection, error); |
| 397 | connection->private_driver = nullptr; |
| 398 | return status; |
| 399 | } |
| 400 | |
| 401 | AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { |
| 402 | if (!connection->private_driver) { |
| 403 | return ADBC_STATUS_INVALID_STATE; |
| 404 | } |
| 405 | return connection->private_driver->ConnectionRollback(connection, error); |
| 406 | } |
| 407 | |
| 408 | AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, |
| 409 | struct AdbcError *error) { |
| 410 | if (!connection->private_data) { |
| 411 | SetError(error, message: "AdbcConnectionSetOption: must AdbcConnectionNew first" ); |
| 412 | return ADBC_STATUS_INVALID_STATE; |
| 413 | } |
| 414 | if (!connection->private_driver) { |
| 415 | // Init not yet called, save the option |
| 416 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
| 417 | args->options[key] = value; |
| 418 | return ADBC_STATUS_OK; |
| 419 | } |
| 420 | return connection->private_driver->ConnectionSetOption(connection, key, value, error); |
| 421 | } |
| 422 | |
| 423 | AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, |
| 424 | struct AdbcError *error) { |
| 425 | if (!statement->private_driver) { |
| 426 | return ADBC_STATUS_INVALID_STATE; |
| 427 | } |
| 428 | return statement->private_driver->StatementBind(statement, values, schema, error); |
| 429 | } |
| 430 | |
| 431 | AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, |
| 432 | struct AdbcError *error) { |
| 433 | if (!statement->private_driver) { |
| 434 | return ADBC_STATUS_INVALID_STATE; |
| 435 | } |
| 436 | return statement->private_driver->StatementBindStream(statement, stream, error); |
| 437 | } |
| 438 | |
| 439 | // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' |
| 440 | AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, |
| 441 | struct AdbcPartitions *partitions, int64_t *rows_affected, |
| 442 | struct AdbcError *error) { |
| 443 | if (!statement->private_driver) { |
| 444 | return ADBC_STATUS_INVALID_STATE; |
| 445 | } |
| 446 | return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); |
| 447 | } |
| 448 | |
| 449 | AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, |
| 450 | int64_t *rows_affected, struct AdbcError *error) { |
| 451 | if (!statement) { |
| 452 | return ADBC_STATUS_INVALID_ARGUMENT; |
| 453 | } |
| 454 | if (!statement->private_driver) { |
| 455 | return ADBC_STATUS_INVALID_STATE; |
| 456 | } |
| 457 | return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error); |
| 458 | } |
| 459 | |
| 460 | AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, |
| 461 | struct AdbcError *error) { |
| 462 | if (!statement->private_driver) { |
| 463 | return ADBC_STATUS_INVALID_STATE; |
| 464 | } |
| 465 | return statement->private_driver->StatementGetParameterSchema(statement, schema, error); |
| 466 | } |
| 467 | |
| 468 | AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, |
| 469 | struct AdbcError *error) { |
| 470 | if (!connection) { |
| 471 | return ADBC_STATUS_INVALID_ARGUMENT; |
| 472 | } |
| 473 | if (!connection->private_driver) { |
| 474 | return ADBC_STATUS_INVALID_STATE; |
| 475 | } |
| 476 | auto status = connection->private_driver->StatementNew(connection, statement, error); |
| 477 | statement->private_driver = connection->private_driver; |
| 478 | return status; |
| 479 | } |
| 480 | |
| 481 | AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { |
| 482 | if (!statement->private_driver) { |
| 483 | return ADBC_STATUS_INVALID_STATE; |
| 484 | } |
| 485 | return statement->private_driver->StatementPrepare(statement, error); |
| 486 | } |
| 487 | |
| 488 | AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { |
| 489 | if (!statement->private_driver) { |
| 490 | return ADBC_STATUS_INVALID_STATE; |
| 491 | } |
| 492 | auto status = statement->private_driver->StatementRelease(statement, error); |
| 493 | statement->private_driver = nullptr; |
| 494 | return status; |
| 495 | } |
| 496 | |
| 497 | AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, |
| 498 | struct AdbcError *error) { |
| 499 | if (!statement->private_driver) { |
| 500 | return ADBC_STATUS_INVALID_STATE; |
| 501 | } |
| 502 | return statement->private_driver->StatementSetOption(statement, key, value, error); |
| 503 | } |
| 504 | |
| 505 | AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { |
| 506 | if (!statement->private_driver) { |
| 507 | return ADBC_STATUS_INVALID_STATE; |
| 508 | } |
| 509 | return statement->private_driver->StatementSetSqlQuery(statement, query, error); |
| 510 | } |
| 511 | |
| 512 | AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, |
| 513 | struct AdbcError *error) { |
| 514 | if (!statement->private_driver) { |
| 515 | return ADBC_STATUS_INVALID_STATE; |
| 516 | } |
| 517 | return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); |
| 518 | } |
| 519 | |
| 520 | const char *AdbcStatusCodeMessage(AdbcStatusCode code) { |
| 521 | #define STRINGIFY(s) #s |
| 522 | #define STRINGIFY_VALUE(s) STRINGIFY(s) |
| 523 | #define CASE(CONSTANT) \ |
| 524 | case CONSTANT: \ |
| 525 | return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; |
| 526 | |
| 527 | switch (code) { |
| 528 | CASE(ADBC_STATUS_OK); |
| 529 | CASE(ADBC_STATUS_UNKNOWN); |
| 530 | CASE(ADBC_STATUS_NOT_IMPLEMENTED); |
| 531 | CASE(ADBC_STATUS_NOT_FOUND); |
| 532 | CASE(ADBC_STATUS_ALREADY_EXISTS); |
| 533 | CASE(ADBC_STATUS_INVALID_ARGUMENT); |
| 534 | CASE(ADBC_STATUS_INVALID_STATE); |
| 535 | CASE(ADBC_STATUS_INVALID_DATA); |
| 536 | CASE(ADBC_STATUS_INTEGRITY); |
| 537 | CASE(ADBC_STATUS_INTERNAL); |
| 538 | CASE(ADBC_STATUS_IO); |
| 539 | CASE(ADBC_STATUS_CANCELLED); |
| 540 | CASE(ADBC_STATUS_TIMEOUT); |
| 541 | CASE(ADBC_STATUS_UNAUTHENTICATED); |
| 542 | CASE(ADBC_STATUS_UNAUTHORIZED); |
| 543 | default: |
| 544 | return "(invalid code)" ; |
| 545 | } |
| 546 | #undef CASE |
| 547 | #undef STRINGIFY_VALUE |
| 548 | #undef STRINGIFY |
| 549 | } |
| 550 | |
| 551 | AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, |
| 552 | struct AdbcError *error) { |
| 553 | AdbcDriverInitFunc init_func; |
| 554 | std::string error_message; |
| 555 | |
| 556 | if (version != ADBC_VERSION_1_0_0) { |
| 557 | SetError(error, message: "Only ADBC 1.0.0 is supported" ); |
| 558 | return ADBC_STATUS_NOT_IMPLEMENTED; |
| 559 | } |
| 560 | |
| 561 | auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver); |
| 562 | |
| 563 | if (!entrypoint) { |
| 564 | // Default entrypoint (see adbc.h) |
| 565 | entrypoint = "AdbcDriverInit" ; |
| 566 | } |
| 567 | |
| 568 | #if defined(_WIN32) |
| 569 | |
| 570 | HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); |
| 571 | if (!handle) { |
| 572 | error_message += driver_name; |
| 573 | error_message += ": LoadLibraryExA() failed: " ; |
| 574 | GetWinError(&error_message); |
| 575 | |
| 576 | std::string full_driver_name = driver_name; |
| 577 | full_driver_name += ".lib" ; |
| 578 | handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); |
| 579 | if (!handle) { |
| 580 | error_message += '\n'; |
| 581 | error_message += full_driver_name; |
| 582 | error_message += ": LoadLibraryExA() failed: " ; |
| 583 | GetWinError(&error_message); |
| 584 | } |
| 585 | } |
| 586 | if (!handle) { |
| 587 | SetError(error, error_message); |
| 588 | return ADBC_STATUS_INTERNAL; |
| 589 | } |
| 590 | |
| 591 | void *load_handle = reinterpret_cast<void *>(GetProcAddress(handle, entrypoint)); |
| 592 | init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
| 593 | if (!init_func) { |
| 594 | std::string message = "GetProcAddress(" ; |
| 595 | message += entrypoint; |
| 596 | message += ") failed: " ; |
| 597 | GetWinError(&message); |
| 598 | if (!FreeLibrary(handle)) { |
| 599 | message += "\nFreeLibrary() failed: " ; |
| 600 | GetWinError(&message); |
| 601 | } |
| 602 | SetError(error, message); |
| 603 | return ADBC_STATUS_INTERNAL; |
| 604 | } |
| 605 | |
| 606 | #else |
| 607 | |
| 608 | #if defined(__APPLE__) |
| 609 | const std::string kPlatformLibraryPrefix = "lib" ; |
| 610 | const std::string kPlatformLibrarySuffix = ".dylib" ; |
| 611 | #else |
| 612 | const std::string kPlatformLibraryPrefix = "lib" ; |
| 613 | const std::string kPlatformLibrarySuffix = ".so" ; |
| 614 | #endif // defined(__APPLE__) |
| 615 | |
| 616 | void *handle = dlopen(file: driver_name, RTLD_NOW | RTLD_LOCAL); |
| 617 | if (!handle) { |
| 618 | error_message = "dlopen() failed: " ; |
| 619 | error_message += dlerror(); |
| 620 | |
| 621 | // If applicable, append the shared library prefix/extension and |
| 622 | // try again (this way you don't have to hardcode driver names by |
| 623 | // platform in the application) |
| 624 | const std::string driver_str = driver_name; |
| 625 | |
| 626 | std::string full_driver_name; |
| 627 | if (driver_str.size() < kPlatformLibraryPrefix.size() || |
| 628 | driver_str.compare(pos: 0, n: kPlatformLibraryPrefix.size(), str: kPlatformLibraryPrefix) != 0) { |
| 629 | full_driver_name += kPlatformLibraryPrefix; |
| 630 | } |
| 631 | full_driver_name += driver_name; |
| 632 | if (driver_str.size() < kPlatformLibrarySuffix.size() || |
| 633 | driver_str.compare(pos: full_driver_name.size() - kPlatformLibrarySuffix.size(), n: kPlatformLibrarySuffix.size(), |
| 634 | str: kPlatformLibrarySuffix) != 0) { |
| 635 | full_driver_name += kPlatformLibrarySuffix; |
| 636 | } |
| 637 | handle = dlopen(file: full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); |
| 638 | if (!handle) { |
| 639 | error_message += "\ndlopen() failed: " ; |
| 640 | error_message += dlerror(); |
| 641 | } |
| 642 | } |
| 643 | if (!handle) { |
| 644 | SetError(error, message: error_message); |
| 645 | // AdbcDatabaseInit tries to call this if set |
| 646 | driver->release = nullptr; |
| 647 | return ADBC_STATUS_INTERNAL; |
| 648 | } |
| 649 | |
| 650 | void *load_handle = dlsym(handle: handle, name: entrypoint); |
| 651 | if (!load_handle) { |
| 652 | std::string message = "dlsym(" ; |
| 653 | message += entrypoint; |
| 654 | message += ") failed: " ; |
| 655 | message += dlerror(); |
| 656 | SetError(error, message); |
| 657 | return ADBC_STATUS_INTERNAL; |
| 658 | } |
| 659 | init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
| 660 | |
| 661 | #endif // defined(_WIN32) |
| 662 | |
| 663 | AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); |
| 664 | if (status == ADBC_STATUS_OK) { |
| 665 | ManagerDriverState *state = new ManagerDriverState; |
| 666 | state->driver_release = driver->release; |
| 667 | #if defined(_WIN32) |
| 668 | state->handle = handle; |
| 669 | #endif // defined(_WIN32) |
| 670 | driver->release = &ReleaseDriver; |
| 671 | driver->private_manager = state; |
| 672 | } else { |
| 673 | #if defined(_WIN32) |
| 674 | if (!FreeLibrary(handle)) { |
| 675 | std::string message = "FreeLibrary() failed: " ; |
| 676 | GetWinError(&message); |
| 677 | SetError(error, message); |
| 678 | } |
| 679 | #endif // defined(_WIN32) |
| 680 | } |
| 681 | return status; |
| 682 | } |
| 683 | |
| 684 | AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, |
| 685 | struct AdbcError *error) { |
| 686 | #define FILL_DEFAULT(DRIVER, STUB) \ |
| 687 | if (!DRIVER->STUB) { \ |
| 688 | DRIVER->STUB = &STUB; \ |
| 689 | } |
| 690 | #define CHECK_REQUIRED(DRIVER, STUB) \ |
| 691 | if (!DRIVER->STUB) { \ |
| 692 | SetError(error, "Driver does not implement required function Adbc" #STUB); \ |
| 693 | return ADBC_STATUS_INTERNAL; \ |
| 694 | } |
| 695 | |
| 696 | auto result = init_func(version, raw_driver, error); |
| 697 | if (result != ADBC_STATUS_OK) { |
| 698 | return result; |
| 699 | } |
| 700 | |
| 701 | if (version == ADBC_VERSION_1_0_0) { |
| 702 | auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver); |
| 703 | CHECK_REQUIRED(driver, DatabaseNew); |
| 704 | CHECK_REQUIRED(driver, DatabaseInit); |
| 705 | CHECK_REQUIRED(driver, DatabaseRelease); |
| 706 | FILL_DEFAULT(driver, DatabaseSetOption); |
| 707 | |
| 708 | CHECK_REQUIRED(driver, ConnectionNew); |
| 709 | CHECK_REQUIRED(driver, ConnectionInit); |
| 710 | CHECK_REQUIRED(driver, ConnectionRelease); |
| 711 | FILL_DEFAULT(driver, ConnectionCommit); |
| 712 | FILL_DEFAULT(driver, ConnectionGetInfo); |
| 713 | FILL_DEFAULT(driver, ConnectionGetObjects); |
| 714 | FILL_DEFAULT(driver, ConnectionGetTableSchema); |
| 715 | FILL_DEFAULT(driver, ConnectionGetTableTypes); |
| 716 | FILL_DEFAULT(driver, ConnectionReadPartition); |
| 717 | FILL_DEFAULT(driver, ConnectionRollback); |
| 718 | FILL_DEFAULT(driver, ConnectionSetOption); |
| 719 | |
| 720 | FILL_DEFAULT(driver, StatementExecutePartitions); |
| 721 | CHECK_REQUIRED(driver, StatementExecuteQuery); |
| 722 | CHECK_REQUIRED(driver, StatementNew); |
| 723 | CHECK_REQUIRED(driver, StatementRelease); |
| 724 | FILL_DEFAULT(driver, StatementBind); |
| 725 | FILL_DEFAULT(driver, StatementGetParameterSchema); |
| 726 | FILL_DEFAULT(driver, StatementPrepare); |
| 727 | FILL_DEFAULT(driver, StatementSetOption); |
| 728 | FILL_DEFAULT(driver, StatementSetSqlQuery); |
| 729 | FILL_DEFAULT(driver, StatementSetSubstraitPlan); |
| 730 | } |
| 731 | |
| 732 | return ADBC_STATUS_OK; |
| 733 | |
| 734 | #undef FILL_DEFAULT |
| 735 | #undef CHECK_REQUIRED |
| 736 | } |
| 737 | } // namespace duckdb_adbc |
| 738 | |