1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/ipc/reader.h" |
19 | |
20 | #include <cstdint> |
21 | #include <cstring> |
22 | #include <sstream> |
23 | #include <string> |
24 | #include <type_traits> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include <flatbuffers/flatbuffers.h> // IWYU pragma: export |
29 | |
30 | #include "arrow/array.h" |
31 | #include "arrow/buffer.h" |
32 | #include "arrow/io/interfaces.h" |
33 | #include "arrow/io/memory.h" |
34 | #include "arrow/ipc/File_generated.h" // IWYU pragma: export |
35 | #include "arrow/ipc/Message_generated.h" |
36 | #include "arrow/ipc/Schema_generated.h" |
37 | #include "arrow/ipc/dictionary.h" |
38 | #include "arrow/ipc/message.h" |
39 | #include "arrow/ipc/metadata-internal.h" |
40 | #include "arrow/record_batch.h" |
41 | #include "arrow/sparse_tensor.h" |
42 | #include "arrow/status.h" |
43 | #include "arrow/tensor.h" |
44 | #include "arrow/type.h" |
45 | #include "arrow/util/logging.h" |
46 | #include "arrow/visitor_inline.h" |
47 | |
48 | namespace arrow { |
49 | |
50 | namespace flatbuf = org::apache::arrow::flatbuf; |
51 | |
52 | namespace ipc { |
53 | |
54 | using internal::FileBlock; |
55 | using internal::kArrowMagicBytes; |
56 | |
57 | // ---------------------------------------------------------------------- |
58 | // Record batch read path |
59 | |
60 | /// Accessor class for flatbuffers metadata |
61 | class IpcComponentSource { |
62 | public: |
63 | IpcComponentSource(const flatbuf::RecordBatch* metadata, io::RandomAccessFile* file) |
64 | : metadata_(metadata), file_(file) {} |
65 | |
66 | Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) { |
67 | const flatbuf::Buffer* buffer = metadata_->buffers()->Get(buffer_index); |
68 | |
69 | if (buffer->length() == 0) { |
70 | *out = nullptr; |
71 | return Status::OK(); |
72 | } else { |
73 | DCHECK(BitUtil::IsMultipleOf8(buffer->offset())) |
74 | << "Buffer " << buffer_index |
75 | << " did not start on 8-byte aligned offset: " << buffer->offset(); |
76 | return file_->ReadAt(buffer->offset(), buffer->length(), out); |
77 | } |
78 | } |
79 | |
80 | Status GetFieldMetadata(int field_index, ArrayData* out) { |
81 | auto nodes = metadata_->nodes(); |
82 | // pop off a field |
83 | if (field_index >= static_cast<int>(nodes->size())) { |
84 | return Status::Invalid("Ran out of field metadata, likely malformed" ); |
85 | } |
86 | const flatbuf::FieldNode* node = nodes->Get(field_index); |
87 | |
88 | out->length = node->length(); |
89 | out->null_count = node->null_count(); |
90 | out->offset = 0; |
91 | return Status::OK(); |
92 | } |
93 | |
94 | private: |
95 | const flatbuf::RecordBatch* metadata_; |
96 | io::RandomAccessFile* file_; |
97 | }; |
98 | |
99 | /// Bookkeeping struct for loading array objects from their constituent pieces of raw data |
100 | /// |
101 | /// The field_index and buffer_index are incremented in the ArrayLoader |
102 | /// based on how much of the batch is "consumed" (through nested data |
103 | /// reconstruction, for example) |
104 | struct ArrayLoaderContext { |
105 | IpcComponentSource* source; |
106 | int buffer_index; |
107 | int field_index; |
108 | int max_recursion_depth; |
109 | }; |
110 | |
111 | static Status LoadArray(const std::shared_ptr<DataType>& type, |
112 | ArrayLoaderContext* context, ArrayData* out); |
113 | |
114 | class ArrayLoader { |
115 | public: |
116 | ArrayLoader(const std::shared_ptr<DataType>& type, ArrayData* out, |
117 | ArrayLoaderContext* context) |
118 | : type_(type), context_(context), out_(out) {} |
119 | |
120 | Status Load() { |
121 | if (context_->max_recursion_depth <= 0) { |
122 | return Status::Invalid("Max recursion depth reached" ); |
123 | } |
124 | |
125 | out_->type = type_; |
126 | |
127 | RETURN_NOT_OK(VisitTypeInline(*type_, this)); |
128 | return Status::OK(); |
129 | } |
130 | |
131 | Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) { |
132 | return context_->source->GetBuffer(buffer_index, out); |
133 | } |
134 | |
135 | Status LoadCommon() { |
136 | // This only contains the length and null count, which we need to figure |
137 | // out what to do with the buffers. For example, if null_count == 0, then |
138 | // we can skip that buffer without reading from shared memory |
139 | RETURN_NOT_OK(context_->source->GetFieldMetadata(context_->field_index++, out_)); |
140 | |
141 | // extract null_bitmap which is common to all arrays |
142 | if (out_->null_count == 0) { |
143 | out_->buffers[0] = nullptr; |
144 | } else { |
145 | RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[0])); |
146 | } |
147 | context_->buffer_index++; |
148 | return Status::OK(); |
149 | } |
150 | |
151 | template <typename TYPE> |
152 | Status LoadPrimitive() { |
153 | out_->buffers.resize(2); |
154 | |
155 | RETURN_NOT_OK(LoadCommon()); |
156 | if (out_->length > 0) { |
157 | RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); |
158 | } else { |
159 | context_->buffer_index++; |
160 | out_->buffers[1].reset(new Buffer(nullptr, 0)); |
161 | } |
162 | return Status::OK(); |
163 | } |
164 | |
165 | template <typename TYPE> |
166 | Status LoadBinary() { |
167 | out_->buffers.resize(3); |
168 | |
169 | RETURN_NOT_OK(LoadCommon()); |
170 | RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); |
171 | return GetBuffer(context_->buffer_index++, &out_->buffers[2]); |
172 | } |
173 | |
174 | Status LoadChild(const Field& field, ArrayData* out) { |
175 | ArrayLoader loader(field.type(), out, context_); |
176 | --context_->max_recursion_depth; |
177 | RETURN_NOT_OK(loader.Load()); |
178 | ++context_->max_recursion_depth; |
179 | return Status::OK(); |
180 | } |
181 | |
182 | Status LoadChildren(std::vector<std::shared_ptr<Field>> child_fields) { |
183 | out_->child_data.reserve(static_cast<int>(child_fields.size())); |
184 | |
185 | for (const auto& child_field : child_fields) { |
186 | auto field_array = std::make_shared<ArrayData>(); |
187 | RETURN_NOT_OK(LoadChild(*child_field.get(), field_array.get())); |
188 | out_->child_data.emplace_back(field_array); |
189 | } |
190 | return Status::OK(); |
191 | } |
192 | |
193 | Status Visit(const NullType& type) { |
194 | out_->buffers.resize(1); |
195 | RETURN_NOT_OK(LoadCommon()); |
196 | RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[0])); |
197 | return Status::OK(); |
198 | } |
199 | |
200 | template <typename T> |
201 | typename std::enable_if<std::is_base_of<FixedWidthType, T>::value && |
202 | !std::is_base_of<FixedSizeBinaryType, T>::value && |
203 | !std::is_base_of<DictionaryType, T>::value, |
204 | Status>::type |
205 | Visit(const T& type) { |
206 | return LoadPrimitive<T>(); |
207 | } |
208 | |
209 | template <typename T> |
210 | typename std::enable_if<std::is_base_of<BinaryType, T>::value, Status>::type Visit( |
211 | const T& type) { |
212 | return LoadBinary<T>(); |
213 | } |
214 | |
215 | Status Visit(const FixedSizeBinaryType& type) { |
216 | out_->buffers.resize(2); |
217 | RETURN_NOT_OK(LoadCommon()); |
218 | return GetBuffer(context_->buffer_index++, &out_->buffers[1]); |
219 | } |
220 | |
221 | Status Visit(const ListType& type) { |
222 | out_->buffers.resize(2); |
223 | |
224 | RETURN_NOT_OK(LoadCommon()); |
225 | RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); |
226 | |
227 | const int num_children = type.num_children(); |
228 | if (num_children != 1) { |
229 | return Status::Invalid("Wrong number of children: " , num_children); |
230 | } |
231 | |
232 | return LoadChildren(type.children()); |
233 | } |
234 | |
235 | Status Visit(const StructType& type) { |
236 | out_->buffers.resize(1); |
237 | RETURN_NOT_OK(LoadCommon()); |
238 | return LoadChildren(type.children()); |
239 | } |
240 | |
241 | Status Visit(const UnionType& type) { |
242 | out_->buffers.resize(3); |
243 | |
244 | RETURN_NOT_OK(LoadCommon()); |
245 | if (out_->length > 0) { |
246 | RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[1])); |
247 | if (type.mode() == UnionMode::DENSE) { |
248 | RETURN_NOT_OK(GetBuffer(context_->buffer_index + 1, &out_->buffers[2])); |
249 | } |
250 | } |
251 | context_->buffer_index += type.mode() == UnionMode::DENSE ? 2 : 1; |
252 | return LoadChildren(type.children()); |
253 | } |
254 | |
255 | Status Visit(const DictionaryType& type) { |
256 | RETURN_NOT_OK(LoadArray(type.index_type(), context_, out_)); |
257 | out_->type = type_; |
258 | return Status::OK(); |
259 | } |
260 | |
261 | private: |
262 | const std::shared_ptr<DataType> type_; |
263 | ArrayLoaderContext* context_; |
264 | |
265 | // Used in visitor pattern |
266 | ArrayData* out_; |
267 | }; |
268 | |
269 | static Status LoadArray(const std::shared_ptr<DataType>& type, |
270 | ArrayLoaderContext* context, ArrayData* out) { |
271 | ArrayLoader loader(type, out, context); |
272 | return loader.Load(); |
273 | } |
274 | |
275 | Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema, |
276 | io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) { |
277 | return ReadRecordBatch(metadata, schema, kMaxNestingDepth, file, out); |
278 | } |
279 | |
280 | Status ReadRecordBatch(const Message& message, const std::shared_ptr<Schema>& schema, |
281 | std::shared_ptr<RecordBatch>* out) { |
282 | io::BufferReader reader(message.body()); |
283 | DCHECK_EQ(message.type(), Message::RECORD_BATCH); |
284 | return ReadRecordBatch(*message.metadata(), schema, kMaxNestingDepth, &reader, out); |
285 | } |
286 | |
287 | // ---------------------------------------------------------------------- |
288 | // Array loading |
289 | |
290 | static Status LoadRecordBatchFromSource(const std::shared_ptr<Schema>& schema, |
291 | int64_t num_rows, int max_recursion_depth, |
292 | IpcComponentSource* source, |
293 | std::shared_ptr<RecordBatch>* out) { |
294 | ArrayLoaderContext context; |
295 | context.source = source; |
296 | context.field_index = 0; |
297 | context.buffer_index = 0; |
298 | context.max_recursion_depth = max_recursion_depth; |
299 | |
300 | std::vector<std::shared_ptr<ArrayData>> arrays(schema->num_fields()); |
301 | for (int i = 0; i < schema->num_fields(); ++i) { |
302 | auto arr = std::make_shared<ArrayData>(); |
303 | RETURN_NOT_OK(LoadArray(schema->field(i)->type(), &context, arr.get())); |
304 | DCHECK_EQ(num_rows, arr->length) << "Array length did not match record batch length" ; |
305 | arrays[i] = std::move(arr); |
306 | } |
307 | |
308 | *out = RecordBatch::Make(schema, num_rows, std::move(arrays)); |
309 | return Status::OK(); |
310 | } |
311 | |
312 | static inline Status ReadRecordBatch(const flatbuf::RecordBatch* metadata, |
313 | const std::shared_ptr<Schema>& schema, |
314 | int max_recursion_depth, io::RandomAccessFile* file, |
315 | std::shared_ptr<RecordBatch>* out) { |
316 | IpcComponentSource source(metadata, file); |
317 | return LoadRecordBatchFromSource(schema, metadata->length(), max_recursion_depth, |
318 | &source, out); |
319 | } |
320 | |
321 | Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema, |
322 | int max_recursion_depth, io::RandomAccessFile* file, |
323 | std::shared_ptr<RecordBatch>* out) { |
324 | auto message = flatbuf::GetMessage(metadata.data()); |
325 | if (message->header_type() != flatbuf::MessageHeader_RecordBatch) { |
326 | DCHECK_EQ(message->header_type(), flatbuf::MessageHeader_RecordBatch); |
327 | } |
328 | if (message->header() == nullptr) { |
329 | return Status::IOError("Header-pointer of flatbuffer-encoded Message is null." ); |
330 | } |
331 | auto batch = reinterpret_cast<const flatbuf::RecordBatch*>(message->header()); |
332 | return ReadRecordBatch(batch, schema, max_recursion_depth, file, out); |
333 | } |
334 | |
335 | Status ReadDictionary(const Buffer& metadata, const DictionaryTypeMap& dictionary_types, |
336 | io::RandomAccessFile* file, int64_t* dictionary_id, |
337 | std::shared_ptr<Array>* out) { |
338 | auto message = flatbuf::GetMessage(metadata.data()); |
339 | auto dictionary_batch = |
340 | reinterpret_cast<const flatbuf::DictionaryBatch*>(message->header()); |
341 | |
342 | int64_t id = *dictionary_id = dictionary_batch->id(); |
343 | auto it = dictionary_types.find(id); |
344 | if (it == dictionary_types.end()) { |
345 | return Status::KeyError("Do not have type metadata for dictionary with id: " , id); |
346 | } |
347 | |
348 | std::vector<std::shared_ptr<Field>> fields = {it->second}; |
349 | |
350 | // We need a schema for the record batch |
351 | auto dummy_schema = std::make_shared<Schema>(fields); |
352 | |
353 | // The dictionary is embedded in a record batch with a single column |
354 | std::shared_ptr<RecordBatch> batch; |
355 | auto batch_meta = |
356 | reinterpret_cast<const flatbuf::RecordBatch*>(dictionary_batch->data()); |
357 | RETURN_NOT_OK( |
358 | ReadRecordBatch(batch_meta, dummy_schema, kMaxNestingDepth, file, &batch)); |
359 | if (batch->num_columns() != 1) { |
360 | return Status::Invalid("Dictionary record batch must only contain one field" ); |
361 | } |
362 | |
363 | *out = batch->column(0); |
364 | return Status::OK(); |
365 | } |
366 | |
367 | static Status ReadMessageAndValidate(MessageReader* reader, Message::Type expected_type, |
368 | bool allow_null, std::unique_ptr<Message>* message) { |
369 | RETURN_NOT_OK(reader->ReadNextMessage(message)); |
370 | |
371 | if (!(*message) && !allow_null) { |
372 | return Status::Invalid("Expected " , FormatMessageType(expected_type), |
373 | " message in stream, was null or length 0" ); |
374 | } |
375 | |
376 | if ((*message) == nullptr) { |
377 | return Status::OK(); |
378 | } |
379 | |
380 | if ((*message)->type() != expected_type) { |
381 | return Status::IOError( |
382 | "Message not expected type: " , FormatMessageType(expected_type), |
383 | ", was: " , (*message)->type()); |
384 | } |
385 | return Status::OK(); |
386 | } |
387 | |
388 | // ---------------------------------------------------------------------- |
389 | // RecordBatchStreamReader implementation |
390 | |
391 | static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { |
392 | return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()}; |
393 | } |
394 | |
395 | class RecordBatchStreamReader::RecordBatchStreamReaderImpl { |
396 | public: |
397 | RecordBatchStreamReaderImpl() {} |
398 | ~RecordBatchStreamReaderImpl() {} |
399 | |
400 | Status Open(std::unique_ptr<MessageReader> message_reader) { |
401 | message_reader_ = std::move(message_reader); |
402 | return ReadSchema(); |
403 | } |
404 | |
405 | Status ReadNextDictionary() { |
406 | std::unique_ptr<Message> message; |
407 | RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::DICTIONARY_BATCH, |
408 | false, &message)); |
409 | |
410 | io::BufferReader reader(message->body()); |
411 | |
412 | std::shared_ptr<Array> dictionary; |
413 | int64_t id; |
414 | RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_types_, &reader, &id, |
415 | &dictionary)); |
416 | return dictionary_memo_.AddDictionary(id, dictionary); |
417 | } |
418 | |
419 | Status ReadSchema() { |
420 | std::unique_ptr<Message> message; |
421 | RETURN_NOT_OK( |
422 | ReadMessageAndValidate(message_reader_.get(), Message::SCHEMA, false, &message)); |
423 | |
424 | if (message->header() == nullptr) { |
425 | return Status::IOError("Header-pointer of flatbuffer-encoded Message is null." ); |
426 | } |
427 | RETURN_NOT_OK(internal::GetDictionaryTypes(message->header(), &dictionary_types_)); |
428 | |
429 | // TODO(wesm): In future, we may want to reconcile the ids in the stream with |
430 | // those found in the schema |
431 | int num_dictionaries = static_cast<int>(dictionary_types_.size()); |
432 | for (int i = 0; i < num_dictionaries; ++i) { |
433 | RETURN_NOT_OK(ReadNextDictionary()); |
434 | } |
435 | |
436 | return internal::GetSchema(message->header(), dictionary_memo_, &schema_); |
437 | } |
438 | |
439 | Status ReadNext(std::shared_ptr<RecordBatch>* batch) { |
440 | std::unique_ptr<Message> message; |
441 | RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::RECORD_BATCH, |
442 | true, &message)); |
443 | |
444 | if (message == nullptr) { |
445 | // End of stream |
446 | *batch = nullptr; |
447 | return Status::OK(); |
448 | } |
449 | |
450 | io::BufferReader reader(message->body()); |
451 | return ReadRecordBatch(*message->metadata(), schema_, &reader, batch); |
452 | } |
453 | |
454 | std::shared_ptr<Schema> schema() const { return schema_; } |
455 | |
456 | private: |
457 | std::unique_ptr<MessageReader> message_reader_; |
458 | |
459 | // dictionary_id -> type |
460 | DictionaryTypeMap dictionary_types_; |
461 | DictionaryMemo dictionary_memo_; |
462 | std::shared_ptr<Schema> schema_; |
463 | }; |
464 | |
465 | RecordBatchStreamReader::RecordBatchStreamReader() { |
466 | impl_.reset(new RecordBatchStreamReaderImpl()); |
467 | } |
468 | |
469 | RecordBatchStreamReader::~RecordBatchStreamReader() {} |
470 | |
471 | Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader, |
472 | std::shared_ptr<RecordBatchReader>* reader) { |
473 | // Private ctor |
474 | auto result = std::shared_ptr<RecordBatchStreamReader>(new RecordBatchStreamReader()); |
475 | RETURN_NOT_OK(result->impl_->Open(std::move(message_reader))); |
476 | *reader = result; |
477 | return Status::OK(); |
478 | } |
479 | |
480 | Status RecordBatchStreamReader::Open(io::InputStream* stream, |
481 | std::shared_ptr<RecordBatchReader>* out) { |
482 | return Open(MessageReader::Open(stream), out); |
483 | } |
484 | |
485 | Status RecordBatchStreamReader::Open(const std::shared_ptr<io::InputStream>& stream, |
486 | std::shared_ptr<RecordBatchReader>* out) { |
487 | return Open(MessageReader::Open(stream), out); |
488 | } |
489 | |
490 | std::shared_ptr<Schema> RecordBatchStreamReader::schema() const { |
491 | return impl_->schema(); |
492 | } |
493 | |
494 | Status RecordBatchStreamReader::ReadNext(std::shared_ptr<RecordBatch>* batch) { |
495 | return impl_->ReadNext(batch); |
496 | } |
497 | |
498 | // ---------------------------------------------------------------------- |
499 | // Reader implementation |
500 | |
501 | class RecordBatchFileReader::RecordBatchFileReaderImpl { |
502 | public: |
503 | RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) { |
504 | dictionary_memo_ = std::make_shared<DictionaryMemo>(); |
505 | } |
506 | |
507 | Status () { |
508 | int magic_size = static_cast<int>(strlen(kArrowMagicBytes)); |
509 | |
510 | if (footer_offset_ <= magic_size * 2 + 4) { |
511 | return Status::Invalid("File is too small: " , footer_offset_); |
512 | } |
513 | |
514 | std::shared_ptr<Buffer> buffer; |
515 | int file_end_size = static_cast<int>(magic_size + sizeof(int32_t)); |
516 | RETURN_NOT_OK(file_->ReadAt(footer_offset_ - file_end_size, file_end_size, &buffer)); |
517 | |
518 | const int64_t = magic_size + sizeof(int32_t); |
519 | if (buffer->size() < expected_footer_size) { |
520 | return Status::Invalid("Unable to read " , expected_footer_size, "from end of file" ); |
521 | } |
522 | |
523 | if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) { |
524 | return Status::Invalid("Not an Arrow file" ); |
525 | } |
526 | |
527 | int32_t = *reinterpret_cast<const int32_t*>(buffer->data()); |
528 | |
529 | if (footer_length <= 0 || footer_length + magic_size * 2 + 4 > footer_offset_) { |
530 | return Status::Invalid("File is smaller than indicated metadata size" ); |
531 | } |
532 | |
533 | // Now read the footer |
534 | RETURN_NOT_OK(file_->ReadAt(footer_offset_ - footer_length - file_end_size, |
535 | footer_length, &footer_buffer_)); |
536 | |
537 | // TODO(wesm): Verify the footer |
538 | footer_ = flatbuf::GetFooter(footer_buffer_->data()); |
539 | |
540 | return Status::OK(); |
541 | } |
542 | |
543 | int num_dictionaries() const { return footer_->dictionaries()->size(); } |
544 | |
545 | int num_record_batches() const { return footer_->recordBatches()->size(); } |
546 | |
547 | MetadataVersion version() const { |
548 | return internal::GetMetadataVersion(footer_->version()); |
549 | } |
550 | |
551 | FileBlock record_batch(int i) const { |
552 | return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i)); |
553 | } |
554 | |
555 | FileBlock dictionary(int i) const { |
556 | return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i)); |
557 | } |
558 | |
559 | Status ReadRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) { |
560 | DCHECK_GE(i, 0); |
561 | DCHECK_LT(i, num_record_batches()); |
562 | FileBlock block = record_batch(i); |
563 | |
564 | DCHECK(BitUtil::IsMultipleOf8(block.offset)); |
565 | DCHECK(BitUtil::IsMultipleOf8(block.metadata_length)); |
566 | DCHECK(BitUtil::IsMultipleOf8(block.body_length)); |
567 | |
568 | std::unique_ptr<Message> message; |
569 | RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message)); |
570 | |
571 | // TODO(wesm): this breaks integration tests, see ARROW-3256 |
572 | // DCHECK_EQ(message->body_length(), block.body_length); |
573 | |
574 | io::BufferReader reader(message->body()); |
575 | return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &reader, batch); |
576 | } |
577 | |
578 | Status ReadSchema() { |
579 | RETURN_NOT_OK(internal::GetDictionaryTypes(footer_->schema(), &dictionary_fields_)); |
580 | |
581 | // Read all the dictionaries |
582 | for (int i = 0; i < num_dictionaries(); ++i) { |
583 | FileBlock block = dictionary(i); |
584 | |
585 | DCHECK(BitUtil::IsMultipleOf8(block.offset)); |
586 | DCHECK(BitUtil::IsMultipleOf8(block.metadata_length)); |
587 | DCHECK(BitUtil::IsMultipleOf8(block.body_length)); |
588 | |
589 | std::unique_ptr<Message> message; |
590 | RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message)); |
591 | |
592 | // TODO(wesm): this breaks integration tests, see ARROW-3256 |
593 | // DCHECK_EQ(message->body_length(), block.body_length); |
594 | |
595 | io::BufferReader reader(message->body()); |
596 | |
597 | std::shared_ptr<Array> dictionary; |
598 | int64_t dictionary_id; |
599 | RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_fields_, &reader, |
600 | &dictionary_id, &dictionary)); |
601 | RETURN_NOT_OK(dictionary_memo_->AddDictionary(dictionary_id, dictionary)); |
602 | } |
603 | |
604 | // Get the schema |
605 | return internal::GetSchema(footer_->schema(), *dictionary_memo_, &schema_); |
606 | } |
607 | |
608 | Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t ) { |
609 | owned_file_ = file; |
610 | return Open(file.get(), footer_offset); |
611 | } |
612 | |
613 | Status Open(io::RandomAccessFile* file, int64_t ) { |
614 | file_ = file; |
615 | footer_offset_ = footer_offset; |
616 | RETURN_NOT_OK(ReadFooter()); |
617 | return ReadSchema(); |
618 | } |
619 | |
620 | std::shared_ptr<Schema> schema() const { return schema_; } |
621 | |
622 | private: |
623 | io::RandomAccessFile* file_; |
624 | |
625 | std::shared_ptr<io::RandomAccessFile> owned_file_; |
626 | |
627 | // The location where the Arrow file layout ends. May be the end of the file |
628 | // or some other location if embedded in a larger file. |
629 | int64_t ; |
630 | |
631 | // Footer metadata |
632 | std::shared_ptr<Buffer> ; |
633 | const flatbuf::Footer* ; |
634 | |
635 | DictionaryTypeMap dictionary_fields_; |
636 | std::shared_ptr<DictionaryMemo> dictionary_memo_; |
637 | |
638 | // Reconstructed schema, including any read dictionaries |
639 | std::shared_ptr<Schema> schema_; |
640 | }; |
641 | |
642 | RecordBatchFileReader::RecordBatchFileReader() { |
643 | impl_.reset(new RecordBatchFileReaderImpl()); |
644 | } |
645 | |
646 | RecordBatchFileReader::~RecordBatchFileReader() {} |
647 | |
648 | Status RecordBatchFileReader::Open(io::RandomAccessFile* file, |
649 | std::shared_ptr<RecordBatchFileReader>* reader) { |
650 | int64_t ; |
651 | RETURN_NOT_OK(file->GetSize(&footer_offset)); |
652 | return Open(file, footer_offset, reader); |
653 | } |
654 | |
655 | Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t , |
656 | std::shared_ptr<RecordBatchFileReader>* reader) { |
657 | *reader = std::shared_ptr<RecordBatchFileReader>(new RecordBatchFileReader()); |
658 | return (*reader)->impl_->Open(file, footer_offset); |
659 | } |
660 | |
661 | Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file, |
662 | std::shared_ptr<RecordBatchFileReader>* reader) { |
663 | int64_t ; |
664 | RETURN_NOT_OK(file->GetSize(&footer_offset)); |
665 | return Open(file, footer_offset, reader); |
666 | } |
667 | |
668 | Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file, |
669 | int64_t , |
670 | std::shared_ptr<RecordBatchFileReader>* reader) { |
671 | *reader = std::shared_ptr<RecordBatchFileReader>(new RecordBatchFileReader()); |
672 | return (*reader)->impl_->Open(file, footer_offset); |
673 | } |
674 | |
675 | std::shared_ptr<Schema> RecordBatchFileReader::schema() const { return impl_->schema(); } |
676 | |
677 | int RecordBatchFileReader::num_record_batches() const { |
678 | return impl_->num_record_batches(); |
679 | } |
680 | |
681 | MetadataVersion RecordBatchFileReader::version() const { return impl_->version(); } |
682 | |
683 | Status RecordBatchFileReader::ReadRecordBatch(int i, |
684 | std::shared_ptr<RecordBatch>* batch) { |
685 | return impl_->ReadRecordBatch(i, batch); |
686 | } |
687 | |
688 | static Status ReadContiguousPayload(io::InputStream* file, |
689 | std::unique_ptr<Message>* message) { |
690 | RETURN_NOT_OK(ReadMessage(file, message)); |
691 | if (*message == nullptr) { |
692 | return Status::Invalid("Unable to read metadata at offset" ); |
693 | } |
694 | return Status::OK(); |
695 | } |
696 | |
697 | Status ReadSchema(io::InputStream* stream, std::shared_ptr<Schema>* out) { |
698 | std::shared_ptr<RecordBatchReader> reader; |
699 | RETURN_NOT_OK(RecordBatchStreamReader::Open(stream, &reader)); |
700 | *out = reader->schema(); |
701 | return Status::OK(); |
702 | } |
703 | |
704 | Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, io::InputStream* file, |
705 | std::shared_ptr<RecordBatch>* out) { |
706 | std::unique_ptr<Message> message; |
707 | RETURN_NOT_OK(ReadContiguousPayload(file, &message)); |
708 | io::BufferReader buffer_reader(message->body()); |
709 | return ReadRecordBatch(*message->metadata(), schema, kMaxNestingDepth, &buffer_reader, |
710 | out); |
711 | } |
712 | |
713 | Status ReadTensor(io::InputStream* file, std::shared_ptr<Tensor>* out) { |
714 | std::unique_ptr<Message> message; |
715 | RETURN_NOT_OK(ReadContiguousPayload(file, &message)); |
716 | return ReadTensor(*message, out); |
717 | } |
718 | |
719 | Status ReadTensor(const Message& message, std::shared_ptr<Tensor>* out) { |
720 | std::shared_ptr<DataType> type; |
721 | std::vector<int64_t> shape; |
722 | std::vector<int64_t> strides; |
723 | std::vector<std::string> dim_names; |
724 | RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides, |
725 | &dim_names)); |
726 | *out = std::make_shared<Tensor>(type, message.body(), shape, strides, dim_names); |
727 | return Status::OK(); |
728 | } |
729 | |
730 | namespace { |
731 | |
732 | Status ReadSparseCOOIndex(const flatbuf::SparseTensor* sparse_tensor, int64_t ndim, |
733 | int64_t non_zero_length, io::RandomAccessFile* file, |
734 | std::shared_ptr<SparseIndex>* out) { |
735 | auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(); |
736 | auto* indices_buffer = sparse_index->indicesBuffer(); |
737 | std::shared_ptr<Buffer> indices_data; |
738 | RETURN_NOT_OK( |
739 | file->ReadAt(indices_buffer->offset(), indices_buffer->length(), &indices_data)); |
740 | std::vector<int64_t> shape({non_zero_length, ndim}); |
741 | const int64_t elsize = sizeof(int64_t); |
742 | std::vector<int64_t> strides({elsize, elsize * non_zero_length}); |
743 | *out = std::make_shared<SparseCOOIndex>( |
744 | std::make_shared<SparseCOOIndex::CoordsTensor>(indices_data, shape, strides)); |
745 | return Status::OK(); |
746 | } |
747 | |
748 | Status ReadSparseCSRIndex(const flatbuf::SparseTensor* sparse_tensor, int64_t ndim, |
749 | int64_t non_zero_length, io::RandomAccessFile* file, |
750 | std::shared_ptr<SparseIndex>* out) { |
751 | auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSR(); |
752 | |
753 | auto* indptr_buffer = sparse_index->indptrBuffer(); |
754 | std::shared_ptr<Buffer> indptr_data; |
755 | RETURN_NOT_OK( |
756 | file->ReadAt(indptr_buffer->offset(), indptr_buffer->length(), &indptr_data)); |
757 | |
758 | auto* indices_buffer = sparse_index->indicesBuffer(); |
759 | std::shared_ptr<Buffer> indices_data; |
760 | RETURN_NOT_OK( |
761 | file->ReadAt(indices_buffer->offset(), indices_buffer->length(), &indices_data)); |
762 | |
763 | std::vector<int64_t> indptr_shape({ndim + 1}); |
764 | std::vector<int64_t> indices_shape({non_zero_length}); |
765 | *out = std::make_shared<SparseCSRIndex>( |
766 | std::make_shared<SparseCSRIndex::IndexTensor>(indptr_data, indptr_shape), |
767 | std::make_shared<SparseCSRIndex::IndexTensor>(indices_data, indices_shape)); |
768 | return Status::OK(); |
769 | } |
770 | |
771 | Status MakeSparseTensorWithSparseCOOIndex( |
772 | const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape, |
773 | const std::vector<std::string>& dim_names, |
774 | const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length, |
775 | const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) { |
776 | *out = std::make_shared<SparseTensorImpl<SparseCOOIndex>>(sparse_index, type, data, |
777 | shape, dim_names); |
778 | return Status::OK(); |
779 | } |
780 | |
781 | Status MakeSparseTensorWithSparseCSRIndex( |
782 | const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape, |
783 | const std::vector<std::string>& dim_names, |
784 | const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length, |
785 | const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) { |
786 | *out = std::make_shared<SparseTensorImpl<SparseCSRIndex>>(sparse_index, type, data, |
787 | shape, dim_names); |
788 | return Status::OK(); |
789 | } |
790 | |
791 | } // namespace |
792 | |
793 | Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file, |
794 | std::shared_ptr<SparseTensor>* out) { |
795 | std::shared_ptr<DataType> type; |
796 | std::vector<int64_t> shape; |
797 | std::vector<std::string> dim_names; |
798 | int64_t non_zero_length; |
799 | SparseTensorFormat::type sparse_tensor_format_id; |
800 | |
801 | RETURN_NOT_OK(internal::GetSparseTensorMetadata( |
802 | metadata, &type, &shape, &dim_names, &non_zero_length, &sparse_tensor_format_id)); |
803 | |
804 | auto message = flatbuf::GetMessage(metadata.data()); |
805 | auto sparse_tensor = reinterpret_cast<const flatbuf::SparseTensor*>(message->header()); |
806 | const flatbuf::Buffer* buffer = sparse_tensor->data(); |
807 | DCHECK(BitUtil::IsMultipleOf8(buffer->offset())) |
808 | << "Buffer of sparse index data " |
809 | << "did not start on 8-byte aligned offset: " << buffer->offset(); |
810 | |
811 | std::shared_ptr<Buffer> data; |
812 | RETURN_NOT_OK(file->ReadAt(buffer->offset(), buffer->length(), &data)); |
813 | |
814 | std::shared_ptr<SparseIndex> sparse_index; |
815 | switch (sparse_tensor_format_id) { |
816 | case SparseTensorFormat::COO: |
817 | RETURN_NOT_OK(ReadSparseCOOIndex(sparse_tensor, shape.size(), non_zero_length, file, |
818 | &sparse_index)); |
819 | return MakeSparseTensorWithSparseCOOIndex( |
820 | type, shape, dim_names, std::dynamic_pointer_cast<SparseCOOIndex>(sparse_index), |
821 | non_zero_length, data, out); |
822 | |
823 | case SparseTensorFormat::CSR: |
824 | RETURN_NOT_OK(ReadSparseCSRIndex(sparse_tensor, shape.size(), non_zero_length, file, |
825 | &sparse_index)); |
826 | return MakeSparseTensorWithSparseCSRIndex( |
827 | type, shape, dim_names, std::dynamic_pointer_cast<SparseCSRIndex>(sparse_index), |
828 | non_zero_length, data, out); |
829 | |
830 | default: |
831 | return Status::Invalid("Unsupported sparse index format" ); |
832 | } |
833 | } |
834 | |
835 | Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out) { |
836 | io::BufferReader buffer_reader(message.body()); |
837 | return ReadSparseTensor(*message.metadata(), &buffer_reader, out); |
838 | } |
839 | |
840 | Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* out) { |
841 | std::unique_ptr<Message> message; |
842 | RETURN_NOT_OK(ReadContiguousPayload(file, &message)); |
843 | DCHECK_EQ(message->type(), Message::SPARSE_TENSOR); |
844 | io::BufferReader buffer_reader(message->body()); |
845 | return ReadSparseTensor(*message->metadata(), &buffer_reader, out); |
846 | } |
847 | |
848 | } // namespace ipc |
849 | } // namespace arrow |
850 | |