1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/ipc/reader.h"
19
20#include <cstdint>
21#include <cstring>
22#include <sstream>
23#include <string>
24#include <type_traits>
25#include <utility>
26#include <vector>
27
28#include <flatbuffers/flatbuffers.h> // IWYU pragma: export
29
30#include "arrow/array.h"
31#include "arrow/buffer.h"
32#include "arrow/io/interfaces.h"
33#include "arrow/io/memory.h"
34#include "arrow/ipc/File_generated.h" // IWYU pragma: export
35#include "arrow/ipc/Message_generated.h"
36#include "arrow/ipc/Schema_generated.h"
37#include "arrow/ipc/dictionary.h"
38#include "arrow/ipc/message.h"
39#include "arrow/ipc/metadata-internal.h"
40#include "arrow/record_batch.h"
41#include "arrow/sparse_tensor.h"
42#include "arrow/status.h"
43#include "arrow/tensor.h"
44#include "arrow/type.h"
45#include "arrow/util/logging.h"
46#include "arrow/visitor_inline.h"
47
48namespace arrow {
49
50namespace flatbuf = org::apache::arrow::flatbuf;
51
52namespace ipc {
53
54using internal::FileBlock;
55using internal::kArrowMagicBytes;
56
57// ----------------------------------------------------------------------
58// Record batch read path
59
60/// Accessor class for flatbuffers metadata
61class IpcComponentSource {
62 public:
63 IpcComponentSource(const flatbuf::RecordBatch* metadata, io::RandomAccessFile* file)
64 : metadata_(metadata), file_(file) {}
65
66 Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
67 const flatbuf::Buffer* buffer = metadata_->buffers()->Get(buffer_index);
68
69 if (buffer->length() == 0) {
70 *out = nullptr;
71 return Status::OK();
72 } else {
73 DCHECK(BitUtil::IsMultipleOf8(buffer->offset()))
74 << "Buffer " << buffer_index
75 << " did not start on 8-byte aligned offset: " << buffer->offset();
76 return file_->ReadAt(buffer->offset(), buffer->length(), out);
77 }
78 }
79
80 Status GetFieldMetadata(int field_index, ArrayData* out) {
81 auto nodes = metadata_->nodes();
82 // pop off a field
83 if (field_index >= static_cast<int>(nodes->size())) {
84 return Status::Invalid("Ran out of field metadata, likely malformed");
85 }
86 const flatbuf::FieldNode* node = nodes->Get(field_index);
87
88 out->length = node->length();
89 out->null_count = node->null_count();
90 out->offset = 0;
91 return Status::OK();
92 }
93
94 private:
95 const flatbuf::RecordBatch* metadata_;
96 io::RandomAccessFile* file_;
97};
98
99/// Bookkeeping struct for loading array objects from their constituent pieces of raw data
100///
101/// The field_index and buffer_index are incremented in the ArrayLoader
102/// based on how much of the batch is "consumed" (through nested data
103/// reconstruction, for example)
104struct ArrayLoaderContext {
105 IpcComponentSource* source;
106 int buffer_index;
107 int field_index;
108 int max_recursion_depth;
109};
110
111static Status LoadArray(const std::shared_ptr<DataType>& type,
112 ArrayLoaderContext* context, ArrayData* out);
113
114class ArrayLoader {
115 public:
116 ArrayLoader(const std::shared_ptr<DataType>& type, ArrayData* out,
117 ArrayLoaderContext* context)
118 : type_(type), context_(context), out_(out) {}
119
120 Status Load() {
121 if (context_->max_recursion_depth <= 0) {
122 return Status::Invalid("Max recursion depth reached");
123 }
124
125 out_->type = type_;
126
127 RETURN_NOT_OK(VisitTypeInline(*type_, this));
128 return Status::OK();
129 }
130
131 Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
132 return context_->source->GetBuffer(buffer_index, out);
133 }
134
135 Status LoadCommon() {
136 // This only contains the length and null count, which we need to figure
137 // out what to do with the buffers. For example, if null_count == 0, then
138 // we can skip that buffer without reading from shared memory
139 RETURN_NOT_OK(context_->source->GetFieldMetadata(context_->field_index++, out_));
140
141 // extract null_bitmap which is common to all arrays
142 if (out_->null_count == 0) {
143 out_->buffers[0] = nullptr;
144 } else {
145 RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[0]));
146 }
147 context_->buffer_index++;
148 return Status::OK();
149 }
150
151 template <typename TYPE>
152 Status LoadPrimitive() {
153 out_->buffers.resize(2);
154
155 RETURN_NOT_OK(LoadCommon());
156 if (out_->length > 0) {
157 RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1]));
158 } else {
159 context_->buffer_index++;
160 out_->buffers[1].reset(new Buffer(nullptr, 0));
161 }
162 return Status::OK();
163 }
164
165 template <typename TYPE>
166 Status LoadBinary() {
167 out_->buffers.resize(3);
168
169 RETURN_NOT_OK(LoadCommon());
170 RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1]));
171 return GetBuffer(context_->buffer_index++, &out_->buffers[2]);
172 }
173
174 Status LoadChild(const Field& field, ArrayData* out) {
175 ArrayLoader loader(field.type(), out, context_);
176 --context_->max_recursion_depth;
177 RETURN_NOT_OK(loader.Load());
178 ++context_->max_recursion_depth;
179 return Status::OK();
180 }
181
182 Status LoadChildren(std::vector<std::shared_ptr<Field>> child_fields) {
183 out_->child_data.reserve(static_cast<int>(child_fields.size()));
184
185 for (const auto& child_field : child_fields) {
186 auto field_array = std::make_shared<ArrayData>();
187 RETURN_NOT_OK(LoadChild(*child_field.get(), field_array.get()));
188 out_->child_data.emplace_back(field_array);
189 }
190 return Status::OK();
191 }
192
193 Status Visit(const NullType& type) {
194 out_->buffers.resize(1);
195 RETURN_NOT_OK(LoadCommon());
196 RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[0]));
197 return Status::OK();
198 }
199
200 template <typename T>
201 typename std::enable_if<std::is_base_of<FixedWidthType, T>::value &&
202 !std::is_base_of<FixedSizeBinaryType, T>::value &&
203 !std::is_base_of<DictionaryType, T>::value,
204 Status>::type
205 Visit(const T& type) {
206 return LoadPrimitive<T>();
207 }
208
209 template <typename T>
210 typename std::enable_if<std::is_base_of<BinaryType, T>::value, Status>::type Visit(
211 const T& type) {
212 return LoadBinary<T>();
213 }
214
215 Status Visit(const FixedSizeBinaryType& type) {
216 out_->buffers.resize(2);
217 RETURN_NOT_OK(LoadCommon());
218 return GetBuffer(context_->buffer_index++, &out_->buffers[1]);
219 }
220
221 Status Visit(const ListType& type) {
222 out_->buffers.resize(2);
223
224 RETURN_NOT_OK(LoadCommon());
225 RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1]));
226
227 const int num_children = type.num_children();
228 if (num_children != 1) {
229 return Status::Invalid("Wrong number of children: ", num_children);
230 }
231
232 return LoadChildren(type.children());
233 }
234
235 Status Visit(const StructType& type) {
236 out_->buffers.resize(1);
237 RETURN_NOT_OK(LoadCommon());
238 return LoadChildren(type.children());
239 }
240
241 Status Visit(const UnionType& type) {
242 out_->buffers.resize(3);
243
244 RETURN_NOT_OK(LoadCommon());
245 if (out_->length > 0) {
246 RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[1]));
247 if (type.mode() == UnionMode::DENSE) {
248 RETURN_NOT_OK(GetBuffer(context_->buffer_index + 1, &out_->buffers[2]));
249 }
250 }
251 context_->buffer_index += type.mode() == UnionMode::DENSE ? 2 : 1;
252 return LoadChildren(type.children());
253 }
254
255 Status Visit(const DictionaryType& type) {
256 RETURN_NOT_OK(LoadArray(type.index_type(), context_, out_));
257 out_->type = type_;
258 return Status::OK();
259 }
260
261 private:
262 const std::shared_ptr<DataType> type_;
263 ArrayLoaderContext* context_;
264
265 // Used in visitor pattern
266 ArrayData* out_;
267};
268
269static Status LoadArray(const std::shared_ptr<DataType>& type,
270 ArrayLoaderContext* context, ArrayData* out) {
271 ArrayLoader loader(type, out, context);
272 return loader.Load();
273}
274
275Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema,
276 io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) {
277 return ReadRecordBatch(metadata, schema, kMaxNestingDepth, file, out);
278}
279
280Status ReadRecordBatch(const Message& message, const std::shared_ptr<Schema>& schema,
281 std::shared_ptr<RecordBatch>* out) {
282 io::BufferReader reader(message.body());
283 DCHECK_EQ(message.type(), Message::RECORD_BATCH);
284 return ReadRecordBatch(*message.metadata(), schema, kMaxNestingDepth, &reader, out);
285}
286
287// ----------------------------------------------------------------------
288// Array loading
289
290static Status LoadRecordBatchFromSource(const std::shared_ptr<Schema>& schema,
291 int64_t num_rows, int max_recursion_depth,
292 IpcComponentSource* source,
293 std::shared_ptr<RecordBatch>* out) {
294 ArrayLoaderContext context;
295 context.source = source;
296 context.field_index = 0;
297 context.buffer_index = 0;
298 context.max_recursion_depth = max_recursion_depth;
299
300 std::vector<std::shared_ptr<ArrayData>> arrays(schema->num_fields());
301 for (int i = 0; i < schema->num_fields(); ++i) {
302 auto arr = std::make_shared<ArrayData>();
303 RETURN_NOT_OK(LoadArray(schema->field(i)->type(), &context, arr.get()));
304 DCHECK_EQ(num_rows, arr->length) << "Array length did not match record batch length";
305 arrays[i] = std::move(arr);
306 }
307
308 *out = RecordBatch::Make(schema, num_rows, std::move(arrays));
309 return Status::OK();
310}
311
312static inline Status ReadRecordBatch(const flatbuf::RecordBatch* metadata,
313 const std::shared_ptr<Schema>& schema,
314 int max_recursion_depth, io::RandomAccessFile* file,
315 std::shared_ptr<RecordBatch>* out) {
316 IpcComponentSource source(metadata, file);
317 return LoadRecordBatchFromSource(schema, metadata->length(), max_recursion_depth,
318 &source, out);
319}
320
321Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& schema,
322 int max_recursion_depth, io::RandomAccessFile* file,
323 std::shared_ptr<RecordBatch>* out) {
324 auto message = flatbuf::GetMessage(metadata.data());
325 if (message->header_type() != flatbuf::MessageHeader_RecordBatch) {
326 DCHECK_EQ(message->header_type(), flatbuf::MessageHeader_RecordBatch);
327 }
328 if (message->header() == nullptr) {
329 return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
330 }
331 auto batch = reinterpret_cast<const flatbuf::RecordBatch*>(message->header());
332 return ReadRecordBatch(batch, schema, max_recursion_depth, file, out);
333}
334
335Status ReadDictionary(const Buffer& metadata, const DictionaryTypeMap& dictionary_types,
336 io::RandomAccessFile* file, int64_t* dictionary_id,
337 std::shared_ptr<Array>* out) {
338 auto message = flatbuf::GetMessage(metadata.data());
339 auto dictionary_batch =
340 reinterpret_cast<const flatbuf::DictionaryBatch*>(message->header());
341
342 int64_t id = *dictionary_id = dictionary_batch->id();
343 auto it = dictionary_types.find(id);
344 if (it == dictionary_types.end()) {
345 return Status::KeyError("Do not have type metadata for dictionary with id: ", id);
346 }
347
348 std::vector<std::shared_ptr<Field>> fields = {it->second};
349
350 // We need a schema for the record batch
351 auto dummy_schema = std::make_shared<Schema>(fields);
352
353 // The dictionary is embedded in a record batch with a single column
354 std::shared_ptr<RecordBatch> batch;
355 auto batch_meta =
356 reinterpret_cast<const flatbuf::RecordBatch*>(dictionary_batch->data());
357 RETURN_NOT_OK(
358 ReadRecordBatch(batch_meta, dummy_schema, kMaxNestingDepth, file, &batch));
359 if (batch->num_columns() != 1) {
360 return Status::Invalid("Dictionary record batch must only contain one field");
361 }
362
363 *out = batch->column(0);
364 return Status::OK();
365}
366
367static Status ReadMessageAndValidate(MessageReader* reader, Message::Type expected_type,
368 bool allow_null, std::unique_ptr<Message>* message) {
369 RETURN_NOT_OK(reader->ReadNextMessage(message));
370
371 if (!(*message) && !allow_null) {
372 return Status::Invalid("Expected ", FormatMessageType(expected_type),
373 " message in stream, was null or length 0");
374 }
375
376 if ((*message) == nullptr) {
377 return Status::OK();
378 }
379
380 if ((*message)->type() != expected_type) {
381 return Status::IOError(
382 "Message not expected type: ", FormatMessageType(expected_type),
383 ", was: ", (*message)->type());
384 }
385 return Status::OK();
386}
387
388// ----------------------------------------------------------------------
389// RecordBatchStreamReader implementation
390
391static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) {
392 return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()};
393}
394
395class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
396 public:
397 RecordBatchStreamReaderImpl() {}
398 ~RecordBatchStreamReaderImpl() {}
399
400 Status Open(std::unique_ptr<MessageReader> message_reader) {
401 message_reader_ = std::move(message_reader);
402 return ReadSchema();
403 }
404
405 Status ReadNextDictionary() {
406 std::unique_ptr<Message> message;
407 RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::DICTIONARY_BATCH,
408 false, &message));
409
410 io::BufferReader reader(message->body());
411
412 std::shared_ptr<Array> dictionary;
413 int64_t id;
414 RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_types_, &reader, &id,
415 &dictionary));
416 return dictionary_memo_.AddDictionary(id, dictionary);
417 }
418
419 Status ReadSchema() {
420 std::unique_ptr<Message> message;
421 RETURN_NOT_OK(
422 ReadMessageAndValidate(message_reader_.get(), Message::SCHEMA, false, &message));
423
424 if (message->header() == nullptr) {
425 return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
426 }
427 RETURN_NOT_OK(internal::GetDictionaryTypes(message->header(), &dictionary_types_));
428
429 // TODO(wesm): In future, we may want to reconcile the ids in the stream with
430 // those found in the schema
431 int num_dictionaries = static_cast<int>(dictionary_types_.size());
432 for (int i = 0; i < num_dictionaries; ++i) {
433 RETURN_NOT_OK(ReadNextDictionary());
434 }
435
436 return internal::GetSchema(message->header(), dictionary_memo_, &schema_);
437 }
438
439 Status ReadNext(std::shared_ptr<RecordBatch>* batch) {
440 std::unique_ptr<Message> message;
441 RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::RECORD_BATCH,
442 true, &message));
443
444 if (message == nullptr) {
445 // End of stream
446 *batch = nullptr;
447 return Status::OK();
448 }
449
450 io::BufferReader reader(message->body());
451 return ReadRecordBatch(*message->metadata(), schema_, &reader, batch);
452 }
453
454 std::shared_ptr<Schema> schema() const { return schema_; }
455
456 private:
457 std::unique_ptr<MessageReader> message_reader_;
458
459 // dictionary_id -> type
460 DictionaryTypeMap dictionary_types_;
461 DictionaryMemo dictionary_memo_;
462 std::shared_ptr<Schema> schema_;
463};
464
465RecordBatchStreamReader::RecordBatchStreamReader() {
466 impl_.reset(new RecordBatchStreamReaderImpl());
467}
468
469RecordBatchStreamReader::~RecordBatchStreamReader() {}
470
471Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
472 std::shared_ptr<RecordBatchReader>* reader) {
473 // Private ctor
474 auto result = std::shared_ptr<RecordBatchStreamReader>(new RecordBatchStreamReader());
475 RETURN_NOT_OK(result->impl_->Open(std::move(message_reader)));
476 *reader = result;
477 return Status::OK();
478}
479
480Status RecordBatchStreamReader::Open(io::InputStream* stream,
481 std::shared_ptr<RecordBatchReader>* out) {
482 return Open(MessageReader::Open(stream), out);
483}
484
485Status RecordBatchStreamReader::Open(const std::shared_ptr<io::InputStream>& stream,
486 std::shared_ptr<RecordBatchReader>* out) {
487 return Open(MessageReader::Open(stream), out);
488}
489
490std::shared_ptr<Schema> RecordBatchStreamReader::schema() const {
491 return impl_->schema();
492}
493
494Status RecordBatchStreamReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
495 return impl_->ReadNext(batch);
496}
497
498// ----------------------------------------------------------------------
499// Reader implementation
500
501class RecordBatchFileReader::RecordBatchFileReaderImpl {
502 public:
503 RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {
504 dictionary_memo_ = std::make_shared<DictionaryMemo>();
505 }
506
507 Status ReadFooter() {
508 int magic_size = static_cast<int>(strlen(kArrowMagicBytes));
509
510 if (footer_offset_ <= magic_size * 2 + 4) {
511 return Status::Invalid("File is too small: ", footer_offset_);
512 }
513
514 std::shared_ptr<Buffer> buffer;
515 int file_end_size = static_cast<int>(magic_size + sizeof(int32_t));
516 RETURN_NOT_OK(file_->ReadAt(footer_offset_ - file_end_size, file_end_size, &buffer));
517
518 const int64_t expected_footer_size = magic_size + sizeof(int32_t);
519 if (buffer->size() < expected_footer_size) {
520 return Status::Invalid("Unable to read ", expected_footer_size, "from end of file");
521 }
522
523 if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) {
524 return Status::Invalid("Not an Arrow file");
525 }
526
527 int32_t footer_length = *reinterpret_cast<const int32_t*>(buffer->data());
528
529 if (footer_length <= 0 || footer_length + magic_size * 2 + 4 > footer_offset_) {
530 return Status::Invalid("File is smaller than indicated metadata size");
531 }
532
533 // Now read the footer
534 RETURN_NOT_OK(file_->ReadAt(footer_offset_ - footer_length - file_end_size,
535 footer_length, &footer_buffer_));
536
537 // TODO(wesm): Verify the footer
538 footer_ = flatbuf::GetFooter(footer_buffer_->data());
539
540 return Status::OK();
541 }
542
543 int num_dictionaries() const { return footer_->dictionaries()->size(); }
544
545 int num_record_batches() const { return footer_->recordBatches()->size(); }
546
547 MetadataVersion version() const {
548 return internal::GetMetadataVersion(footer_->version());
549 }
550
551 FileBlock record_batch(int i) const {
552 return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
553 }
554
555 FileBlock dictionary(int i) const {
556 return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
557 }
558
559 Status ReadRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) {
560 DCHECK_GE(i, 0);
561 DCHECK_LT(i, num_record_batches());
562 FileBlock block = record_batch(i);
563
564 DCHECK(BitUtil::IsMultipleOf8(block.offset));
565 DCHECK(BitUtil::IsMultipleOf8(block.metadata_length));
566 DCHECK(BitUtil::IsMultipleOf8(block.body_length));
567
568 std::unique_ptr<Message> message;
569 RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message));
570
571 // TODO(wesm): this breaks integration tests, see ARROW-3256
572 // DCHECK_EQ(message->body_length(), block.body_length);
573
574 io::BufferReader reader(message->body());
575 return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &reader, batch);
576 }
577
578 Status ReadSchema() {
579 RETURN_NOT_OK(internal::GetDictionaryTypes(footer_->schema(), &dictionary_fields_));
580
581 // Read all the dictionaries
582 for (int i = 0; i < num_dictionaries(); ++i) {
583 FileBlock block = dictionary(i);
584
585 DCHECK(BitUtil::IsMultipleOf8(block.offset));
586 DCHECK(BitUtil::IsMultipleOf8(block.metadata_length));
587 DCHECK(BitUtil::IsMultipleOf8(block.body_length));
588
589 std::unique_ptr<Message> message;
590 RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message));
591
592 // TODO(wesm): this breaks integration tests, see ARROW-3256
593 // DCHECK_EQ(message->body_length(), block.body_length);
594
595 io::BufferReader reader(message->body());
596
597 std::shared_ptr<Array> dictionary;
598 int64_t dictionary_id;
599 RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_fields_, &reader,
600 &dictionary_id, &dictionary));
601 RETURN_NOT_OK(dictionary_memo_->AddDictionary(dictionary_id, dictionary));
602 }
603
604 // Get the schema
605 return internal::GetSchema(footer_->schema(), *dictionary_memo_, &schema_);
606 }
607
608 Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset) {
609 owned_file_ = file;
610 return Open(file.get(), footer_offset);
611 }
612
613 Status Open(io::RandomAccessFile* file, int64_t footer_offset) {
614 file_ = file;
615 footer_offset_ = footer_offset;
616 RETURN_NOT_OK(ReadFooter());
617 return ReadSchema();
618 }
619
620 std::shared_ptr<Schema> schema() const { return schema_; }
621
622 private:
623 io::RandomAccessFile* file_;
624
625 std::shared_ptr<io::RandomAccessFile> owned_file_;
626
627 // The location where the Arrow file layout ends. May be the end of the file
628 // or some other location if embedded in a larger file.
629 int64_t footer_offset_;
630
631 // Footer metadata
632 std::shared_ptr<Buffer> footer_buffer_;
633 const flatbuf::Footer* footer_;
634
635 DictionaryTypeMap dictionary_fields_;
636 std::shared_ptr<DictionaryMemo> dictionary_memo_;
637
638 // Reconstructed schema, including any read dictionaries
639 std::shared_ptr<Schema> schema_;
640};
641
642RecordBatchFileReader::RecordBatchFileReader() {
643 impl_.reset(new RecordBatchFileReaderImpl());
644}
645
646RecordBatchFileReader::~RecordBatchFileReader() {}
647
648Status RecordBatchFileReader::Open(io::RandomAccessFile* file,
649 std::shared_ptr<RecordBatchFileReader>* reader) {
650 int64_t footer_offset;
651 RETURN_NOT_OK(file->GetSize(&footer_offset));
652 return Open(file, footer_offset, reader);
653}
654
655Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t footer_offset,
656 std::shared_ptr<RecordBatchFileReader>* reader) {
657 *reader = std::shared_ptr<RecordBatchFileReader>(new RecordBatchFileReader());
658 return (*reader)->impl_->Open(file, footer_offset);
659}
660
661Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
662 std::shared_ptr<RecordBatchFileReader>* reader) {
663 int64_t footer_offset;
664 RETURN_NOT_OK(file->GetSize(&footer_offset));
665 return Open(file, footer_offset, reader);
666}
667
668Status RecordBatchFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
669 int64_t footer_offset,
670 std::shared_ptr<RecordBatchFileReader>* reader) {
671 *reader = std::shared_ptr<RecordBatchFileReader>(new RecordBatchFileReader());
672 return (*reader)->impl_->Open(file, footer_offset);
673}
674
675std::shared_ptr<Schema> RecordBatchFileReader::schema() const { return impl_->schema(); }
676
677int RecordBatchFileReader::num_record_batches() const {
678 return impl_->num_record_batches();
679}
680
681MetadataVersion RecordBatchFileReader::version() const { return impl_->version(); }
682
683Status RecordBatchFileReader::ReadRecordBatch(int i,
684 std::shared_ptr<RecordBatch>* batch) {
685 return impl_->ReadRecordBatch(i, batch);
686}
687
688static Status ReadContiguousPayload(io::InputStream* file,
689 std::unique_ptr<Message>* message) {
690 RETURN_NOT_OK(ReadMessage(file, message));
691 if (*message == nullptr) {
692 return Status::Invalid("Unable to read metadata at offset");
693 }
694 return Status::OK();
695}
696
697Status ReadSchema(io::InputStream* stream, std::shared_ptr<Schema>* out) {
698 std::shared_ptr<RecordBatchReader> reader;
699 RETURN_NOT_OK(RecordBatchStreamReader::Open(stream, &reader));
700 *out = reader->schema();
701 return Status::OK();
702}
703
704Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, io::InputStream* file,
705 std::shared_ptr<RecordBatch>* out) {
706 std::unique_ptr<Message> message;
707 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
708 io::BufferReader buffer_reader(message->body());
709 return ReadRecordBatch(*message->metadata(), schema, kMaxNestingDepth, &buffer_reader,
710 out);
711}
712
713Status ReadTensor(io::InputStream* file, std::shared_ptr<Tensor>* out) {
714 std::unique_ptr<Message> message;
715 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
716 return ReadTensor(*message, out);
717}
718
719Status ReadTensor(const Message& message, std::shared_ptr<Tensor>* out) {
720 std::shared_ptr<DataType> type;
721 std::vector<int64_t> shape;
722 std::vector<int64_t> strides;
723 std::vector<std::string> dim_names;
724 RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
725 &dim_names));
726 *out = std::make_shared<Tensor>(type, message.body(), shape, strides, dim_names);
727 return Status::OK();
728}
729
730namespace {
731
732Status ReadSparseCOOIndex(const flatbuf::SparseTensor* sparse_tensor, int64_t ndim,
733 int64_t non_zero_length, io::RandomAccessFile* file,
734 std::shared_ptr<SparseIndex>* out) {
735 auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO();
736 auto* indices_buffer = sparse_index->indicesBuffer();
737 std::shared_ptr<Buffer> indices_data;
738 RETURN_NOT_OK(
739 file->ReadAt(indices_buffer->offset(), indices_buffer->length(), &indices_data));
740 std::vector<int64_t> shape({non_zero_length, ndim});
741 const int64_t elsize = sizeof(int64_t);
742 std::vector<int64_t> strides({elsize, elsize * non_zero_length});
743 *out = std::make_shared<SparseCOOIndex>(
744 std::make_shared<SparseCOOIndex::CoordsTensor>(indices_data, shape, strides));
745 return Status::OK();
746}
747
748Status ReadSparseCSRIndex(const flatbuf::SparseTensor* sparse_tensor, int64_t ndim,
749 int64_t non_zero_length, io::RandomAccessFile* file,
750 std::shared_ptr<SparseIndex>* out) {
751 auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSR();
752
753 auto* indptr_buffer = sparse_index->indptrBuffer();
754 std::shared_ptr<Buffer> indptr_data;
755 RETURN_NOT_OK(
756 file->ReadAt(indptr_buffer->offset(), indptr_buffer->length(), &indptr_data));
757
758 auto* indices_buffer = sparse_index->indicesBuffer();
759 std::shared_ptr<Buffer> indices_data;
760 RETURN_NOT_OK(
761 file->ReadAt(indices_buffer->offset(), indices_buffer->length(), &indices_data));
762
763 std::vector<int64_t> indptr_shape({ndim + 1});
764 std::vector<int64_t> indices_shape({non_zero_length});
765 *out = std::make_shared<SparseCSRIndex>(
766 std::make_shared<SparseCSRIndex::IndexTensor>(indptr_data, indptr_shape),
767 std::make_shared<SparseCSRIndex::IndexTensor>(indices_data, indices_shape));
768 return Status::OK();
769}
770
771Status MakeSparseTensorWithSparseCOOIndex(
772 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
773 const std::vector<std::string>& dim_names,
774 const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
775 const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) {
776 *out = std::make_shared<SparseTensorImpl<SparseCOOIndex>>(sparse_index, type, data,
777 shape, dim_names);
778 return Status::OK();
779}
780
781Status MakeSparseTensorWithSparseCSRIndex(
782 const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
783 const std::vector<std::string>& dim_names,
784 const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
785 const std::shared_ptr<Buffer>& data, std::shared_ptr<SparseTensor>* out) {
786 *out = std::make_shared<SparseTensorImpl<SparseCSRIndex>>(sparse_index, type, data,
787 shape, dim_names);
788 return Status::OK();
789}
790
791} // namespace
792
793Status ReadSparseTensor(const Buffer& metadata, io::RandomAccessFile* file,
794 std::shared_ptr<SparseTensor>* out) {
795 std::shared_ptr<DataType> type;
796 std::vector<int64_t> shape;
797 std::vector<std::string> dim_names;
798 int64_t non_zero_length;
799 SparseTensorFormat::type sparse_tensor_format_id;
800
801 RETURN_NOT_OK(internal::GetSparseTensorMetadata(
802 metadata, &type, &shape, &dim_names, &non_zero_length, &sparse_tensor_format_id));
803
804 auto message = flatbuf::GetMessage(metadata.data());
805 auto sparse_tensor = reinterpret_cast<const flatbuf::SparseTensor*>(message->header());
806 const flatbuf::Buffer* buffer = sparse_tensor->data();
807 DCHECK(BitUtil::IsMultipleOf8(buffer->offset()))
808 << "Buffer of sparse index data "
809 << "did not start on 8-byte aligned offset: " << buffer->offset();
810
811 std::shared_ptr<Buffer> data;
812 RETURN_NOT_OK(file->ReadAt(buffer->offset(), buffer->length(), &data));
813
814 std::shared_ptr<SparseIndex> sparse_index;
815 switch (sparse_tensor_format_id) {
816 case SparseTensorFormat::COO:
817 RETURN_NOT_OK(ReadSparseCOOIndex(sparse_tensor, shape.size(), non_zero_length, file,
818 &sparse_index));
819 return MakeSparseTensorWithSparseCOOIndex(
820 type, shape, dim_names, std::dynamic_pointer_cast<SparseCOOIndex>(sparse_index),
821 non_zero_length, data, out);
822
823 case SparseTensorFormat::CSR:
824 RETURN_NOT_OK(ReadSparseCSRIndex(sparse_tensor, shape.size(), non_zero_length, file,
825 &sparse_index));
826 return MakeSparseTensorWithSparseCSRIndex(
827 type, shape, dim_names, std::dynamic_pointer_cast<SparseCSRIndex>(sparse_index),
828 non_zero_length, data, out);
829
830 default:
831 return Status::Invalid("Unsupported sparse index format");
832 }
833}
834
835Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* out) {
836 io::BufferReader buffer_reader(message.body());
837 return ReadSparseTensor(*message.metadata(), &buffer_reader, out);
838}
839
840Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* out) {
841 std::unique_ptr<Message> message;
842 RETURN_NOT_OK(ReadContiguousPayload(file, &message));
843 DCHECK_EQ(message->type(), Message::SPARSE_TENSOR);
844 io::BufferReader buffer_reader(message->body());
845 return ReadSparseTensor(*message->metadata(), &buffer_reader, out);
846}
847
848} // namespace ipc
849} // namespace arrow
850