1 | #include "duckdb/main/capi/capi_internal.hpp" |
2 | #include "duckdb/common/arrow/arrow_converter.hpp" |
3 | |
4 | using duckdb::ArrowConverter; |
5 | using duckdb::ArrowResultWrapper; |
6 | using duckdb::Connection; |
7 | using duckdb::DataChunk; |
8 | using duckdb::LogicalType; |
9 | using duckdb::MaterializedQueryResult; |
10 | using duckdb::PreparedStatementWrapper; |
11 | using duckdb::QueryResult; |
12 | using duckdb::QueryResultType; |
13 | |
14 | duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result) { |
15 | Connection *conn = (Connection *)connection; |
16 | auto wrapper = new ArrowResultWrapper(); |
17 | wrapper->result = conn->Query(query); |
18 | *out_result = (duckdb_arrow)wrapper; |
19 | return !wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; |
20 | } |
21 | |
22 | duckdb_state duckdb_query_arrow_schema(duckdb_arrow result, duckdb_arrow_schema *out_schema) { |
23 | if (!out_schema) { |
24 | return DuckDBSuccess; |
25 | } |
26 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
27 | ArrowConverter::ToArrowSchema(out_schema: (ArrowSchema *)*out_schema, types: wrapper->result->types, names: wrapper->result->names, |
28 | options: wrapper->options); |
29 | return DuckDBSuccess; |
30 | } |
31 | |
32 | duckdb_state duckdb_query_arrow_array(duckdb_arrow result, duckdb_arrow_array *out_array) { |
33 | if (!out_array) { |
34 | return DuckDBSuccess; |
35 | } |
36 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
37 | auto success = wrapper->result->TryFetch(result&: wrapper->current_chunk, error&: wrapper->result->GetErrorObject()); |
38 | if (!success) { // LCOV_EXCL_START |
39 | return DuckDBError; |
40 | } // LCOV_EXCL_STOP |
41 | if (!wrapper->current_chunk || wrapper->current_chunk->size() == 0) { |
42 | return DuckDBSuccess; |
43 | } |
44 | ArrowConverter::ToArrowArray(input&: *wrapper->current_chunk, out_array: reinterpret_cast<ArrowArray *>(*out_array), options: wrapper->options); |
45 | return DuckDBSuccess; |
46 | } |
47 | |
48 | idx_t duckdb_arrow_row_count(duckdb_arrow result) { |
49 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
50 | if (wrapper->result->HasError()) { |
51 | return 0; |
52 | } |
53 | return wrapper->result->RowCount(); |
54 | } |
55 | |
56 | idx_t duckdb_arrow_column_count(duckdb_arrow result) { |
57 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
58 | return wrapper->result->ColumnCount(); |
59 | } |
60 | |
61 | idx_t duckdb_arrow_rows_changed(duckdb_arrow result) { |
62 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
63 | if (wrapper->result->HasError()) { |
64 | return 0; |
65 | } |
66 | idx_t rows_changed = 0; |
67 | auto &collection = wrapper->result->Collection(); |
68 | idx_t row_count = collection.Count(); |
69 | if (row_count > 0 && wrapper->result->properties.return_type == duckdb::StatementReturnType::CHANGED_ROWS) { |
70 | auto rows = collection.GetRows(); |
71 | D_ASSERT(row_count == 1); |
72 | D_ASSERT(rows.size() == 1); |
73 | rows_changed = rows[0].GetValue(column_index: 0).GetValue<int64_t>(); |
74 | } |
75 | return rows_changed; |
76 | } |
77 | |
78 | const char *duckdb_query_arrow_error(duckdb_arrow result) { |
79 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(result); |
80 | return wrapper->result->GetError().c_str(); |
81 | } |
82 | |
83 | void duckdb_destroy_arrow(duckdb_arrow *result) { |
84 | if (*result) { |
85 | auto wrapper = reinterpret_cast<ArrowResultWrapper *>(*result); |
86 | delete wrapper; |
87 | *result = nullptr; |
88 | } |
89 | } |
90 | |
91 | duckdb_state duckdb_execute_prepared_arrow(duckdb_prepared_statement prepared_statement, duckdb_arrow *out_result) { |
92 | auto wrapper = reinterpret_cast<PreparedStatementWrapper *>(prepared_statement); |
93 | if (!wrapper || !wrapper->statement || wrapper->statement->HasError() || !out_result) { |
94 | return DuckDBError; |
95 | } |
96 | auto arrow_wrapper = new ArrowResultWrapper(); |
97 | if (wrapper->statement->context->config.set_variables.find(x: "TimeZone" ) == |
98 | wrapper->statement->context->config.set_variables.end()) { |
99 | arrow_wrapper->options.time_zone = "UTC" ; |
100 | } else { |
101 | arrow_wrapper->options.time_zone = |
102 | wrapper->statement->context->config.set_variables["TimeZone" ].GetValue<std::string>(); |
103 | } |
104 | |
105 | auto result = wrapper->statement->Execute(values&: wrapper->values, allow_stream_result: false); |
106 | D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); |
107 | arrow_wrapper->result = duckdb::unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(result)); |
108 | *out_result = reinterpret_cast<duckdb_arrow>(arrow_wrapper); |
109 | return !arrow_wrapper->result->HasError() ? DuckDBSuccess : DuckDBError; |
110 | } |
111 | |