1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include <cstdint>
19#include <limits>
20#include <memory>
21#include <ostream>
22#include <string>
23
24#include <flatbuffers/flatbuffers.h>
25#include <gtest/gtest.h>
26
27#include "arrow/array.h"
28#include "arrow/buffer.h"
29#include "arrow/builder.h"
30#include "arrow/io/file.h"
31#include "arrow/io/memory.h"
32#include "arrow/io/test-common.h"
33#include "arrow/ipc/Message_generated.h" // IWYU pragma: keep
34#include "arrow/ipc/message.h"
35#include "arrow/ipc/metadata-internal.h"
36#include "arrow/ipc/reader.h"
37#include "arrow/ipc/test-common.h"
38#include "arrow/ipc/writer.h"
39#include "arrow/memory_pool.h"
40#include "arrow/record_batch.h"
41#include "arrow/sparse_tensor.h"
42#include "arrow/status.h"
43#include "arrow/tensor.h"
44#include "arrow/test-util.h"
45#include "arrow/type.h"
46#include "arrow/util/bit-util.h"
47#include "arrow/util/checked_cast.h"
48
49namespace arrow {
50
51using internal::checked_cast;
52
53namespace ipc {
54
55using BatchVector = std::vector<std::shared_ptr<RecordBatch>>;
56
57class TestSchemaMetadata : public ::testing::Test {
58 public:
59 void SetUp() {}
60
61 void CheckRoundtrip(const Schema& schema) {
62 std::shared_ptr<Buffer> buffer;
63 ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer));
64
65 std::shared_ptr<Schema> result;
66 io::BufferReader reader(buffer);
67 ASSERT_OK(ReadSchema(&reader, &result));
68 AssertSchemaEqual(schema, *result);
69 }
70};
71
72TEST(TestMessage, Equals) {
73 std::string metadata = "foo";
74 std::string body = "bar";
75
76 auto b1 = std::make_shared<Buffer>(metadata);
77 auto b2 = std::make_shared<Buffer>(metadata);
78 auto b3 = std::make_shared<Buffer>(body);
79 auto b4 = std::make_shared<Buffer>(body);
80
81 Message msg1(b1, b3);
82 Message msg2(b2, b4);
83 Message msg3(b1, nullptr);
84 Message msg4(b2, nullptr);
85
86 ASSERT_TRUE(msg1.Equals(msg2));
87 ASSERT_TRUE(msg3.Equals(msg4));
88
89 ASSERT_FALSE(msg1.Equals(msg3));
90 ASSERT_FALSE(msg3.Equals(msg1));
91
92 // same metadata as msg1, different body
93 Message msg5(b2, b1);
94 ASSERT_FALSE(msg1.Equals(msg5));
95 ASSERT_FALSE(msg5.Equals(msg1));
96}
97
98TEST(TestMessage, SerializeTo) {
99 const int64_t body_length = 64;
100
101 flatbuffers::FlatBufferBuilder fbb;
102 fbb.Finish(flatbuf::CreateMessage(fbb, internal::kCurrentMetadataVersion,
103 flatbuf::MessageHeader_RecordBatch, 0 /* header */,
104 body_length));
105
106 std::shared_ptr<Buffer> metadata;
107 ASSERT_OK(internal::WriteFlatbufferBuilder(fbb, &metadata));
108
109 std::string body = "abcdef";
110
111 std::unique_ptr<Message> message;
112 ASSERT_OK(Message::Open(metadata, std::make_shared<Buffer>(body), &message));
113
114 int64_t output_length = 0;
115 int64_t position = 0;
116
117 std::shared_ptr<io::BufferOutputStream> stream;
118
119 {
120 const int32_t alignment = 8;
121
122 ASSERT_OK(io::BufferOutputStream::Create(1 << 10, default_memory_pool(), &stream));
123 ASSERT_OK(message->SerializeTo(stream.get(), alignment, &output_length));
124 ASSERT_OK(stream->Tell(&position));
125 ASSERT_EQ(BitUtil::RoundUp(metadata->size() + 4, alignment) + body_length,
126 output_length);
127 ASSERT_EQ(output_length, position);
128 }
129
130 {
131 const int32_t alignment = 64;
132
133 ASSERT_OK(io::BufferOutputStream::Create(1 << 10, default_memory_pool(), &stream));
134 ASSERT_OK(message->SerializeTo(stream.get(), alignment, &output_length));
135 ASSERT_OK(stream->Tell(&position));
136 ASSERT_EQ(BitUtil::RoundUp(metadata->size() + 4, alignment) + body_length,
137 output_length);
138 ASSERT_EQ(output_length, position);
139 }
140}
141
142TEST(TestMessage, Verify) {
143 std::string metadata = "invalid";
144 std::string body = "abcdef";
145
146 Message message(std::make_shared<Buffer>(metadata), std::make_shared<Buffer>(body));
147 ASSERT_FALSE(message.Verify());
148}
149
150const std::shared_ptr<DataType> INT32 = std::make_shared<Int32Type>();
151
152TEST_F(TestSchemaMetadata, PrimitiveFields) {
153 auto f0 = field("f0", std::make_shared<Int8Type>());
154 auto f1 = field("f1", std::make_shared<Int16Type>(), false);
155 auto f2 = field("f2", std::make_shared<Int32Type>());
156 auto f3 = field("f3", std::make_shared<Int64Type>());
157 auto f4 = field("f4", std::make_shared<UInt8Type>());
158 auto f5 = field("f5", std::make_shared<UInt16Type>());
159 auto f6 = field("f6", std::make_shared<UInt32Type>());
160 auto f7 = field("f7", std::make_shared<UInt64Type>());
161 auto f8 = field("f8", std::make_shared<FloatType>());
162 auto f9 = field("f9", std::make_shared<DoubleType>(), false);
163 auto f10 = field("f10", std::make_shared<BooleanType>());
164
165 Schema schema({f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10});
166 CheckRoundtrip(schema);
167}
168
169TEST_F(TestSchemaMetadata, NestedFields) {
170 auto type = list(int32());
171 auto f0 = field("f0", type);
172
173 std::shared_ptr<StructType> type2(
174 new StructType({field("k1", INT32), field("k2", INT32), field("k3", INT32)}));
175 auto f1 = field("f1", type2);
176
177 Schema schema({f0, f1});
178 CheckRoundtrip(schema);
179}
180
181TEST_F(TestSchemaMetadata, KeyValueMetadata) {
182 auto field_metadata = key_value_metadata({{"key", "value"}});
183 auto schema_metadata = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
184
185 auto f0 = field("f0", std::make_shared<Int8Type>());
186 auto f1 = field("f1", std::make_shared<Int16Type>(), false, field_metadata);
187
188 Schema schema({f0, f1}, schema_metadata);
189 CheckRoundtrip(schema);
190}
191
192#define BATCH_CASES() \
193 ::testing::Values(&MakeIntRecordBatch, &MakeListRecordBatch, &MakeNonNullRecordBatch, \
194 &MakeZeroLengthRecordBatch, &MakeDeeplyNestedList, \
195 &MakeStringTypesRecordBatchWithNulls, &MakeStruct, &MakeUnion, \
196 &MakeDictionary, &MakeDates, &MakeTimestamps, &MakeTimes, \
197 &MakeFWBinary, &MakeNull, &MakeDecimal, &MakeBooleanBatch);
198
199static int g_file_number = 0;
200
201class IpcTestFixture : public io::MemoryMapFixture {
202 public:
203 Status DoSchemaRoundTrip(const Schema& schema, std::shared_ptr<Schema>* result) {
204 std::shared_ptr<Buffer> serialized_schema;
205 RETURN_NOT_OK(SerializeSchema(schema, pool_, &serialized_schema));
206
207 io::BufferReader buf_reader(serialized_schema);
208 return ReadSchema(&buf_reader, result);
209 }
210
211 Status DoStandardRoundTrip(const RecordBatch& batch,
212 std::shared_ptr<RecordBatch>* batch_result) {
213 std::shared_ptr<Buffer> serialized_batch;
214 RETURN_NOT_OK(SerializeRecordBatch(batch, pool_, &serialized_batch));
215
216 io::BufferReader buf_reader(serialized_batch);
217 return ReadRecordBatch(batch.schema(), &buf_reader, batch_result);
218 }
219
220 Status DoLargeRoundTrip(const RecordBatch& batch, bool zero_data,
221 std::shared_ptr<RecordBatch>* result) {
222 if (zero_data) {
223 RETURN_NOT_OK(ZeroMemoryMap(mmap_.get()));
224 }
225 RETURN_NOT_OK(mmap_->Seek(0));
226
227 std::shared_ptr<RecordBatchWriter> file_writer;
228 RETURN_NOT_OK(RecordBatchFileWriter::Open(mmap_.get(), batch.schema(), &file_writer));
229 RETURN_NOT_OK(file_writer->WriteRecordBatch(batch, true));
230 RETURN_NOT_OK(file_writer->Close());
231
232 int64_t offset;
233 RETURN_NOT_OK(mmap_->Tell(&offset));
234
235 std::shared_ptr<RecordBatchFileReader> file_reader;
236 RETURN_NOT_OK(RecordBatchFileReader::Open(mmap_.get(), offset, &file_reader));
237
238 return file_reader->ReadRecordBatch(0, result);
239 }
240
241 void CheckReadResult(const RecordBatch& result, const RecordBatch& expected) {
242 EXPECT_EQ(expected.num_rows(), result.num_rows());
243
244 ASSERT_TRUE(expected.schema()->Equals(*result.schema()));
245 ASSERT_EQ(expected.num_columns(), result.num_columns())
246 << expected.schema()->ToString() << " result: " << result.schema()->ToString();
247
248 CompareBatchColumnsDetailed(result, expected);
249 }
250
251 void CheckRoundtrip(const RecordBatch& batch, int64_t buffer_size) {
252 std::stringstream ss;
253 ss << "test-write-row-batch-" << g_file_number++;
254 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(buffer_size, ss.str(), &mmap_));
255
256 std::shared_ptr<Schema> schema_result;
257 ASSERT_OK(DoSchemaRoundTrip(*batch.schema(), &schema_result));
258 ASSERT_TRUE(batch.schema()->Equals(*schema_result));
259
260 std::shared_ptr<RecordBatch> result;
261 ASSERT_OK(DoStandardRoundTrip(batch, &result));
262 CheckReadResult(*result, batch);
263
264 ASSERT_OK(DoLargeRoundTrip(batch, true, &result));
265 CheckReadResult(*result, batch);
266 }
267
268 void CheckRoundtrip(const std::shared_ptr<Array>& array, int64_t buffer_size) {
269 auto f0 = arrow::field("f0", array->type());
270 std::vector<std::shared_ptr<Field>> fields = {f0};
271 auto schema = std::make_shared<Schema>(fields);
272
273 auto batch = RecordBatch::Make(schema, 0, {array});
274 CheckRoundtrip(*batch, buffer_size);
275 }
276
277 protected:
278 std::shared_ptr<io::MemoryMappedFile> mmap_;
279 MemoryPool* pool_;
280};
281
282class TestWriteRecordBatch : public ::testing::Test, public IpcTestFixture {
283 public:
284 void SetUp() { pool_ = default_memory_pool(); }
285 void TearDown() { io::MemoryMapFixture::TearDown(); }
286};
287
288class TestIpcRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*>,
289 public IpcTestFixture {
290 public:
291 void SetUp() { pool_ = default_memory_pool(); }
292 void TearDown() { io::MemoryMapFixture::TearDown(); }
293};
294
295TEST_P(TestIpcRoundTrip, RoundTrip) {
296 std::shared_ptr<RecordBatch> batch;
297 ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
298
299 CheckRoundtrip(*batch, 1 << 20);
300}
301
302TEST_F(TestIpcRoundTrip, MetadataVersion) {
303 std::shared_ptr<RecordBatch> batch;
304 ASSERT_OK(MakeIntRecordBatch(&batch));
305
306 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(1 << 16, "test-metadata", &mmap_));
307
308 int32_t metadata_length;
309 int64_t body_length;
310
311 const int64_t buffer_offset = 0;
312
313 ASSERT_OK(WriteRecordBatch(*batch, buffer_offset, mmap_.get(), &metadata_length,
314 &body_length, pool_));
315
316 std::unique_ptr<Message> message;
317 ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message));
318
319 ASSERT_EQ(MetadataVersion::V4, message->metadata_version());
320}
321
322TEST_P(TestIpcRoundTrip, SliceRoundTrip) {
323 std::shared_ptr<RecordBatch> batch;
324 ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
325
326 // Skip the zero-length case
327 if (batch->num_rows() < 2) {
328 return;
329 }
330
331 auto sliced_batch = batch->Slice(2, 10);
332 CheckRoundtrip(*sliced_batch, 1 << 20);
333}
334
335TEST_P(TestIpcRoundTrip, ZeroLengthArrays) {
336 std::shared_ptr<RecordBatch> batch;
337 ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
338
339 std::shared_ptr<RecordBatch> zero_length_batch;
340 if (batch->num_rows() > 2) {
341 zero_length_batch = batch->Slice(2, 0);
342 } else {
343 zero_length_batch = batch->Slice(0, 0);
344 }
345
346 CheckRoundtrip(*zero_length_batch, 1 << 20);
347
348 // ARROW-544: check binary array
349 std::shared_ptr<Buffer> value_offsets;
350 ASSERT_OK(AllocateBuffer(pool_, sizeof(int32_t), &value_offsets));
351 *reinterpret_cast<int32_t*>(value_offsets->mutable_data()) = 0;
352
353 std::shared_ptr<Array> bin_array = std::make_shared<BinaryArray>(
354 0, value_offsets, std::make_shared<Buffer>(nullptr, 0),
355 std::make_shared<Buffer>(nullptr, 0));
356
357 // null value_offsets
358 std::shared_ptr<Array> bin_array2 = std::make_shared<BinaryArray>(0, nullptr, nullptr);
359
360 CheckRoundtrip(bin_array, 1 << 20);
361 CheckRoundtrip(bin_array2, 1 << 20);
362}
363
364TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) {
365 auto CheckArray = [this](const std::shared_ptr<Array>& array) {
366 auto f0 = field("f0", array->type());
367 auto schema = ::arrow::schema({f0});
368 auto batch = RecordBatch::Make(schema, array->length(), {array});
369 auto sliced_batch = batch->Slice(0, 5);
370
371 int64_t full_size;
372 int64_t sliced_size;
373
374 ASSERT_OK(GetRecordBatchSize(*batch, &full_size));
375 ASSERT_OK(GetRecordBatchSize(*sliced_batch, &sliced_size));
376 ASSERT_TRUE(sliced_size < full_size) << sliced_size << " " << full_size;
377
378 // make sure we can write and read it
379 this->CheckRoundtrip(*sliced_batch, 1 << 20);
380 };
381
382 std::shared_ptr<Array> a0, a1;
383 auto pool = default_memory_pool();
384
385 // Integer
386 ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0));
387 CheckArray(a0);
388
389 // String / Binary
390 {
391 auto s = MakeRandomBinaryArray<StringBuilder, char>(500, false, pool, &a0);
392 ASSERT_TRUE(s.ok());
393 }
394 CheckArray(a0);
395
396 // Boolean
397 ASSERT_OK(MakeRandomBooleanArray(10000, false, &a0));
398 CheckArray(a0);
399
400 // List
401 ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0));
402 ASSERT_OK(MakeRandomListArray(a0, 200, false, pool, &a1));
403 CheckArray(a1);
404
405 // Struct
406 auto struct_type = struct_({field("f0", a0->type())});
407 std::vector<std::shared_ptr<Array>> struct_children = {a0};
408 a1 = std::make_shared<StructArray>(struct_type, a0->length(), struct_children);
409 CheckArray(a1);
410
411 // Sparse Union
412 auto union_type = union_({field("f0", a0->type())}, {0});
413 std::vector<int32_t> type_ids(a0->length());
414 std::shared_ptr<Buffer> ids_buffer;
415 ASSERT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &ids_buffer));
416 a1 =
417 std::make_shared<UnionArray>(union_type, a0->length(), struct_children, ids_buffer);
418 CheckArray(a1);
419
420 // Dense union
421 auto dense_union_type = union_({field("f0", a0->type())}, {0}, UnionMode::DENSE);
422 std::vector<int32_t> type_offsets;
423 for (int32_t i = 0; i < a0->length(); ++i) {
424 type_offsets.push_back(i);
425 }
426 std::shared_ptr<Buffer> offsets_buffer;
427 ASSERT_OK(CopyBufferFromVector(type_offsets, default_memory_pool(), &offsets_buffer));
428 a1 = std::make_shared<UnionArray>(dense_union_type, a0->length(), struct_children,
429 ids_buffer, offsets_buffer);
430 CheckArray(a1);
431}
432
433void TestGetRecordBatchSize(std::shared_ptr<RecordBatch> batch) {
434 io::MockOutputStream mock;
435 int32_t mock_metadata_length = -1;
436 int64_t mock_body_length = -1;
437 int64_t size = -1;
438 ASSERT_OK(WriteRecordBatch(*batch, 0, &mock, &mock_metadata_length, &mock_body_length,
439 default_memory_pool()));
440 ASSERT_OK(GetRecordBatchSize(*batch, &size));
441 ASSERT_EQ(mock.GetExtentBytesWritten(), size);
442}
443
444TEST_F(TestWriteRecordBatch, IntegerGetRecordBatchSize) {
445 std::shared_ptr<RecordBatch> batch;
446
447 ASSERT_OK(MakeIntRecordBatch(&batch));
448 TestGetRecordBatchSize(batch);
449
450 ASSERT_OK(MakeListRecordBatch(&batch));
451 TestGetRecordBatchSize(batch);
452
453 ASSERT_OK(MakeZeroLengthRecordBatch(&batch));
454 TestGetRecordBatchSize(batch);
455
456 ASSERT_OK(MakeNonNullRecordBatch(&batch));
457 TestGetRecordBatchSize(batch);
458
459 ASSERT_OK(MakeDeeplyNestedList(&batch));
460 TestGetRecordBatchSize(batch);
461}
462
463class RecursionLimits : public ::testing::Test, public io::MemoryMapFixture {
464 public:
465 void SetUp() { pool_ = default_memory_pool(); }
466 void TearDown() { io::MemoryMapFixture::TearDown(); }
467
468 Status WriteToMmap(int recursion_level, bool override_level, int32_t* metadata_length,
469 int64_t* body_length, std::shared_ptr<RecordBatch>* batch,
470 std::shared_ptr<Schema>* schema) {
471 const int batch_length = 5;
472 auto type = int32();
473 std::shared_ptr<Array> array;
474 const bool include_nulls = true;
475 RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool_, &array));
476 for (int i = 0; i < recursion_level; ++i) {
477 type = list(type);
478 RETURN_NOT_OK(
479 MakeRandomListArray(array, batch_length, include_nulls, pool_, &array));
480 }
481
482 auto f0 = field("f0", type);
483
484 *schema = ::arrow::schema({f0});
485
486 *batch = RecordBatch::Make(*schema, batch_length, {array});
487
488 std::stringstream ss;
489 ss << "test-write-past-max-recursion-" << g_file_number++;
490 const int memory_map_size = 1 << 20;
491 RETURN_NOT_OK(io::MemoryMapFixture::InitMemoryMap(memory_map_size, ss.str(), &mmap_));
492
493 if (override_level) {
494 return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length,
495 pool_, recursion_level + 1);
496 } else {
497 return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length,
498 pool_);
499 }
500 }
501
502 protected:
503 std::shared_ptr<io::MemoryMappedFile> mmap_;
504 MemoryPool* pool_;
505};
506
507TEST_F(RecursionLimits, WriteLimit) {
508 int32_t metadata_length = -1;
509 int64_t body_length = -1;
510 std::shared_ptr<Schema> schema;
511 std::shared_ptr<RecordBatch> batch;
512 ASSERT_RAISES(Invalid, WriteToMmap((1 << 8) + 1, false, &metadata_length, &body_length,
513 &batch, &schema));
514}
515
516TEST_F(RecursionLimits, ReadLimit) {
517 int32_t metadata_length = -1;
518 int64_t body_length = -1;
519 std::shared_ptr<Schema> schema;
520
521 const int recursion_depth = 64;
522
523 std::shared_ptr<RecordBatch> batch;
524 ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch,
525 &schema));
526
527 std::unique_ptr<Message> message;
528 ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message));
529
530 io::BufferReader reader(message->body());
531
532 std::shared_ptr<RecordBatch> result;
533 ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &reader, &result));
534}
535
536// Test fails with a structured exception on Windows + Debug
537#if !defined(_WIN32) || defined(NDEBUG)
538TEST_F(RecursionLimits, StressLimit) {
539 auto CheckDepth = [this](int recursion_depth, bool* it_works) {
540 int32_t metadata_length = -1;
541 int64_t body_length = -1;
542 std::shared_ptr<Schema> schema;
543 std::shared_ptr<RecordBatch> batch;
544 ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch,
545 &schema));
546
547 std::unique_ptr<Message> message;
548 ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message));
549
550 io::BufferReader reader(message->body());
551 std::shared_ptr<RecordBatch> result;
552 ASSERT_OK(ReadRecordBatch(*message->metadata(), schema, recursion_depth + 1, &reader,
553 &result));
554 *it_works = result->Equals(*batch);
555 };
556
557 bool it_works = false;
558 CheckDepth(100, &it_works);
559 ASSERT_TRUE(it_works);
560
561// Mitigate Valgrind's slowness
562#if !defined(ARROW_VALGRIND)
563 CheckDepth(500, &it_works);
564 ASSERT_TRUE(it_works);
565#endif
566}
567#endif // !defined(_WIN32) || defined(NDEBUG)
568
569class TestFileFormat : public ::testing::TestWithParam<MakeRecordBatch*> {
570 public:
571 void SetUp() {
572 pool_ = default_memory_pool();
573 ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_));
574 sink_.reset(new io::BufferOutputStream(buffer_));
575 }
576 void TearDown() {}
577
578 Status RoundTripHelper(const BatchVector& in_batches, BatchVector* out_batches) {
579 // Write the file
580 std::shared_ptr<RecordBatchWriter> writer;
581 RETURN_NOT_OK(
582 RecordBatchFileWriter::Open(sink_.get(), in_batches[0]->schema(), &writer));
583
584 const int num_batches = static_cast<int>(in_batches.size());
585
586 for (const auto& batch : in_batches) {
587 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
588 }
589 RETURN_NOT_OK(writer->Close());
590 RETURN_NOT_OK(sink_->Close());
591
592 // Current offset into stream is the end of the file
593 int64_t footer_offset;
594 RETURN_NOT_OK(sink_->Tell(&footer_offset));
595
596 // Open the file
597 auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
598 std::shared_ptr<RecordBatchFileReader> reader;
599 RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset, &reader));
600
601 EXPECT_EQ(num_batches, reader->num_record_batches());
602 for (int i = 0; i < num_batches; ++i) {
603 std::shared_ptr<RecordBatch> chunk;
604 RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk));
605 out_batches->emplace_back(chunk);
606 }
607
608 return Status::OK();
609 }
610
611 protected:
612 MemoryPool* pool_;
613
614 std::unique_ptr<io::BufferOutputStream> sink_;
615 std::shared_ptr<ResizableBuffer> buffer_;
616};
617
618TEST_P(TestFileFormat, RoundTrip) {
619 std::shared_ptr<RecordBatch> batch1;
620 std::shared_ptr<RecordBatch> batch2;
621 ASSERT_OK((*GetParam())(&batch1)); // NOLINT clang-tidy gtest issue
622 ASSERT_OK((*GetParam())(&batch2)); // NOLINT clang-tidy gtest issue
623
624 BatchVector in_batches = {batch1, batch2};
625 BatchVector out_batches;
626
627 ASSERT_OK(RoundTripHelper(in_batches, &out_batches));
628
629 // Compare batches
630 for (size_t i = 0; i < in_batches.size(); ++i) {
631 CompareBatch(*in_batches[i], *out_batches[i]);
632 }
633}
634
635class TestStreamFormat : public ::testing::TestWithParam<MakeRecordBatch*> {
636 public:
637 void SetUp() {
638 pool_ = default_memory_pool();
639 ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_));
640 sink_.reset(new io::BufferOutputStream(buffer_));
641 }
642 void TearDown() {}
643
644 Status RoundTripHelper(const BatchVector& batches, BatchVector* out_batches) {
645 // Write the file
646 std::shared_ptr<RecordBatchWriter> writer;
647 RETURN_NOT_OK(
648 RecordBatchStreamWriter::Open(sink_.get(), batches[0]->schema(), &writer));
649
650 for (const auto& batch : batches) {
651 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
652 }
653 RETURN_NOT_OK(writer->Close());
654 RETURN_NOT_OK(sink_->Close());
655
656 // Open the file
657 io::BufferReader buf_reader(buffer_);
658
659 std::shared_ptr<RecordBatchReader> reader;
660 RETURN_NOT_OK(RecordBatchStreamReader::Open(&buf_reader, &reader));
661 return reader->ReadAll(out_batches);
662 }
663
664 protected:
665 MemoryPool* pool_;
666
667 std::unique_ptr<io::BufferOutputStream> sink_;
668 std::shared_ptr<ResizableBuffer> buffer_;
669};
670
671TEST_P(TestStreamFormat, RoundTrip) {
672 std::shared_ptr<RecordBatch> batch;
673 ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
674
675 BatchVector out_batches;
676
677 ASSERT_OK(RoundTripHelper({batch, batch, batch}, &out_batches));
678
679 // Compare batches. Same
680 for (size_t i = 0; i < out_batches.size(); ++i) {
681 CompareBatch(*batch, *out_batches[i]);
682 }
683}
684
685INSTANTIATE_TEST_CASE_P(GenericIpcRoundTripTests, TestIpcRoundTrip, BATCH_CASES());
686INSTANTIATE_TEST_CASE_P(FileRoundTripTests, TestFileFormat, BATCH_CASES());
687INSTANTIATE_TEST_CASE_P(StreamRoundTripTests, TestStreamFormat, BATCH_CASES());
688
689// This test uses uninitialized memory
690
691#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER))
692TEST_F(TestIpcRoundTrip, LargeRecordBatch) {
693 const int64_t length = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
694
695 BooleanBuilder builder(default_memory_pool());
696 ASSERT_OK(builder.Reserve(length));
697 ASSERT_OK(builder.Advance(length));
698
699 std::shared_ptr<Array> array;
700 ASSERT_OK(builder.Finish(&array));
701
702 auto f0 = arrow::field("f0", array->type());
703 std::vector<std::shared_ptr<Field>> fields = {f0};
704 auto schema = std::make_shared<Schema>(fields);
705
706 auto batch = RecordBatch::Make(schema, length, {array});
707
708 std::string path = "test-write-large-record_batch";
709
710 // 512 MB
711 constexpr int64_t kBufferSize = 1 << 29;
712 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
713
714 std::shared_ptr<RecordBatch> result;
715 ASSERT_OK(DoLargeRoundTrip(*batch, false, &result));
716 CheckReadResult(*result, *batch);
717
718 ASSERT_EQ(length, result->num_rows());
719}
720#endif
721
722void CheckBatchDictionaries(const RecordBatch& batch) {
723 // Check that dictionaries that should be the same are the same
724 auto schema = batch.schema();
725
726 const auto& t0 = checked_cast<const DictionaryType&>(*schema->field(0)->type());
727 const auto& t1 = checked_cast<const DictionaryType&>(*schema->field(1)->type());
728
729 ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get());
730
731 // Same dictionary used for list values
732 const auto& t3 = checked_cast<const ListType&>(*schema->field(3)->type());
733 const auto& t3_value = checked_cast<const DictionaryType&>(*t3.value_type());
734 ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get());
735}
736
737TEST_F(TestStreamFormat, DictionaryRoundTrip) {
738 std::shared_ptr<RecordBatch> batch;
739 ASSERT_OK(MakeDictionary(&batch));
740
741 BatchVector out_batches;
742 ASSERT_OK(RoundTripHelper({batch}, &out_batches));
743
744 CheckBatchDictionaries(*out_batches[0]);
745}
746
747TEST_F(TestStreamFormat, WriteTable) {
748 std::shared_ptr<RecordBatch> b1, b2, b3;
749 ASSERT_OK(MakeIntRecordBatch(&b1));
750 ASSERT_OK(MakeIntRecordBatch(&b2));
751 ASSERT_OK(MakeIntRecordBatch(&b3));
752
753 BatchVector out_batches;
754 ASSERT_OK(RoundTripHelper({b1, b2, b3}, &out_batches));
755
756 ASSERT_TRUE(b1->Equals(*out_batches[0]));
757 ASSERT_TRUE(b2->Equals(*out_batches[1]));
758 ASSERT_TRUE(b3->Equals(*out_batches[2]));
759}
760
761TEST_F(TestFileFormat, DictionaryRoundTrip) {
762 std::shared_ptr<RecordBatch> batch;
763 ASSERT_OK(MakeDictionary(&batch));
764
765 BatchVector out_batches;
766 ASSERT_OK(RoundTripHelper({batch}, &out_batches));
767
768 CheckBatchDictionaries(*out_batches[0]);
769}
770
771class TestTensorRoundTrip : public ::testing::Test, public IpcTestFixture {
772 public:
773 void SetUp() { pool_ = default_memory_pool(); }
774 void TearDown() { io::MemoryMapFixture::TearDown(); }
775
776 void CheckTensorRoundTrip(const Tensor& tensor) {
777 int32_t metadata_length;
778 int64_t body_length;
779
780 const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
781 const int elem_size = type.bit_width() / 8;
782
783 ASSERT_OK(mmap_->Seek(0));
784
785 ASSERT_OK(WriteTensor(tensor, mmap_.get(), &metadata_length, &body_length));
786
787 const int64_t expected_body_length = elem_size * tensor.size();
788 ASSERT_EQ(expected_body_length, body_length);
789
790 ASSERT_OK(mmap_->Seek(0));
791
792 std::shared_ptr<Tensor> result;
793 ASSERT_OK(ReadTensor(mmap_.get(), &result));
794
795 ASSERT_EQ(result->data()->size(), expected_body_length);
796 ASSERT_TRUE(tensor.Equals(*result));
797 }
798};
799
800TEST_F(TestTensorRoundTrip, BasicRoundtrip) {
801 std::string path = "test-write-tensor";
802 constexpr int64_t kBufferSize = 1 << 20;
803 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
804
805 std::vector<int64_t> shape = {4, 6};
806 std::vector<int64_t> strides = {48, 8};
807 std::vector<std::string> dim_names = {"foo", "bar"};
808 int64_t size = 24;
809
810 std::vector<int64_t> values;
811 randint(size, 0, 100, &values);
812
813 auto data = Buffer::Wrap(values);
814
815 Tensor t0(int64(), data, shape, strides, dim_names);
816 Tensor t_no_dims(int64(), data, {}, {}, {});
817 Tensor t_zero_length_dim(int64(), data, {0}, {8}, {"foo"});
818
819 CheckTensorRoundTrip(t0);
820 CheckTensorRoundTrip(t_no_dims);
821 CheckTensorRoundTrip(t_zero_length_dim);
822
823 int64_t serialized_size;
824 ASSERT_OK(GetTensorSize(t0, &serialized_size));
825 ASSERT_TRUE(serialized_size > static_cast<int64_t>(size * sizeof(int64_t)));
826
827 // ARROW-2840: Check that padding/alignment minded
828 std::vector<int64_t> shape_2 = {1, 1};
829 std::vector<int64_t> strides_2 = {8, 8};
830 Tensor t0_not_multiple_64(int64(), data, shape_2, strides_2, dim_names);
831 CheckTensorRoundTrip(t0_not_multiple_64);
832}
833
834TEST_F(TestTensorRoundTrip, NonContiguous) {
835 std::string path = "test-write-tensor-strided";
836 constexpr int64_t kBufferSize = 1 << 20;
837 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
838
839 std::vector<int64_t> values;
840 randint(24, 0, 100, &values);
841
842 auto data = Buffer::Wrap(values);
843 Tensor tensor(int64(), data, {4, 3}, {48, 16});
844
845 CheckTensorRoundTrip(tensor);
846}
847
848class TestSparseTensorRoundTrip : public ::testing::Test, public IpcTestFixture {
849 public:
850 void SetUp() { pool_ = default_memory_pool(); }
851 void TearDown() { io::MemoryMapFixture::TearDown(); }
852
853 template <typename SparseIndexType>
854 void CheckSparseTensorRoundTrip(const SparseTensorImpl<SparseIndexType>& tensor) {
855 GTEST_FAIL();
856 }
857};
858
859template <>
860void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCOOIndex>(
861 const SparseTensorImpl<SparseCOOIndex>& tensor) {
862 const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
863 const int elem_size = type.bit_width() / 8;
864
865 int32_t metadata_length;
866 int64_t body_length;
867
868 ASSERT_OK(mmap_->Seek(0));
869
870 ASSERT_OK(WriteSparseTensor(tensor, mmap_.get(), &metadata_length, &body_length,
871 default_memory_pool()));
872
873 const auto& sparse_index = checked_cast<const SparseCOOIndex&>(*tensor.sparse_index());
874 const int64_t indices_length = elem_size * sparse_index.indices()->size();
875 const int64_t data_length = elem_size * tensor.non_zero_length();
876 const int64_t expected_body_length = indices_length + data_length;
877 ASSERT_EQ(expected_body_length, body_length);
878
879 ASSERT_OK(mmap_->Seek(0));
880
881 std::shared_ptr<SparseTensor> result;
882 ASSERT_OK(ReadSparseTensor(mmap_.get(), &result));
883
884 const auto& resulted_sparse_index =
885 checked_cast<const SparseCOOIndex&>(*result->sparse_index());
886 ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length);
887 ASSERT_EQ(result->data()->size(), data_length);
888 ASSERT_TRUE(result->Equals(*result));
889}
890
891template <>
892void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCSRIndex>(
893 const SparseTensorImpl<SparseCSRIndex>& tensor) {
894 const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
895 const int elem_size = type.bit_width() / 8;
896
897 int32_t metadata_length;
898 int64_t body_length;
899
900 ASSERT_OK(mmap_->Seek(0));
901
902 ASSERT_OK(WriteSparseTensor(tensor, mmap_.get(), &metadata_length, &body_length,
903 default_memory_pool()));
904
905 const auto& sparse_index = checked_cast<const SparseCSRIndex&>(*tensor.sparse_index());
906 const int64_t indptr_length = elem_size * sparse_index.indptr()->size();
907 const int64_t indices_length = elem_size * sparse_index.indices()->size();
908 const int64_t data_length = elem_size * tensor.non_zero_length();
909 const int64_t expected_body_length = indptr_length + indices_length + data_length;
910 ASSERT_EQ(expected_body_length, body_length);
911
912 ASSERT_OK(mmap_->Seek(0));
913
914 std::shared_ptr<SparseTensor> result;
915 ASSERT_OK(ReadSparseTensor(mmap_.get(), &result));
916
917 const auto& resulted_sparse_index =
918 checked_cast<const SparseCSRIndex&>(*result->sparse_index());
919 ASSERT_EQ(resulted_sparse_index.indptr()->data()->size(), indptr_length);
920 ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length);
921 ASSERT_EQ(result->data()->size(), data_length);
922 ASSERT_TRUE(result->Equals(*result));
923}
924
925TEST_F(TestSparseTensorRoundTrip, WithSparseCOOIndex) {
926 std::string path = "test-write-sparse-coo-tensor";
927 constexpr int64_t kBufferSize = 1 << 20;
928 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
929
930 std::vector<int64_t> shape = {2, 3, 4};
931 std::vector<std::string> dim_names = {"foo", "bar", "baz"};
932 std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
933 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
934
935 auto data = Buffer::Wrap(values);
936 NumericTensor<Int64Type> t(data, shape, {}, dim_names);
937 SparseTensorImpl<SparseCOOIndex> st(t);
938
939 CheckSparseTensorRoundTrip(st);
940}
941
942TEST_F(TestSparseTensorRoundTrip, WithSparseCSRIndex) {
943 std::string path = "test-write-sparse-csr-matrix";
944 constexpr int64_t kBufferSize = 1 << 20;
945 ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
946
947 std::vector<int64_t> shape = {4, 6};
948 std::vector<std::string> dim_names = {"foo", "bar", "baz"};
949 std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
950 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
951
952 auto data = Buffer::Wrap(values);
953 NumericTensor<Int64Type> t(data, shape, {}, dim_names);
954 SparseTensorImpl<SparseCSRIndex> st(t);
955
956 CheckSparseTensorRoundTrip(st);
957}
958
959TEST(TestRecordBatchStreamReader, MalformedInput) {
960 const std::string empty_str = "";
961 const std::string garbage_str = "12345678";
962
963 auto empty = std::make_shared<Buffer>(empty_str);
964 auto garbage = std::make_shared<Buffer>(garbage_str);
965
966 std::shared_ptr<RecordBatchReader> batch_reader;
967
968 io::BufferReader empty_reader(empty);
969 ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&empty_reader, &batch_reader));
970
971 io::BufferReader garbage_reader(garbage);
972 ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader));
973}
974
975} // namespace ipc
976} // namespace arrow
977