1#include "duckdb/common/arrow/arrow_wrapper.hpp"
2#include "duckdb/common/arrow/arrow_converter.hpp"
3
4#include "duckdb/common/assert.hpp"
5#include "duckdb/common/exception.hpp"
6
7#include "duckdb/main/stream_query_result.hpp"
8
9#include "duckdb/common/arrow/result_arrow_wrapper.hpp"
10#include "duckdb/common/arrow/arrow_appender.hpp"
11#include "duckdb/main/query_result.hpp"
12
13namespace duckdb {
14
15ArrowSchemaWrapper::~ArrowSchemaWrapper() {
16 if (arrow_schema.release) {
17 for (int64_t child_idx = 0; child_idx < arrow_schema.n_children; child_idx++) {
18 auto &child = *arrow_schema.children[child_idx];
19 if (child.release) {
20 child.release(&child);
21 }
22 }
23 arrow_schema.release(&arrow_schema);
24 arrow_schema.release = nullptr;
25 }
26}
27
28ArrowArrayWrapper::~ArrowArrayWrapper() {
29 if (arrow_array.release) {
30 for (int64_t child_idx = 0; child_idx < arrow_array.n_children; child_idx++) {
31 auto &child = *arrow_array.children[child_idx];
32 if (child.release) {
33 child.release(&child);
34 }
35 }
36 arrow_array.release(&arrow_array);
37 arrow_array.release = nullptr;
38 }
39}
40
41ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() {
42 if (arrow_array_stream.release) {
43 arrow_array_stream.release(&arrow_array_stream);
44 arrow_array_stream.release = nullptr;
45 }
46}
47
48void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) {
49 D_ASSERT(arrow_array_stream.get_schema);
50 // LCOV_EXCL_START
51 if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema)) {
52 throw InvalidInputException("arrow_scan: get_schema failed(): %s", string(GetError()));
53 }
54 if (!schema.arrow_schema.release) {
55 throw InvalidInputException("arrow_scan: released schema passed");
56 }
57 if (schema.arrow_schema.n_children < 1) {
58 throw InvalidInputException("arrow_scan: empty schema passed");
59 }
60 // LCOV_EXCL_STOP
61}
62
63shared_ptr<ArrowArrayWrapper> ArrowArrayStreamWrapper::GetNextChunk() {
64 auto current_chunk = make_shared<ArrowArrayWrapper>();
65 if (arrow_array_stream.get_next(&arrow_array_stream, &current_chunk->arrow_array)) { // LCOV_EXCL_START
66 throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError()));
67 } // LCOV_EXCL_STOP
68
69 return current_chunk;
70}
71
72const char *ArrowArrayStreamWrapper::GetError() { // LCOV_EXCL_START
73 return arrow_array_stream.get_last_error(&arrow_array_stream);
74} // LCOV_EXCL_STOP
75
76int ResultArrowArrayStreamWrapper::MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) {
77 if (!stream->release) {
78 return -1;
79 }
80 auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data);
81 if (!my_stream->column_types.empty()) {
82 ArrowConverter::ToArrowSchema(out_schema: out, types: my_stream->column_types, names: my_stream->column_names,
83 options: QueryResult::GetArrowOptions(query_result&: *my_stream->result));
84 return 0;
85 }
86
87 auto &result = *my_stream->result;
88 if (result.HasError()) {
89 my_stream->last_error = result.GetErrorObject();
90 return -1;
91 }
92 if (result.type == QueryResultType::STREAM_RESULT) {
93 auto &stream_result = result.Cast<StreamQueryResult>();
94 if (!stream_result.IsOpen()) {
95 my_stream->last_error = PreservedError("Query Stream is closed");
96 return -1;
97 }
98 }
99 if (my_stream->column_types.empty()) {
100 my_stream->column_types = result.types;
101 my_stream->column_names = result.names;
102 }
103 ArrowConverter::ToArrowSchema(out_schema: out, types: my_stream->column_types, names: my_stream->column_names,
104 options: QueryResult::GetArrowOptions(query_result&: *my_stream->result));
105 return 0;
106}
107
108int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) {
109 if (!stream->release) {
110 return -1;
111 }
112 auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data);
113 auto &result = *my_stream->result;
114 if (result.HasError()) {
115 my_stream->last_error = result.GetErrorObject();
116 return -1;
117 }
118 if (result.type == QueryResultType::STREAM_RESULT) {
119 auto &stream_result = result.Cast<StreamQueryResult>();
120 if (!stream_result.IsOpen()) {
121 // Nothing to output
122 out->release = nullptr;
123 return 0;
124 }
125 }
126 if (my_stream->column_types.empty()) {
127 my_stream->column_types = result.types;
128 my_stream->column_names = result.names;
129 }
130 idx_t result_count;
131 PreservedError error;
132 if (!ArrowUtil::TryFetchChunk(result: &result, chunk_size: my_stream->batch_size, out, result_count, error)) {
133 D_ASSERT(error);
134 my_stream->last_error = error;
135 return -1;
136 }
137 if (result_count == 0) {
138 // Nothing to output
139 out->release = nullptr;
140 }
141 return 0;
142}
143
144void ResultArrowArrayStreamWrapper::MyStreamRelease(struct ArrowArrayStream *stream) {
145 if (!stream || !stream->release) {
146 return;
147 }
148 stream->release = nullptr;
149 delete reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data);
150}
151
152const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArrayStream *stream) {
153 if (!stream->release) {
154 return "stream was released";
155 }
156 D_ASSERT(stream->private_data);
157 auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data);
158 return my_stream->last_error.Message().c_str();
159}
160
161ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr<QueryResult> result_p, idx_t batch_size_p)
162 : result(std::move(result_p)) {
163 //! We first initialize the private data of the stream
164 stream.private_data = this;
165 //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE
166 if (batch_size_p == 0) {
167 throw std::runtime_error("Approximate Batch Size of Record Batch MUST be higher than 0");
168 }
169 batch_size = batch_size_p;
170 //! We initialize the stream functions
171 stream.get_schema = ResultArrowArrayStreamWrapper::MyStreamGetSchema;
172 stream.get_next = ResultArrowArrayStreamWrapper::MyStreamGetNext;
173 stream.release = ResultArrowArrayStreamWrapper::MyStreamRelease;
174 stream.get_last_error = ResultArrowArrayStreamWrapper::MyStreamGetLastError;
175}
176
177bool ArrowUtil::TryFetchNext(QueryResult &result, unique_ptr<DataChunk> &chunk, PreservedError &error) {
178 if (result.type == QueryResultType::STREAM_RESULT) {
179 auto &stream_result = result.Cast<StreamQueryResult>();
180 if (!stream_result.IsOpen()) {
181 return true;
182 }
183 }
184 return result.TryFetch(result&: chunk, error);
185}
186
187bool ArrowUtil::TryFetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out, idx_t &count,
188 PreservedError &error) {
189 count = 0;
190 ArrowAppender appender(result->types, chunk_size, QueryResult::GetArrowOptions(query_result&: *result));
191 auto &current_chunk = result->current_chunk;
192 if (current_chunk.Valid()) {
193 // We start by scanning the non-finished current chunk
194 idx_t cur_consumption = current_chunk.RemainingSize() > chunk_size ? chunk_size : current_chunk.RemainingSize();
195 count += cur_consumption;
196 appender.Append(input&: *current_chunk.data_chunk, from: current_chunk.position, to: current_chunk.position + cur_consumption,
197 input_size: current_chunk.data_chunk->size());
198 current_chunk.position += cur_consumption;
199 }
200 while (count < chunk_size) {
201 unique_ptr<DataChunk> data_chunk;
202 if (!TryFetchNext(result&: *result, chunk&: data_chunk, error)) {
203 if (result->HasError()) {
204 error = result->GetErrorObject();
205 }
206 return false;
207 }
208 if (!data_chunk || data_chunk->size() == 0) {
209 break;
210 }
211 if (count + data_chunk->size() > chunk_size) {
212 // We have to split the chunk between this and the next batch
213 idx_t available_space = chunk_size - count;
214 appender.Append(input&: *data_chunk, from: 0, to: available_space, input_size: data_chunk->size());
215 count += available_space;
216 current_chunk.data_chunk = std::move(data_chunk);
217 current_chunk.position = available_space;
218 } else {
219 count += data_chunk->size();
220 appender.Append(input&: *data_chunk, from: 0, to: data_chunk->size(), input_size: data_chunk->size());
221 }
222 }
223 if (count > 0) {
224 *out = appender.Finalize();
225 }
226 return true;
227}
228
229idx_t ArrowUtil::FetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out) {
230 PreservedError error;
231 idx_t result_count;
232 if (!TryFetchChunk(result, chunk_size, out, count&: result_count, error)) {
233 error.Throw();
234 }
235 return result_count;
236}
237
238} // namespace duckdb
239