| 1 | #include "duckdb/common/arrow/arrow_wrapper.hpp" |
| 2 | #include "duckdb/common/arrow/arrow_converter.hpp" |
| 3 | |
| 4 | #include "duckdb/common/assert.hpp" |
| 5 | #include "duckdb/common/exception.hpp" |
| 6 | |
| 7 | #include "duckdb/main/stream_query_result.hpp" |
| 8 | |
| 9 | #include "duckdb/common/arrow/result_arrow_wrapper.hpp" |
| 10 | #include "duckdb/common/arrow/arrow_appender.hpp" |
| 11 | #include "duckdb/main/query_result.hpp" |
| 12 | |
| 13 | namespace duckdb { |
| 14 | |
| 15 | ArrowSchemaWrapper::~ArrowSchemaWrapper() { |
| 16 | if (arrow_schema.release) { |
| 17 | for (int64_t child_idx = 0; child_idx < arrow_schema.n_children; child_idx++) { |
| 18 | auto &child = *arrow_schema.children[child_idx]; |
| 19 | if (child.release) { |
| 20 | child.release(&child); |
| 21 | } |
| 22 | } |
| 23 | arrow_schema.release(&arrow_schema); |
| 24 | arrow_schema.release = nullptr; |
| 25 | } |
| 26 | } |
| 27 | |
| 28 | ArrowArrayWrapper::~ArrowArrayWrapper() { |
| 29 | if (arrow_array.release) { |
| 30 | for (int64_t child_idx = 0; child_idx < arrow_array.n_children; child_idx++) { |
| 31 | auto &child = *arrow_array.children[child_idx]; |
| 32 | if (child.release) { |
| 33 | child.release(&child); |
| 34 | } |
| 35 | } |
| 36 | arrow_array.release(&arrow_array); |
| 37 | arrow_array.release = nullptr; |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() { |
| 42 | if (arrow_array_stream.release) { |
| 43 | arrow_array_stream.release(&arrow_array_stream); |
| 44 | arrow_array_stream.release = nullptr; |
| 45 | } |
| 46 | } |
| 47 | |
| 48 | void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { |
| 49 | D_ASSERT(arrow_array_stream.get_schema); |
| 50 | // LCOV_EXCL_START |
| 51 | if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema)) { |
| 52 | throw InvalidInputException("arrow_scan: get_schema failed(): %s" , string(GetError())); |
| 53 | } |
| 54 | if (!schema.arrow_schema.release) { |
| 55 | throw InvalidInputException("arrow_scan: released schema passed" ); |
| 56 | } |
| 57 | if (schema.arrow_schema.n_children < 1) { |
| 58 | throw InvalidInputException("arrow_scan: empty schema passed" ); |
| 59 | } |
| 60 | // LCOV_EXCL_STOP |
| 61 | } |
| 62 | |
| 63 | shared_ptr<ArrowArrayWrapper> ArrowArrayStreamWrapper::GetNextChunk() { |
| 64 | auto current_chunk = make_shared<ArrowArrayWrapper>(); |
| 65 | if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START |
| 66 | throw InvalidInputException("arrow_scan: get_next failed(): %s" , string(GetError())); |
| 67 | } // LCOV_EXCL_STOP |
| 68 | |
| 69 | return current_chunk; |
| 70 | } |
| 71 | |
| 72 | const char *ArrowArrayStreamWrapper::GetError() { // LCOV_EXCL_START |
| 73 | return arrow_array_stream.get_last_error(&arrow_array_stream); |
| 74 | } // LCOV_EXCL_STOP |
| 75 | |
| 76 | int ResultArrowArrayStreamWrapper::MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { |
| 77 | if (!stream->release) { |
| 78 | return -1; |
| 79 | } |
| 80 | auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data); |
| 81 | if (!my_stream->column_types.empty()) { |
| 82 | ArrowConverter::ToArrowSchema(out_schema: out, types: my_stream->column_types, names: my_stream->column_names, |
| 83 | options: QueryResult::GetArrowOptions(query_result&: *my_stream->result)); |
| 84 | return 0; |
| 85 | } |
| 86 | |
| 87 | auto &result = *my_stream->result; |
| 88 | if (result.HasError()) { |
| 89 | my_stream->last_error = result.GetErrorObject(); |
| 90 | return -1; |
| 91 | } |
| 92 | if (result.type == QueryResultType::STREAM_RESULT) { |
| 93 | auto &stream_result = result.Cast<StreamQueryResult>(); |
| 94 | if (!stream_result.IsOpen()) { |
| 95 | my_stream->last_error = PreservedError("Query Stream is closed" ); |
| 96 | return -1; |
| 97 | } |
| 98 | } |
| 99 | if (my_stream->column_types.empty()) { |
| 100 | my_stream->column_types = result.types; |
| 101 | my_stream->column_names = result.names; |
| 102 | } |
| 103 | ArrowConverter::ToArrowSchema(out_schema: out, types: my_stream->column_types, names: my_stream->column_names, |
| 104 | options: QueryResult::GetArrowOptions(query_result&: *my_stream->result)); |
| 105 | return 0; |
| 106 | } |
| 107 | |
| 108 | int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { |
| 109 | if (!stream->release) { |
| 110 | return -1; |
| 111 | } |
| 112 | auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data); |
| 113 | auto &result = *my_stream->result; |
| 114 | if (result.HasError()) { |
| 115 | my_stream->last_error = result.GetErrorObject(); |
| 116 | return -1; |
| 117 | } |
| 118 | if (result.type == QueryResultType::STREAM_RESULT) { |
| 119 | auto &stream_result = result.Cast<StreamQueryResult>(); |
| 120 | if (!stream_result.IsOpen()) { |
| 121 | // Nothing to output |
| 122 | out->release = nullptr; |
| 123 | return 0; |
| 124 | } |
| 125 | } |
| 126 | if (my_stream->column_types.empty()) { |
| 127 | my_stream->column_types = result.types; |
| 128 | my_stream->column_names = result.names; |
| 129 | } |
| 130 | idx_t result_count; |
| 131 | PreservedError error; |
| 132 | if (!ArrowUtil::TryFetchChunk(result: &result, chunk_size: my_stream->batch_size, out, result_count, error)) { |
| 133 | D_ASSERT(error); |
| 134 | my_stream->last_error = error; |
| 135 | return -1; |
| 136 | } |
| 137 | if (result_count == 0) { |
| 138 | // Nothing to output |
| 139 | out->release = nullptr; |
| 140 | } |
| 141 | return 0; |
| 142 | } |
| 143 | |
| 144 | void ResultArrowArrayStreamWrapper::MyStreamRelease(struct ArrowArrayStream *stream) { |
| 145 | if (!stream || !stream->release) { |
| 146 | return; |
| 147 | } |
| 148 | stream->release = nullptr; |
| 149 | delete reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data); |
| 150 | } |
| 151 | |
| 152 | const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArrayStream *stream) { |
| 153 | if (!stream->release) { |
| 154 | return "stream was released" ; |
| 155 | } |
| 156 | D_ASSERT(stream->private_data); |
| 157 | auto my_stream = reinterpret_cast<ResultArrowArrayStreamWrapper *>(stream->private_data); |
| 158 | return my_stream->last_error.Message().c_str(); |
| 159 | } |
| 160 | |
| 161 | ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr<QueryResult> result_p, idx_t batch_size_p) |
| 162 | : result(std::move(result_p)) { |
| 163 | //! We first initialize the private data of the stream |
| 164 | stream.private_data = this; |
| 165 | //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE |
| 166 | if (batch_size_p == 0) { |
| 167 | throw std::runtime_error("Approximate Batch Size of Record Batch MUST be higher than 0" ); |
| 168 | } |
| 169 | batch_size = batch_size_p; |
| 170 | //! We initialize the stream functions |
| 171 | stream.get_schema = ResultArrowArrayStreamWrapper::MyStreamGetSchema; |
| 172 | stream.get_next = ResultArrowArrayStreamWrapper::MyStreamGetNext; |
| 173 | stream.release = ResultArrowArrayStreamWrapper::MyStreamRelease; |
| 174 | stream.get_last_error = ResultArrowArrayStreamWrapper::MyStreamGetLastError; |
| 175 | } |
| 176 | |
| 177 | bool ArrowUtil::TryFetchNext(QueryResult &result, unique_ptr<DataChunk> &chunk, PreservedError &error) { |
| 178 | if (result.type == QueryResultType::STREAM_RESULT) { |
| 179 | auto &stream_result = result.Cast<StreamQueryResult>(); |
| 180 | if (!stream_result.IsOpen()) { |
| 181 | return true; |
| 182 | } |
| 183 | } |
| 184 | return result.TryFetch(result&: chunk, error); |
| 185 | } |
| 186 | |
| 187 | bool ArrowUtil::TryFetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out, idx_t &count, |
| 188 | PreservedError &error) { |
| 189 | count = 0; |
| 190 | ArrowAppender appender(result->types, chunk_size, QueryResult::GetArrowOptions(query_result&: *result)); |
| 191 | auto ¤t_chunk = result->current_chunk; |
| 192 | if (current_chunk.Valid()) { |
| 193 | // We start by scanning the non-finished current chunk |
| 194 | idx_t cur_consumption = current_chunk.RemainingSize() > chunk_size ? chunk_size : current_chunk.RemainingSize(); |
| 195 | count += cur_consumption; |
| 196 | appender.Append(input&: *current_chunk.data_chunk, from: current_chunk.position, to: current_chunk.position + cur_consumption, |
| 197 | input_size: current_chunk.data_chunk->size()); |
| 198 | current_chunk.position += cur_consumption; |
| 199 | } |
| 200 | while (count < chunk_size) { |
| 201 | unique_ptr<DataChunk> data_chunk; |
| 202 | if (!TryFetchNext(result&: *result, chunk&: data_chunk, error)) { |
| 203 | if (result->HasError()) { |
| 204 | error = result->GetErrorObject(); |
| 205 | } |
| 206 | return false; |
| 207 | } |
| 208 | if (!data_chunk || data_chunk->size() == 0) { |
| 209 | break; |
| 210 | } |
| 211 | if (count + data_chunk->size() > chunk_size) { |
| 212 | // We have to split the chunk between this and the next batch |
| 213 | idx_t available_space = chunk_size - count; |
| 214 | appender.Append(input&: *data_chunk, from: 0, to: available_space, input_size: data_chunk->size()); |
| 215 | count += available_space; |
| 216 | current_chunk.data_chunk = std::move(data_chunk); |
| 217 | current_chunk.position = available_space; |
| 218 | } else { |
| 219 | count += data_chunk->size(); |
| 220 | appender.Append(input&: *data_chunk, from: 0, to: data_chunk->size(), input_size: data_chunk->size()); |
| 221 | } |
| 222 | } |
| 223 | if (count > 0) { |
| 224 | *out = appender.Finalize(); |
| 225 | } |
| 226 | return true; |
| 227 | } |
| 228 | |
| 229 | idx_t ArrowUtil::FetchChunk(QueryResult *result, idx_t chunk_size, ArrowArray *out) { |
| 230 | PreservedError error; |
| 231 | idx_t result_count; |
| 232 | if (!TryFetchChunk(result, chunk_size, out, count&: result_count, error)) { |
| 233 | error.Throw(); |
| 234 | } |
| 235 | return result_count; |
| 236 | } |
| 237 | |
| 238 | } // namespace duckdb |
| 239 | |