1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "duckdb/common/adbc/driver_manager.h"
19#include "duckdb/common/adbc/adbc.h"
20#include "duckdb/common/adbc/adbc.hpp"
21
22#include <algorithm>
23#include <cstring>
24#include <string>
25#include <unordered_map>
26#include <utility>
27
28#if defined(_WIN32)
29#include <windows.h> // Must come first
30
31#include <libloaderapi.h>
32#include <strsafe.h>
33#else
34#include <dlfcn.h>
35#endif // defined(_WIN32)
36
37namespace duckdb_adbc {
38
39// Platform-specific helpers
40
41#if defined(_WIN32)
42/// Append a description of the Windows error to the buffer.
43void GetWinError(std::string *buffer) {
44 DWORD rc = GetLastError();
45 LPVOID message;
46
47 FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
48 /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
49 reinterpret_cast<LPSTR>(&message), /*nSize=*/0, /*Arguments=*/nullptr);
50
51 (*buffer) += '(';
52 (*buffer) += std::to_string(rc);
53 (*buffer) += ") ";
54 (*buffer) += reinterpret_cast<char *>(message);
55 LocalFree(message);
56}
57
58#endif // defined(_WIN32)
59
60// Error handling
61
62void ReleaseError(struct AdbcError *error) {
63 if (error) {
64 if (error->message) {
65 delete[] error->message;
66 }
67 error->message = nullptr;
68 error->release = nullptr;
69 }
70}
71
72void SetError(struct AdbcError *error, const std::string &message) {
73 if (!error) {
74 return;
75 }
76 if (error->message) {
77 // Append
78 std::string buffer = error->message;
79 buffer.reserve(res_arg: buffer.size() + message.size() + 1);
80 buffer += '\n';
81 buffer += message;
82 error->release(error);
83
84 error->message = new char[buffer.size() + 1];
85 buffer.copy(s: error->message, n: buffer.size());
86 error->message[buffer.size()] = '\0';
87 } else {
88 error->message = new char[message.size() + 1];
89 message.copy(s: error->message, n: message.size());
90 error->message[message.size()] = '\0';
91 }
92 error->release = ReleaseError;
93}
94
95// Driver state
96
97/// Hold the driver DLL and the driver release callback in the driver struct.
98struct ManagerDriverState {
99 // The original release callback
100 AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error);
101
102#if defined(_WIN32)
103 // The loaded DLL
104 HMODULE handle;
105#endif // defined(_WIN32)
106};
107
108/// Unload the driver DLL.
109static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) {
110 AdbcStatusCode status = ADBC_STATUS_OK;
111
112 if (!driver->private_manager)
113 return status;
114 ManagerDriverState *state = reinterpret_cast<ManagerDriverState *>(driver->private_manager);
115
116 if (state->driver_release) {
117 status = state->driver_release(driver, error);
118 }
119
120#if defined(_WIN32)
121 // TODO(apache/arrow-adbc#204): causes tests to segfault
122 // if (!FreeLibrary(state->handle)) {
123 // std::string message = "FreeLibrary() failed: ";
124 // GetWinError(&message);
125 // SetError(error, message);
126 // }
127#endif // defined(_WIN32)
128
129 driver->private_manager = nullptr;
130 delete state;
131 return status;
132}
133
134// Default stubs
135
136AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length,
137 struct ArrowArrayStream *out, struct AdbcError *error) {
138 return ADBC_STATUS_NOT_IMPLEMENTED;
139}
140
141AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *, const char *, const char *, const char *,
142 struct ArrowSchema *, struct AdbcError *error) {
143 return ADBC_STATUS_NOT_IMPLEMENTED;
144}
145
146AdbcStatusCode StatementBind(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *,
147 struct AdbcError *error) {
148 return ADBC_STATUS_NOT_IMPLEMENTED;
149}
150
151AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema,
152 struct AdbcError *error) {
153 return ADBC_STATUS_NOT_IMPLEMENTED;
154}
155AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *error) {
156 return ADBC_STATUS_NOT_IMPLEMENTED;
157}
158
159/// Temporary state while the database is being configured.
160struct TempDatabase {
161 std::unordered_map<std::string, std::string> options;
162 std::string driver;
163 // Default name (see adbc.h)
164 std::string entrypoint = "AdbcDriverInit";
165 AdbcDriverInitFunc init_func = nullptr;
166};
167
168/// Temporary state while the database is being configured.
169struct TempConnection {
170 std::unordered_map<std::string, std::string> options;
171};
172
173// Direct implementations of API methods
174
175AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) {
176 // Allocate a temporary structure to store options pre-Init
177 database->private_data = new TempDatabase();
178 database->private_driver = nullptr;
179 return ADBC_STATUS_OK;
180}
181
182AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value,
183 struct AdbcError *error) {
184 if (database->private_driver) {
185 return database->private_driver->DatabaseSetOption(database, key, value, error);
186 }
187
188 TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data);
189 if (std::strcmp(s1: key, s2: "driver") == 0) {
190 args->driver = value;
191 } else if (std::strcmp(s1: key, s2: "entrypoint") == 0) {
192 args->entrypoint = value;
193 } else {
194 args->options[key] = value;
195 }
196 return ADBC_STATUS_OK;
197}
198
199AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func,
200 struct AdbcError *error) {
201 if (database->private_driver) {
202 return ADBC_STATUS_INVALID_STATE;
203 }
204
205 TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data);
206 args->init_func = init_func;
207 return ADBC_STATUS_OK;
208}
209
210AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) {
211 if (!database->private_data) {
212 SetError(error, message: "Must call AdbcDatabaseNew first");
213 return ADBC_STATUS_INVALID_STATE;
214 }
215 TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data);
216 if (args->init_func) {
217 // Do nothing
218 } else if (args->driver.empty()) {
219 SetError(error, message: "Must provide 'driver' parameter");
220 return ADBC_STATUS_INVALID_ARGUMENT;
221 }
222
223 database->private_driver = new AdbcDriver;
224 std::memset(s: database->private_driver, c: 0, n: sizeof(AdbcDriver));
225 AdbcStatusCode status;
226 // So we don't confuse a driver into thinking it's initialized already
227 database->private_data = nullptr;
228 if (args->init_func) {
229 status = AdbcLoadDriverFromInitFunc(init_func: args->init_func, ADBC_VERSION_1_0_0, driver: database->private_driver, error);
230 } else {
231 status = AdbcLoadDriver(driver_name: args->driver.c_str(), entrypoint: args->entrypoint.c_str(), ADBC_VERSION_1_0_0,
232 driver: database->private_driver, error);
233 }
234 if (status != ADBC_STATUS_OK) {
235 // Restore private_data so it will be released by AdbcDatabaseRelease
236 database->private_data = args;
237 if (database->private_driver->release) {
238 database->private_driver->release(database->private_driver, error);
239 }
240 delete database->private_driver;
241 database->private_driver = nullptr;
242 return status;
243 }
244 status = database->private_driver->DatabaseNew(database, error);
245 if (status != ADBC_STATUS_OK) {
246 if (database->private_driver->release) {
247 database->private_driver->release(database->private_driver, error);
248 }
249 delete database->private_driver;
250 database->private_driver = nullptr;
251 return status;
252 }
253 for (const auto &option : args->options) {
254 status =
255 database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error);
256 if (status != ADBC_STATUS_OK) {
257 delete args;
258 // Release the database
259 std::ignore = database->private_driver->DatabaseRelease(database, error);
260 if (database->private_driver->release) {
261 database->private_driver->release(database->private_driver, error);
262 }
263 delete database->private_driver;
264 database->private_driver = nullptr;
265 // Should be redundant, but ensure that AdbcDatabaseRelease
266 // below doesn't think that it contains a TempDatabase
267 database->private_data = nullptr;
268 return status;
269 }
270 }
271 delete args;
272 return database->private_driver->DatabaseInit(database, error);
273}
274
275AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) {
276 if (!database->private_driver) {
277 if (database->private_data) {
278 TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data);
279 delete args;
280 database->private_data = nullptr;
281 return ADBC_STATUS_OK;
282 }
283 return ADBC_STATUS_INVALID_STATE;
284 }
285 auto status = database->private_driver->DatabaseRelease(database, error);
286 if (database->private_driver->release) {
287 database->private_driver->release(database->private_driver, error);
288 }
289 delete database->private_driver;
290 database->private_data = nullptr;
291 database->private_driver = nullptr;
292 return status;
293}
294
295AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) {
296 if (!connection->private_driver) {
297 return ADBC_STATUS_INVALID_STATE;
298 }
299 return connection->private_driver->ConnectionCommit(connection, error);
300}
301
302AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length,
303 struct ArrowArrayStream *out, struct AdbcError *error) {
304 if (!connection->private_driver) {
305 return ADBC_STATUS_INVALID_STATE;
306 }
307 return connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error);
308}
309
310AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog,
311 const char *db_schema, const char *table_name, const char **table_types,
312 const char *column_name, struct ArrowArrayStream *stream,
313 struct AdbcError *error) {
314 if (!connection->private_driver) {
315 return ADBC_STATUS_INVALID_STATE;
316 }
317 return connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
318 table_types, column_name, stream, error);
319}
320
321AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog,
322 const char *db_schema, const char *table_name, struct ArrowSchema *schema,
323 struct AdbcError *error) {
324 if (!connection->private_driver) {
325 return ADBC_STATUS_INVALID_STATE;
326 }
327 return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema,
328 error);
329}
330
331AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream,
332 struct AdbcError *error) {
333 if (!connection->private_driver) {
334 return ADBC_STATUS_INVALID_STATE;
335 }
336 return connection->private_driver->ConnectionGetTableTypes(connection, stream, error);
337}
338
339AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database,
340 struct AdbcError *error) {
341 if (!connection->private_data) {
342 SetError(error, message: "Must call AdbcConnectionNew first");
343 return ADBC_STATUS_INVALID_STATE;
344 } else if (!database->private_driver) {
345 SetError(error, message: "Database is not initialized");
346 return ADBC_STATUS_INVALID_ARGUMENT;
347 }
348 TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data);
349 connection->private_data = nullptr;
350 std::unordered_map<std::string, std::string> options = std::move(args->options);
351 delete args;
352
353 auto status = database->private_driver->ConnectionNew(connection, error);
354 if (status != ADBC_STATUS_OK)
355 return status;
356 connection->private_driver = database->private_driver;
357
358 for (const auto &option : options) {
359 status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(),
360 error);
361 if (status != ADBC_STATUS_OK)
362 return status;
363 }
364 return connection->private_driver->ConnectionInit(connection, database, error);
365}
366
367AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) {
368 // Allocate a temporary structure to store options pre-Init, because
369 // we don't get access to the database (and hence the driver
370 // function table) until then
371 connection->private_data = new TempConnection;
372 connection->private_driver = nullptr;
373 return ADBC_STATUS_OK;
374}
375
376AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition,
377 size_t serialized_length, struct ArrowArrayStream *out,
378 struct AdbcError *error) {
379 if (!connection->private_driver) {
380 return ADBC_STATUS_INVALID_STATE;
381 }
382 return connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, out,
383 error);
384}
385
386AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) {
387 if (!connection->private_driver) {
388 if (connection->private_data) {
389 TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data);
390 delete args;
391 connection->private_data = nullptr;
392 return ADBC_STATUS_OK;
393 }
394 return ADBC_STATUS_INVALID_STATE;
395 }
396 auto status = connection->private_driver->ConnectionRelease(connection, error);
397 connection->private_driver = nullptr;
398 return status;
399}
400
401AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) {
402 if (!connection->private_driver) {
403 return ADBC_STATUS_INVALID_STATE;
404 }
405 return connection->private_driver->ConnectionRollback(connection, error);
406}
407
408AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value,
409 struct AdbcError *error) {
410 if (!connection->private_data) {
411 SetError(error, message: "AdbcConnectionSetOption: must AdbcConnectionNew first");
412 return ADBC_STATUS_INVALID_STATE;
413 }
414 if (!connection->private_driver) {
415 // Init not yet called, save the option
416 TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data);
417 args->options[key] = value;
418 return ADBC_STATUS_OK;
419 }
420 return connection->private_driver->ConnectionSetOption(connection, key, value, error);
421}
422
423AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema,
424 struct AdbcError *error) {
425 if (!statement->private_driver) {
426 return ADBC_STATUS_INVALID_STATE;
427 }
428 return statement->private_driver->StatementBind(statement, values, schema, error);
429}
430
431AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream,
432 struct AdbcError *error) {
433 if (!statement->private_driver) {
434 return ADBC_STATUS_INVALID_STATE;
435 }
436 return statement->private_driver->StatementBindStream(statement, stream, error);
437}
438
439// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema'
440AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema,
441 struct AdbcPartitions *partitions, int64_t *rows_affected,
442 struct AdbcError *error) {
443 if (!statement->private_driver) {
444 return ADBC_STATUS_INVALID_STATE;
445 }
446 return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error);
447}
448
449AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out,
450 int64_t *rows_affected, struct AdbcError *error) {
451 if (!statement) {
452 return ADBC_STATUS_INVALID_ARGUMENT;
453 }
454 if (!statement->private_driver) {
455 return ADBC_STATUS_INVALID_STATE;
456 }
457 return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error);
458}
459
460AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema,
461 struct AdbcError *error) {
462 if (!statement->private_driver) {
463 return ADBC_STATUS_INVALID_STATE;
464 }
465 return statement->private_driver->StatementGetParameterSchema(statement, schema, error);
466}
467
468AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement,
469 struct AdbcError *error) {
470 if (!connection) {
471 return ADBC_STATUS_INVALID_ARGUMENT;
472 }
473 if (!connection->private_driver) {
474 return ADBC_STATUS_INVALID_STATE;
475 }
476 auto status = connection->private_driver->StatementNew(connection, statement, error);
477 statement->private_driver = connection->private_driver;
478 return status;
479}
480
481AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) {
482 if (!statement->private_driver) {
483 return ADBC_STATUS_INVALID_STATE;
484 }
485 return statement->private_driver->StatementPrepare(statement, error);
486}
487
488AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) {
489 if (!statement->private_driver) {
490 return ADBC_STATUS_INVALID_STATE;
491 }
492 auto status = statement->private_driver->StatementRelease(statement, error);
493 statement->private_driver = nullptr;
494 return status;
495}
496
497AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value,
498 struct AdbcError *error) {
499 if (!statement->private_driver) {
500 return ADBC_STATUS_INVALID_STATE;
501 }
502 return statement->private_driver->StatementSetOption(statement, key, value, error);
503}
504
505AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) {
506 if (!statement->private_driver) {
507 return ADBC_STATUS_INVALID_STATE;
508 }
509 return statement->private_driver->StatementSetSqlQuery(statement, query, error);
510}
511
512AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length,
513 struct AdbcError *error) {
514 if (!statement->private_driver) {
515 return ADBC_STATUS_INVALID_STATE;
516 }
517 return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error);
518}
519
520const char *AdbcStatusCodeMessage(AdbcStatusCode code) {
521#define STRINGIFY(s) #s
522#define STRINGIFY_VALUE(s) STRINGIFY(s)
523#define CASE(CONSTANT) \
524 case CONSTANT: \
525 return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")";
526
527 switch (code) {
528 CASE(ADBC_STATUS_OK);
529 CASE(ADBC_STATUS_UNKNOWN);
530 CASE(ADBC_STATUS_NOT_IMPLEMENTED);
531 CASE(ADBC_STATUS_NOT_FOUND);
532 CASE(ADBC_STATUS_ALREADY_EXISTS);
533 CASE(ADBC_STATUS_INVALID_ARGUMENT);
534 CASE(ADBC_STATUS_INVALID_STATE);
535 CASE(ADBC_STATUS_INVALID_DATA);
536 CASE(ADBC_STATUS_INTEGRITY);
537 CASE(ADBC_STATUS_INTERNAL);
538 CASE(ADBC_STATUS_IO);
539 CASE(ADBC_STATUS_CANCELLED);
540 CASE(ADBC_STATUS_TIMEOUT);
541 CASE(ADBC_STATUS_UNAUTHENTICATED);
542 CASE(ADBC_STATUS_UNAUTHORIZED);
543 default:
544 return "(invalid code)";
545 }
546#undef CASE
547#undef STRINGIFY_VALUE
548#undef STRINGIFY
549}
550
551AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver,
552 struct AdbcError *error) {
553 AdbcDriverInitFunc init_func;
554 std::string error_message;
555
556 if (version != ADBC_VERSION_1_0_0) {
557 SetError(error, message: "Only ADBC 1.0.0 is supported");
558 return ADBC_STATUS_NOT_IMPLEMENTED;
559 }
560
561 auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver);
562
563 if (!entrypoint) {
564 // Default entrypoint (see adbc.h)
565 entrypoint = "AdbcDriverInit";
566 }
567
568#if defined(_WIN32)
569
570 HMODULE handle = LoadLibraryExA(driver_name, NULL, 0);
571 if (!handle) {
572 error_message += driver_name;
573 error_message += ": LoadLibraryExA() failed: ";
574 GetWinError(&error_message);
575
576 std::string full_driver_name = driver_name;
577 full_driver_name += ".lib";
578 handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0);
579 if (!handle) {
580 error_message += '\n';
581 error_message += full_driver_name;
582 error_message += ": LoadLibraryExA() failed: ";
583 GetWinError(&error_message);
584 }
585 }
586 if (!handle) {
587 SetError(error, error_message);
588 return ADBC_STATUS_INTERNAL;
589 }
590
591 void *load_handle = reinterpret_cast<void *>(GetProcAddress(handle, entrypoint));
592 init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle);
593 if (!init_func) {
594 std::string message = "GetProcAddress(";
595 message += entrypoint;
596 message += ") failed: ";
597 GetWinError(&message);
598 if (!FreeLibrary(handle)) {
599 message += "\nFreeLibrary() failed: ";
600 GetWinError(&message);
601 }
602 SetError(error, message);
603 return ADBC_STATUS_INTERNAL;
604 }
605
606#else
607
608#if defined(__APPLE__)
609 const std::string kPlatformLibraryPrefix = "lib";
610 const std::string kPlatformLibrarySuffix = ".dylib";
611#else
612 const std::string kPlatformLibraryPrefix = "lib";
613 const std::string kPlatformLibrarySuffix = ".so";
614#endif // defined(__APPLE__)
615
616 void *handle = dlopen(file: driver_name, RTLD_NOW | RTLD_LOCAL);
617 if (!handle) {
618 error_message = "dlopen() failed: ";
619 error_message += dlerror();
620
621 // If applicable, append the shared library prefix/extension and
622 // try again (this way you don't have to hardcode driver names by
623 // platform in the application)
624 const std::string driver_str = driver_name;
625
626 std::string full_driver_name;
627 if (driver_str.size() < kPlatformLibraryPrefix.size() ||
628 driver_str.compare(pos: 0, n: kPlatformLibraryPrefix.size(), str: kPlatformLibraryPrefix) != 0) {
629 full_driver_name += kPlatformLibraryPrefix;
630 }
631 full_driver_name += driver_name;
632 if (driver_str.size() < kPlatformLibrarySuffix.size() ||
633 driver_str.compare(pos: full_driver_name.size() - kPlatformLibrarySuffix.size(), n: kPlatformLibrarySuffix.size(),
634 str: kPlatformLibrarySuffix) != 0) {
635 full_driver_name += kPlatformLibrarySuffix;
636 }
637 handle = dlopen(file: full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL);
638 if (!handle) {
639 error_message += "\ndlopen() failed: ";
640 error_message += dlerror();
641 }
642 }
643 if (!handle) {
644 SetError(error, message: error_message);
645 // AdbcDatabaseInit tries to call this if set
646 driver->release = nullptr;
647 return ADBC_STATUS_INTERNAL;
648 }
649
650 void *load_handle = dlsym(handle: handle, name: entrypoint);
651 if (!load_handle) {
652 std::string message = "dlsym(";
653 message += entrypoint;
654 message += ") failed: ";
655 message += dlerror();
656 SetError(error, message);
657 return ADBC_STATUS_INTERNAL;
658 }
659 init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle);
660
661#endif // defined(_WIN32)
662
663 AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error);
664 if (status == ADBC_STATUS_OK) {
665 ManagerDriverState *state = new ManagerDriverState;
666 state->driver_release = driver->release;
667#if defined(_WIN32)
668 state->handle = handle;
669#endif // defined(_WIN32)
670 driver->release = &ReleaseDriver;
671 driver->private_manager = state;
672 } else {
673#if defined(_WIN32)
674 if (!FreeLibrary(handle)) {
675 std::string message = "FreeLibrary() failed: ";
676 GetWinError(&message);
677 SetError(error, message);
678 }
679#endif // defined(_WIN32)
680 }
681 return status;
682}
683
684AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver,
685 struct AdbcError *error) {
686#define FILL_DEFAULT(DRIVER, STUB) \
687 if (!DRIVER->STUB) { \
688 DRIVER->STUB = &STUB; \
689 }
690#define CHECK_REQUIRED(DRIVER, STUB) \
691 if (!DRIVER->STUB) { \
692 SetError(error, "Driver does not implement required function Adbc" #STUB); \
693 return ADBC_STATUS_INTERNAL; \
694 }
695
696 auto result = init_func(version, raw_driver, error);
697 if (result != ADBC_STATUS_OK) {
698 return result;
699 }
700
701 if (version == ADBC_VERSION_1_0_0) {
702 auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver);
703 CHECK_REQUIRED(driver, DatabaseNew);
704 CHECK_REQUIRED(driver, DatabaseInit);
705 CHECK_REQUIRED(driver, DatabaseRelease);
706 FILL_DEFAULT(driver, DatabaseSetOption);
707
708 CHECK_REQUIRED(driver, ConnectionNew);
709 CHECK_REQUIRED(driver, ConnectionInit);
710 CHECK_REQUIRED(driver, ConnectionRelease);
711 FILL_DEFAULT(driver, ConnectionCommit);
712 FILL_DEFAULT(driver, ConnectionGetInfo);
713 FILL_DEFAULT(driver, ConnectionGetObjects);
714 FILL_DEFAULT(driver, ConnectionGetTableSchema);
715 FILL_DEFAULT(driver, ConnectionGetTableTypes);
716 FILL_DEFAULT(driver, ConnectionReadPartition);
717 FILL_DEFAULT(driver, ConnectionRollback);
718 FILL_DEFAULT(driver, ConnectionSetOption);
719
720 FILL_DEFAULT(driver, StatementExecutePartitions);
721 CHECK_REQUIRED(driver, StatementExecuteQuery);
722 CHECK_REQUIRED(driver, StatementNew);
723 CHECK_REQUIRED(driver, StatementRelease);
724 FILL_DEFAULT(driver, StatementBind);
725 FILL_DEFAULT(driver, StatementGetParameterSchema);
726 FILL_DEFAULT(driver, StatementPrepare);
727 FILL_DEFAULT(driver, StatementSetOption);
728 FILL_DEFAULT(driver, StatementSetSqlQuery);
729 FILL_DEFAULT(driver, StatementSetSubstraitPlan);
730 }
731
732 return ADBC_STATUS_OK;
733
734#undef FILL_DEFAULT
735#undef CHECK_REQUIRED
736}
737} // namespace duckdb_adbc
738