1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "duckdb/common/adbc/driver_manager.h" |
19 | #include "duckdb/common/adbc/adbc.h" |
20 | #include "duckdb/common/adbc/adbc.hpp" |
21 | |
22 | #include <algorithm> |
23 | #include <cstring> |
24 | #include <string> |
25 | #include <unordered_map> |
26 | #include <utility> |
27 | |
28 | #if defined(_WIN32) |
29 | #include <windows.h> // Must come first |
30 | |
31 | #include <libloaderapi.h> |
32 | #include <strsafe.h> |
33 | #else |
34 | #include <dlfcn.h> |
35 | #endif // defined(_WIN32) |
36 | |
37 | namespace duckdb_adbc { |
38 | |
39 | // Platform-specific helpers |
40 | |
41 | #if defined(_WIN32) |
42 | /// Append a description of the Windows error to the buffer. |
43 | void GetWinError(std::string *buffer) { |
44 | DWORD rc = GetLastError(); |
45 | LPVOID message; |
46 | |
47 | FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, |
48 | /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), |
49 | reinterpret_cast<LPSTR>(&message), /*nSize=*/0, /*Arguments=*/nullptr); |
50 | |
51 | (*buffer) += '('; |
52 | (*buffer) += std::to_string(rc); |
53 | (*buffer) += ") " ; |
54 | (*buffer) += reinterpret_cast<char *>(message); |
55 | LocalFree(message); |
56 | } |
57 | |
58 | #endif // defined(_WIN32) |
59 | |
60 | // Error handling |
61 | |
62 | void ReleaseError(struct AdbcError *error) { |
63 | if (error) { |
64 | if (error->message) { |
65 | delete[] error->message; |
66 | } |
67 | error->message = nullptr; |
68 | error->release = nullptr; |
69 | } |
70 | } |
71 | |
72 | void SetError(struct AdbcError *error, const std::string &message) { |
73 | if (!error) { |
74 | return; |
75 | } |
76 | if (error->message) { |
77 | // Append |
78 | std::string buffer = error->message; |
79 | buffer.reserve(res_arg: buffer.size() + message.size() + 1); |
80 | buffer += '\n'; |
81 | buffer += message; |
82 | error->release(error); |
83 | |
84 | error->message = new char[buffer.size() + 1]; |
85 | buffer.copy(s: error->message, n: buffer.size()); |
86 | error->message[buffer.size()] = '\0'; |
87 | } else { |
88 | error->message = new char[message.size() + 1]; |
89 | message.copy(s: error->message, n: message.size()); |
90 | error->message[message.size()] = '\0'; |
91 | } |
92 | error->release = ReleaseError; |
93 | } |
94 | |
95 | // Driver state |
96 | |
97 | /// Hold the driver DLL and the driver release callback in the driver struct. |
98 | struct ManagerDriverState { |
99 | // The original release callback |
100 | AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); |
101 | |
102 | #if defined(_WIN32) |
103 | // The loaded DLL |
104 | HMODULE handle; |
105 | #endif // defined(_WIN32) |
106 | }; |
107 | |
108 | /// Unload the driver DLL. |
109 | static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { |
110 | AdbcStatusCode status = ADBC_STATUS_OK; |
111 | |
112 | if (!driver->private_manager) |
113 | return status; |
114 | ManagerDriverState *state = reinterpret_cast<ManagerDriverState *>(driver->private_manager); |
115 | |
116 | if (state->driver_release) { |
117 | status = state->driver_release(driver, error); |
118 | } |
119 | |
120 | #if defined(_WIN32) |
121 | // TODO(apache/arrow-adbc#204): causes tests to segfault |
122 | // if (!FreeLibrary(state->handle)) { |
123 | // std::string message = "FreeLibrary() failed: "; |
124 | // GetWinError(&message); |
125 | // SetError(error, message); |
126 | // } |
127 | #endif // defined(_WIN32) |
128 | |
129 | driver->private_manager = nullptr; |
130 | delete state; |
131 | return status; |
132 | } |
133 | |
134 | // Default stubs |
135 | |
136 | AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, |
137 | struct ArrowArrayStream *out, struct AdbcError *error) { |
138 | return ADBC_STATUS_NOT_IMPLEMENTED; |
139 | } |
140 | |
141 | AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *, const char *, const char *, const char *, |
142 | struct ArrowSchema *, struct AdbcError *error) { |
143 | return ADBC_STATUS_NOT_IMPLEMENTED; |
144 | } |
145 | |
146 | AdbcStatusCode StatementBind(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *, |
147 | struct AdbcError *error) { |
148 | return ADBC_STATUS_NOT_IMPLEMENTED; |
149 | } |
150 | |
151 | AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, |
152 | struct AdbcError *error) { |
153 | return ADBC_STATUS_NOT_IMPLEMENTED; |
154 | } |
155 | AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *error) { |
156 | return ADBC_STATUS_NOT_IMPLEMENTED; |
157 | } |
158 | |
159 | /// Temporary state while the database is being configured. |
160 | struct TempDatabase { |
161 | std::unordered_map<std::string, std::string> options; |
162 | std::string driver; |
163 | // Default name (see adbc.h) |
164 | std::string entrypoint = "AdbcDriverInit" ; |
165 | AdbcDriverInitFunc init_func = nullptr; |
166 | }; |
167 | |
168 | /// Temporary state while the database is being configured. |
169 | struct TempConnection { |
170 | std::unordered_map<std::string, std::string> options; |
171 | }; |
172 | |
173 | // Direct implementations of API methods |
174 | |
175 | AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { |
176 | // Allocate a temporary structure to store options pre-Init |
177 | database->private_data = new TempDatabase(); |
178 | database->private_driver = nullptr; |
179 | return ADBC_STATUS_OK; |
180 | } |
181 | |
182 | AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, |
183 | struct AdbcError *error) { |
184 | if (database->private_driver) { |
185 | return database->private_driver->DatabaseSetOption(database, key, value, error); |
186 | } |
187 | |
188 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
189 | if (std::strcmp(s1: key, s2: "driver" ) == 0) { |
190 | args->driver = value; |
191 | } else if (std::strcmp(s1: key, s2: "entrypoint" ) == 0) { |
192 | args->entrypoint = value; |
193 | } else { |
194 | args->options[key] = value; |
195 | } |
196 | return ADBC_STATUS_OK; |
197 | } |
198 | |
199 | AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, |
200 | struct AdbcError *error) { |
201 | if (database->private_driver) { |
202 | return ADBC_STATUS_INVALID_STATE; |
203 | } |
204 | |
205 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
206 | args->init_func = init_func; |
207 | return ADBC_STATUS_OK; |
208 | } |
209 | |
210 | AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { |
211 | if (!database->private_data) { |
212 | SetError(error, message: "Must call AdbcDatabaseNew first" ); |
213 | return ADBC_STATUS_INVALID_STATE; |
214 | } |
215 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
216 | if (args->init_func) { |
217 | // Do nothing |
218 | } else if (args->driver.empty()) { |
219 | SetError(error, message: "Must provide 'driver' parameter" ); |
220 | return ADBC_STATUS_INVALID_ARGUMENT; |
221 | } |
222 | |
223 | database->private_driver = new AdbcDriver; |
224 | std::memset(s: database->private_driver, c: 0, n: sizeof(AdbcDriver)); |
225 | AdbcStatusCode status; |
226 | // So we don't confuse a driver into thinking it's initialized already |
227 | database->private_data = nullptr; |
228 | if (args->init_func) { |
229 | status = AdbcLoadDriverFromInitFunc(init_func: args->init_func, ADBC_VERSION_1_0_0, driver: database->private_driver, error); |
230 | } else { |
231 | status = AdbcLoadDriver(driver_name: args->driver.c_str(), entrypoint: args->entrypoint.c_str(), ADBC_VERSION_1_0_0, |
232 | driver: database->private_driver, error); |
233 | } |
234 | if (status != ADBC_STATUS_OK) { |
235 | // Restore private_data so it will be released by AdbcDatabaseRelease |
236 | database->private_data = args; |
237 | if (database->private_driver->release) { |
238 | database->private_driver->release(database->private_driver, error); |
239 | } |
240 | delete database->private_driver; |
241 | database->private_driver = nullptr; |
242 | return status; |
243 | } |
244 | status = database->private_driver->DatabaseNew(database, error); |
245 | if (status != ADBC_STATUS_OK) { |
246 | if (database->private_driver->release) { |
247 | database->private_driver->release(database->private_driver, error); |
248 | } |
249 | delete database->private_driver; |
250 | database->private_driver = nullptr; |
251 | return status; |
252 | } |
253 | for (const auto &option : args->options) { |
254 | status = |
255 | database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); |
256 | if (status != ADBC_STATUS_OK) { |
257 | delete args; |
258 | // Release the database |
259 | std::ignore = database->private_driver->DatabaseRelease(database, error); |
260 | if (database->private_driver->release) { |
261 | database->private_driver->release(database->private_driver, error); |
262 | } |
263 | delete database->private_driver; |
264 | database->private_driver = nullptr; |
265 | // Should be redundant, but ensure that AdbcDatabaseRelease |
266 | // below doesn't think that it contains a TempDatabase |
267 | database->private_data = nullptr; |
268 | return status; |
269 | } |
270 | } |
271 | delete args; |
272 | return database->private_driver->DatabaseInit(database, error); |
273 | } |
274 | |
275 | AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { |
276 | if (!database->private_driver) { |
277 | if (database->private_data) { |
278 | TempDatabase *args = reinterpret_cast<TempDatabase *>(database->private_data); |
279 | delete args; |
280 | database->private_data = nullptr; |
281 | return ADBC_STATUS_OK; |
282 | } |
283 | return ADBC_STATUS_INVALID_STATE; |
284 | } |
285 | auto status = database->private_driver->DatabaseRelease(database, error); |
286 | if (database->private_driver->release) { |
287 | database->private_driver->release(database->private_driver, error); |
288 | } |
289 | delete database->private_driver; |
290 | database->private_data = nullptr; |
291 | database->private_driver = nullptr; |
292 | return status; |
293 | } |
294 | |
295 | AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { |
296 | if (!connection->private_driver) { |
297 | return ADBC_STATUS_INVALID_STATE; |
298 | } |
299 | return connection->private_driver->ConnectionCommit(connection, error); |
300 | } |
301 | |
302 | AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, uint32_t *info_codes, size_t info_codes_length, |
303 | struct ArrowArrayStream *out, struct AdbcError *error) { |
304 | if (!connection->private_driver) { |
305 | return ADBC_STATUS_INVALID_STATE; |
306 | } |
307 | return connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); |
308 | } |
309 | |
310 | AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, |
311 | const char *db_schema, const char *table_name, const char **table_types, |
312 | const char *column_name, struct ArrowArrayStream *stream, |
313 | struct AdbcError *error) { |
314 | if (!connection->private_driver) { |
315 | return ADBC_STATUS_INVALID_STATE; |
316 | } |
317 | return connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, |
318 | table_types, column_name, stream, error); |
319 | } |
320 | |
321 | AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, |
322 | const char *db_schema, const char *table_name, struct ArrowSchema *schema, |
323 | struct AdbcError *error) { |
324 | if (!connection->private_driver) { |
325 | return ADBC_STATUS_INVALID_STATE; |
326 | } |
327 | return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, |
328 | error); |
329 | } |
330 | |
331 | AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, |
332 | struct AdbcError *error) { |
333 | if (!connection->private_driver) { |
334 | return ADBC_STATUS_INVALID_STATE; |
335 | } |
336 | return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); |
337 | } |
338 | |
339 | AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, |
340 | struct AdbcError *error) { |
341 | if (!connection->private_data) { |
342 | SetError(error, message: "Must call AdbcConnectionNew first" ); |
343 | return ADBC_STATUS_INVALID_STATE; |
344 | } else if (!database->private_driver) { |
345 | SetError(error, message: "Database is not initialized" ); |
346 | return ADBC_STATUS_INVALID_ARGUMENT; |
347 | } |
348 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
349 | connection->private_data = nullptr; |
350 | std::unordered_map<std::string, std::string> options = std::move(args->options); |
351 | delete args; |
352 | |
353 | auto status = database->private_driver->ConnectionNew(connection, error); |
354 | if (status != ADBC_STATUS_OK) |
355 | return status; |
356 | connection->private_driver = database->private_driver; |
357 | |
358 | for (const auto &option : options) { |
359 | status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), |
360 | error); |
361 | if (status != ADBC_STATUS_OK) |
362 | return status; |
363 | } |
364 | return connection->private_driver->ConnectionInit(connection, database, error); |
365 | } |
366 | |
367 | AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { |
368 | // Allocate a temporary structure to store options pre-Init, because |
369 | // we don't get access to the database (and hence the driver |
370 | // function table) until then |
371 | connection->private_data = new TempConnection; |
372 | connection->private_driver = nullptr; |
373 | return ADBC_STATUS_OK; |
374 | } |
375 | |
376 | AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, |
377 | size_t serialized_length, struct ArrowArrayStream *out, |
378 | struct AdbcError *error) { |
379 | if (!connection->private_driver) { |
380 | return ADBC_STATUS_INVALID_STATE; |
381 | } |
382 | return connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, out, |
383 | error); |
384 | } |
385 | |
386 | AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { |
387 | if (!connection->private_driver) { |
388 | if (connection->private_data) { |
389 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
390 | delete args; |
391 | connection->private_data = nullptr; |
392 | return ADBC_STATUS_OK; |
393 | } |
394 | return ADBC_STATUS_INVALID_STATE; |
395 | } |
396 | auto status = connection->private_driver->ConnectionRelease(connection, error); |
397 | connection->private_driver = nullptr; |
398 | return status; |
399 | } |
400 | |
401 | AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { |
402 | if (!connection->private_driver) { |
403 | return ADBC_STATUS_INVALID_STATE; |
404 | } |
405 | return connection->private_driver->ConnectionRollback(connection, error); |
406 | } |
407 | |
408 | AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, |
409 | struct AdbcError *error) { |
410 | if (!connection->private_data) { |
411 | SetError(error, message: "AdbcConnectionSetOption: must AdbcConnectionNew first" ); |
412 | return ADBC_STATUS_INVALID_STATE; |
413 | } |
414 | if (!connection->private_driver) { |
415 | // Init not yet called, save the option |
416 | TempConnection *args = reinterpret_cast<TempConnection *>(connection->private_data); |
417 | args->options[key] = value; |
418 | return ADBC_STATUS_OK; |
419 | } |
420 | return connection->private_driver->ConnectionSetOption(connection, key, value, error); |
421 | } |
422 | |
423 | AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, |
424 | struct AdbcError *error) { |
425 | if (!statement->private_driver) { |
426 | return ADBC_STATUS_INVALID_STATE; |
427 | } |
428 | return statement->private_driver->StatementBind(statement, values, schema, error); |
429 | } |
430 | |
431 | AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, |
432 | struct AdbcError *error) { |
433 | if (!statement->private_driver) { |
434 | return ADBC_STATUS_INVALID_STATE; |
435 | } |
436 | return statement->private_driver->StatementBindStream(statement, stream, error); |
437 | } |
438 | |
439 | // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' |
440 | AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, |
441 | struct AdbcPartitions *partitions, int64_t *rows_affected, |
442 | struct AdbcError *error) { |
443 | if (!statement->private_driver) { |
444 | return ADBC_STATUS_INVALID_STATE; |
445 | } |
446 | return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); |
447 | } |
448 | |
449 | AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, |
450 | int64_t *rows_affected, struct AdbcError *error) { |
451 | if (!statement) { |
452 | return ADBC_STATUS_INVALID_ARGUMENT; |
453 | } |
454 | if (!statement->private_driver) { |
455 | return ADBC_STATUS_INVALID_STATE; |
456 | } |
457 | return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error); |
458 | } |
459 | |
460 | AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, |
461 | struct AdbcError *error) { |
462 | if (!statement->private_driver) { |
463 | return ADBC_STATUS_INVALID_STATE; |
464 | } |
465 | return statement->private_driver->StatementGetParameterSchema(statement, schema, error); |
466 | } |
467 | |
468 | AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, |
469 | struct AdbcError *error) { |
470 | if (!connection) { |
471 | return ADBC_STATUS_INVALID_ARGUMENT; |
472 | } |
473 | if (!connection->private_driver) { |
474 | return ADBC_STATUS_INVALID_STATE; |
475 | } |
476 | auto status = connection->private_driver->StatementNew(connection, statement, error); |
477 | statement->private_driver = connection->private_driver; |
478 | return status; |
479 | } |
480 | |
481 | AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { |
482 | if (!statement->private_driver) { |
483 | return ADBC_STATUS_INVALID_STATE; |
484 | } |
485 | return statement->private_driver->StatementPrepare(statement, error); |
486 | } |
487 | |
488 | AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { |
489 | if (!statement->private_driver) { |
490 | return ADBC_STATUS_INVALID_STATE; |
491 | } |
492 | auto status = statement->private_driver->StatementRelease(statement, error); |
493 | statement->private_driver = nullptr; |
494 | return status; |
495 | } |
496 | |
497 | AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, |
498 | struct AdbcError *error) { |
499 | if (!statement->private_driver) { |
500 | return ADBC_STATUS_INVALID_STATE; |
501 | } |
502 | return statement->private_driver->StatementSetOption(statement, key, value, error); |
503 | } |
504 | |
505 | AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { |
506 | if (!statement->private_driver) { |
507 | return ADBC_STATUS_INVALID_STATE; |
508 | } |
509 | return statement->private_driver->StatementSetSqlQuery(statement, query, error); |
510 | } |
511 | |
512 | AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, |
513 | struct AdbcError *error) { |
514 | if (!statement->private_driver) { |
515 | return ADBC_STATUS_INVALID_STATE; |
516 | } |
517 | return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); |
518 | } |
519 | |
520 | const char *AdbcStatusCodeMessage(AdbcStatusCode code) { |
521 | #define STRINGIFY(s) #s |
522 | #define STRINGIFY_VALUE(s) STRINGIFY(s) |
523 | #define CASE(CONSTANT) \ |
524 | case CONSTANT: \ |
525 | return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; |
526 | |
527 | switch (code) { |
528 | CASE(ADBC_STATUS_OK); |
529 | CASE(ADBC_STATUS_UNKNOWN); |
530 | CASE(ADBC_STATUS_NOT_IMPLEMENTED); |
531 | CASE(ADBC_STATUS_NOT_FOUND); |
532 | CASE(ADBC_STATUS_ALREADY_EXISTS); |
533 | CASE(ADBC_STATUS_INVALID_ARGUMENT); |
534 | CASE(ADBC_STATUS_INVALID_STATE); |
535 | CASE(ADBC_STATUS_INVALID_DATA); |
536 | CASE(ADBC_STATUS_INTEGRITY); |
537 | CASE(ADBC_STATUS_INTERNAL); |
538 | CASE(ADBC_STATUS_IO); |
539 | CASE(ADBC_STATUS_CANCELLED); |
540 | CASE(ADBC_STATUS_TIMEOUT); |
541 | CASE(ADBC_STATUS_UNAUTHENTICATED); |
542 | CASE(ADBC_STATUS_UNAUTHORIZED); |
543 | default: |
544 | return "(invalid code)" ; |
545 | } |
546 | #undef CASE |
547 | #undef STRINGIFY_VALUE |
548 | #undef STRINGIFY |
549 | } |
550 | |
551 | AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, |
552 | struct AdbcError *error) { |
553 | AdbcDriverInitFunc init_func; |
554 | std::string error_message; |
555 | |
556 | if (version != ADBC_VERSION_1_0_0) { |
557 | SetError(error, message: "Only ADBC 1.0.0 is supported" ); |
558 | return ADBC_STATUS_NOT_IMPLEMENTED; |
559 | } |
560 | |
561 | auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver); |
562 | |
563 | if (!entrypoint) { |
564 | // Default entrypoint (see adbc.h) |
565 | entrypoint = "AdbcDriverInit" ; |
566 | } |
567 | |
568 | #if defined(_WIN32) |
569 | |
570 | HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); |
571 | if (!handle) { |
572 | error_message += driver_name; |
573 | error_message += ": LoadLibraryExA() failed: " ; |
574 | GetWinError(&error_message); |
575 | |
576 | std::string full_driver_name = driver_name; |
577 | full_driver_name += ".lib" ; |
578 | handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); |
579 | if (!handle) { |
580 | error_message += '\n'; |
581 | error_message += full_driver_name; |
582 | error_message += ": LoadLibraryExA() failed: " ; |
583 | GetWinError(&error_message); |
584 | } |
585 | } |
586 | if (!handle) { |
587 | SetError(error, error_message); |
588 | return ADBC_STATUS_INTERNAL; |
589 | } |
590 | |
591 | void *load_handle = reinterpret_cast<void *>(GetProcAddress(handle, entrypoint)); |
592 | init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
593 | if (!init_func) { |
594 | std::string message = "GetProcAddress(" ; |
595 | message += entrypoint; |
596 | message += ") failed: " ; |
597 | GetWinError(&message); |
598 | if (!FreeLibrary(handle)) { |
599 | message += "\nFreeLibrary() failed: " ; |
600 | GetWinError(&message); |
601 | } |
602 | SetError(error, message); |
603 | return ADBC_STATUS_INTERNAL; |
604 | } |
605 | |
606 | #else |
607 | |
608 | #if defined(__APPLE__) |
609 | const std::string kPlatformLibraryPrefix = "lib" ; |
610 | const std::string kPlatformLibrarySuffix = ".dylib" ; |
611 | #else |
612 | const std::string kPlatformLibraryPrefix = "lib" ; |
613 | const std::string kPlatformLibrarySuffix = ".so" ; |
614 | #endif // defined(__APPLE__) |
615 | |
616 | void *handle = dlopen(file: driver_name, RTLD_NOW | RTLD_LOCAL); |
617 | if (!handle) { |
618 | error_message = "dlopen() failed: " ; |
619 | error_message += dlerror(); |
620 | |
621 | // If applicable, append the shared library prefix/extension and |
622 | // try again (this way you don't have to hardcode driver names by |
623 | // platform in the application) |
624 | const std::string driver_str = driver_name; |
625 | |
626 | std::string full_driver_name; |
627 | if (driver_str.size() < kPlatformLibraryPrefix.size() || |
628 | driver_str.compare(pos: 0, n: kPlatformLibraryPrefix.size(), str: kPlatformLibraryPrefix) != 0) { |
629 | full_driver_name += kPlatformLibraryPrefix; |
630 | } |
631 | full_driver_name += driver_name; |
632 | if (driver_str.size() < kPlatformLibrarySuffix.size() || |
633 | driver_str.compare(pos: full_driver_name.size() - kPlatformLibrarySuffix.size(), n: kPlatformLibrarySuffix.size(), |
634 | str: kPlatformLibrarySuffix) != 0) { |
635 | full_driver_name += kPlatformLibrarySuffix; |
636 | } |
637 | handle = dlopen(file: full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); |
638 | if (!handle) { |
639 | error_message += "\ndlopen() failed: " ; |
640 | error_message += dlerror(); |
641 | } |
642 | } |
643 | if (!handle) { |
644 | SetError(error, message: error_message); |
645 | // AdbcDatabaseInit tries to call this if set |
646 | driver->release = nullptr; |
647 | return ADBC_STATUS_INTERNAL; |
648 | } |
649 | |
650 | void *load_handle = dlsym(handle: handle, name: entrypoint); |
651 | if (!load_handle) { |
652 | std::string message = "dlsym(" ; |
653 | message += entrypoint; |
654 | message += ") failed: " ; |
655 | message += dlerror(); |
656 | SetError(error, message); |
657 | return ADBC_STATUS_INTERNAL; |
658 | } |
659 | init_func = reinterpret_cast<AdbcDriverInitFunc>(load_handle); |
660 | |
661 | #endif // defined(_WIN32) |
662 | |
663 | AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); |
664 | if (status == ADBC_STATUS_OK) { |
665 | ManagerDriverState *state = new ManagerDriverState; |
666 | state->driver_release = driver->release; |
667 | #if defined(_WIN32) |
668 | state->handle = handle; |
669 | #endif // defined(_WIN32) |
670 | driver->release = &ReleaseDriver; |
671 | driver->private_manager = state; |
672 | } else { |
673 | #if defined(_WIN32) |
674 | if (!FreeLibrary(handle)) { |
675 | std::string message = "FreeLibrary() failed: " ; |
676 | GetWinError(&message); |
677 | SetError(error, message); |
678 | } |
679 | #endif // defined(_WIN32) |
680 | } |
681 | return status; |
682 | } |
683 | |
684 | AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, |
685 | struct AdbcError *error) { |
686 | #define FILL_DEFAULT(DRIVER, STUB) \ |
687 | if (!DRIVER->STUB) { \ |
688 | DRIVER->STUB = &STUB; \ |
689 | } |
690 | #define CHECK_REQUIRED(DRIVER, STUB) \ |
691 | if (!DRIVER->STUB) { \ |
692 | SetError(error, "Driver does not implement required function Adbc" #STUB); \ |
693 | return ADBC_STATUS_INTERNAL; \ |
694 | } |
695 | |
696 | auto result = init_func(version, raw_driver, error); |
697 | if (result != ADBC_STATUS_OK) { |
698 | return result; |
699 | } |
700 | |
701 | if (version == ADBC_VERSION_1_0_0) { |
702 | auto *driver = reinterpret_cast<struct AdbcDriver *>(raw_driver); |
703 | CHECK_REQUIRED(driver, DatabaseNew); |
704 | CHECK_REQUIRED(driver, DatabaseInit); |
705 | CHECK_REQUIRED(driver, DatabaseRelease); |
706 | FILL_DEFAULT(driver, DatabaseSetOption); |
707 | |
708 | CHECK_REQUIRED(driver, ConnectionNew); |
709 | CHECK_REQUIRED(driver, ConnectionInit); |
710 | CHECK_REQUIRED(driver, ConnectionRelease); |
711 | FILL_DEFAULT(driver, ConnectionCommit); |
712 | FILL_DEFAULT(driver, ConnectionGetInfo); |
713 | FILL_DEFAULT(driver, ConnectionGetObjects); |
714 | FILL_DEFAULT(driver, ConnectionGetTableSchema); |
715 | FILL_DEFAULT(driver, ConnectionGetTableTypes); |
716 | FILL_DEFAULT(driver, ConnectionReadPartition); |
717 | FILL_DEFAULT(driver, ConnectionRollback); |
718 | FILL_DEFAULT(driver, ConnectionSetOption); |
719 | |
720 | FILL_DEFAULT(driver, StatementExecutePartitions); |
721 | CHECK_REQUIRED(driver, StatementExecuteQuery); |
722 | CHECK_REQUIRED(driver, StatementNew); |
723 | CHECK_REQUIRED(driver, StatementRelease); |
724 | FILL_DEFAULT(driver, StatementBind); |
725 | FILL_DEFAULT(driver, StatementGetParameterSchema); |
726 | FILL_DEFAULT(driver, StatementPrepare); |
727 | FILL_DEFAULT(driver, StatementSetOption); |
728 | FILL_DEFAULT(driver, StatementSetSqlQuery); |
729 | FILL_DEFAULT(driver, StatementSetSubstraitPlan); |
730 | } |
731 | |
732 | return ADBC_STATUS_OK; |
733 | |
734 | #undef FILL_DEFAULT |
735 | #undef CHECK_REQUIRED |
736 | } |
737 | } // namespace duckdb_adbc |
738 | |