| 1 | #include "duckdb/main/capi/capi_internal.hpp" |
| 2 | #include "duckdb/common/arrow/arrow_converter.hpp" |
| 3 | |
| 4 | using duckdb::ArrowConverter; |
| 5 | using duckdb::ArrowResultWrapper; |
| 6 | using duckdb::Connection; |
| 7 | using duckdb::DataChunk; |
| 8 | using duckdb::LogicalType; |
| 9 | using duckdb::MaterializedQueryResult; |
| 10 | using duckdb::PreparedStatementWrapper; |
| 11 | using duckdb::QueryResult; |
| 12 | using duckdb::QueryResultType; |
| 13 | |
| 14 | duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result) { |
| 15 | Connection *conn = (Connection *)connection; |
| 16 | auto wrapper = new ArrowResultWrapper(); |
| 17 | wrapper->result = conn->Query(query); |
| 18 | *out_result = (duckdb_arrow)wrapper; |
| 19 | return !wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; |
| 20 | } |
| 21 | |
| 22 | duckdb_state duckdb_query_arrow_schema(duckdb_arrow result, duckdb_arrow_schema *out_schema) { |
| 23 | if (!out_schema) { |
| 24 | return DuckDBSuccess; |
| 25 | } |
| 26 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 27 | ArrowConverter::ToArrowSchema(out_schema: (ArrowSchema *)*out_schema, types: wrapper->result->types, names: wrapper->result->names, |
| 28 | options: wrapper->options); |
| 29 | return DuckDBSuccess; |
| 30 | } |
| 31 | |
| 32 | duckdb_state duckdb_query_arrow_array(duckdb_arrow result, duckdb_arrow_array *out_array) { |
| 33 | if (!out_array) { |
| 34 | return DuckDBSuccess; |
| 35 | } |
| 36 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 37 | auto success = wrapper->result->TryFetch(result&: wrapper->current_chunk, error&: wrapper->result->GetErrorObject()); |
| 38 | if (!success) { // LCOV_EXCL_START |
| 39 | return DuckDBError; |
| 40 | } // LCOV_EXCL_STOP |
| 41 | if (!wrapper->current_chunk || wrapper->current_chunk->size() == 0) { |
| 42 | return DuckDBSuccess; |
| 43 | } |
| 44 | ArrowConverter::ToArrowArray(input&: *wrapper->current_chunk, out_array: reinterpret_cast<ArrowArray *>(*out_array), options: wrapper->options); |
| 45 | return DuckDBSuccess; |
| 46 | } |
| 47 | |
| 48 | idx_t duckdb_arrow_row_count(duckdb_arrow result) { |
| 49 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 50 | if (wrapper->result->HasError()) { |
| 51 | return 0; |
| 52 | } |
| 53 | return wrapper->result->RowCount(); |
| 54 | } |
| 55 | |
| 56 | idx_t duckdb_arrow_column_count(duckdb_arrow result) { |
| 57 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 58 | return wrapper->result->ColumnCount(); |
| 59 | } |
| 60 | |
| 61 | idx_t duckdb_arrow_rows_changed(duckdb_arrow result) { |
| 62 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 63 | if (wrapper->result->HasError()) { |
| 64 | return 0; |
| 65 | } |
| 66 | idx_t rows_changed = 0; |
| 67 | auto &collection = wrapper->result->Collection(); |
| 68 | idx_t row_count = collection.Count(); |
| 69 | if (row_count > 0 && wrapper->result->properties.return_type == duckdb::StatementReturnType::CHANGED_ROWS) { |
| 70 | auto rows = collection.GetRows(); |
| 71 | D_ASSERT(row_count == 1); |
| 72 | D_ASSERT(rows.size() == 1); |
| 73 | rows_changed = rows[0].GetValue(column_index: 0).GetValue<int64_t>(); |
| 74 | } |
| 75 | return rows_changed; |
| 76 | } |
| 77 | |
| 78 | const char *duckdb_query_arrow_error(duckdb_arrow result) { |
| 79 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
| 80 | return wrapper->result->GetError().c_str(); |
| 81 | } |
| 82 | |
| 83 | void duckdb_destroy_arrow(duckdb_arrow *result) { |
| 84 | if (*result) { |
| 85 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(*result); |
| 86 | delete wrapper; |
| 87 | *result = nullptr; |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | duckdb_state duckdb_execute_prepared_arrow(duckdb_prepared_statement prepared_statement, duckdb_arrow *out_result) { |
| 92 | auto wrapper = reinterpret_cast<PreparedStatementWrapper *>(prepared_statement); |
| 93 | if (!wrapper || !wrapper->statement || wrapper->statement->HasError() || !out_result) { |
| 94 | return DuckDBError; |
| 95 | } |
| 96 | auto arrow_wrapper = new ArrowResultWrapper(); |
| 97 | if (wrapper->statement->context->config.set_variables.find(x: "TimeZone" ) == |
| 98 | wrapper->statement->context->config.set_variables.end()) { |
| 99 | arrow_wrapper->options.time_zone = "UTC" ; |
| 100 | } else { |
| 101 | arrow_wrapper->options.time_zone = |
| 102 | wrapper->statement->context->config.set_variables["TimeZone" ].GetValue<std::string>(); |
| 103 | } |
| 104 | |
| 105 | auto result = wrapper->statement->Execute(values&: wrapper->values, allow_stream_result: false); |
| 106 | D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); |
| 107 | arrow_wrapper->result = duckdb::unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(result)); |
| 108 | *out_result = reinterpret_cast<duckdb_arrow>(arrow_wrapper); |
| 109 | return !arrow_wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; |
| 110 | } |
| 111 | |