1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include <cstdint> |
19 | #include <limits> |
20 | #include <memory> |
21 | #include <ostream> |
22 | #include <string> |
23 | |
24 | #include <flatbuffers/flatbuffers.h> |
25 | #include <gtest/gtest.h> |
26 | |
27 | #include "arrow/array.h" |
28 | #include "arrow/buffer.h" |
29 | #include "arrow/builder.h" |
30 | #include "arrow/io/file.h" |
31 | #include "arrow/io/memory.h" |
32 | #include "arrow/io/test-common.h" |
33 | #include "arrow/ipc/Message_generated.h" // IWYU pragma: keep |
34 | #include "arrow/ipc/message.h" |
35 | #include "arrow/ipc/metadata-internal.h" |
36 | #include "arrow/ipc/reader.h" |
37 | #include "arrow/ipc/test-common.h" |
38 | #include "arrow/ipc/writer.h" |
39 | #include "arrow/memory_pool.h" |
40 | #include "arrow/record_batch.h" |
41 | #include "arrow/sparse_tensor.h" |
42 | #include "arrow/status.h" |
43 | #include "arrow/tensor.h" |
44 | #include "arrow/test-util.h" |
45 | #include "arrow/type.h" |
46 | #include "arrow/util/bit-util.h" |
47 | #include "arrow/util/checked_cast.h" |
48 | |
49 | namespace arrow { |
50 | |
51 | using internal::checked_cast; |
52 | |
53 | namespace ipc { |
54 | |
55 | using BatchVector = std::vector<std::shared_ptr<RecordBatch>>; |
56 | |
57 | class TestSchemaMetadata : public ::testing::Test { |
58 | public: |
59 | void SetUp() {} |
60 | |
61 | void CheckRoundtrip(const Schema& schema) { |
62 | std::shared_ptr<Buffer> buffer; |
63 | ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer)); |
64 | |
65 | std::shared_ptr<Schema> result; |
66 | io::BufferReader reader(buffer); |
67 | ASSERT_OK(ReadSchema(&reader, &result)); |
68 | AssertSchemaEqual(schema, *result); |
69 | } |
70 | }; |
71 | |
72 | TEST(TestMessage, Equals) { |
73 | std::string metadata = "foo" ; |
74 | std::string body = "bar" ; |
75 | |
76 | auto b1 = std::make_shared<Buffer>(metadata); |
77 | auto b2 = std::make_shared<Buffer>(metadata); |
78 | auto b3 = std::make_shared<Buffer>(body); |
79 | auto b4 = std::make_shared<Buffer>(body); |
80 | |
81 | Message msg1(b1, b3); |
82 | Message msg2(b2, b4); |
83 | Message msg3(b1, nullptr); |
84 | Message msg4(b2, nullptr); |
85 | |
86 | ASSERT_TRUE(msg1.Equals(msg2)); |
87 | ASSERT_TRUE(msg3.Equals(msg4)); |
88 | |
89 | ASSERT_FALSE(msg1.Equals(msg3)); |
90 | ASSERT_FALSE(msg3.Equals(msg1)); |
91 | |
92 | // same metadata as msg1, different body |
93 | Message msg5(b2, b1); |
94 | ASSERT_FALSE(msg1.Equals(msg5)); |
95 | ASSERT_FALSE(msg5.Equals(msg1)); |
96 | } |
97 | |
98 | TEST(TestMessage, SerializeTo) { |
99 | const int64_t body_length = 64; |
100 | |
101 | flatbuffers::FlatBufferBuilder fbb; |
102 | fbb.Finish(flatbuf::CreateMessage(fbb, internal::kCurrentMetadataVersion, |
103 | flatbuf::MessageHeader_RecordBatch, 0 /* header */, |
104 | body_length)); |
105 | |
106 | std::shared_ptr<Buffer> metadata; |
107 | ASSERT_OK(internal::WriteFlatbufferBuilder(fbb, &metadata)); |
108 | |
109 | std::string body = "abcdef" ; |
110 | |
111 | std::unique_ptr<Message> message; |
112 | ASSERT_OK(Message::Open(metadata, std::make_shared<Buffer>(body), &message)); |
113 | |
114 | int64_t output_length = 0; |
115 | int64_t position = 0; |
116 | |
117 | std::shared_ptr<io::BufferOutputStream> stream; |
118 | |
119 | { |
120 | const int32_t alignment = 8; |
121 | |
122 | ASSERT_OK(io::BufferOutputStream::Create(1 << 10, default_memory_pool(), &stream)); |
123 | ASSERT_OK(message->SerializeTo(stream.get(), alignment, &output_length)); |
124 | ASSERT_OK(stream->Tell(&position)); |
125 | ASSERT_EQ(BitUtil::RoundUp(metadata->size() + 4, alignment) + body_length, |
126 | output_length); |
127 | ASSERT_EQ(output_length, position); |
128 | } |
129 | |
130 | { |
131 | const int32_t alignment = 64; |
132 | |
133 | ASSERT_OK(io::BufferOutputStream::Create(1 << 10, default_memory_pool(), &stream)); |
134 | ASSERT_OK(message->SerializeTo(stream.get(), alignment, &output_length)); |
135 | ASSERT_OK(stream->Tell(&position)); |
136 | ASSERT_EQ(BitUtil::RoundUp(metadata->size() + 4, alignment) + body_length, |
137 | output_length); |
138 | ASSERT_EQ(output_length, position); |
139 | } |
140 | } |
141 | |
142 | TEST(TestMessage, Verify) { |
143 | std::string metadata = "invalid" ; |
144 | std::string body = "abcdef" ; |
145 | |
146 | Message message(std::make_shared<Buffer>(metadata), std::make_shared<Buffer>(body)); |
147 | ASSERT_FALSE(message.Verify()); |
148 | } |
149 | |
150 | const std::shared_ptr<DataType> INT32 = std::make_shared<Int32Type>(); |
151 | |
152 | TEST_F(TestSchemaMetadata, PrimitiveFields) { |
153 | auto f0 = field("f0" , std::make_shared<Int8Type>()); |
154 | auto f1 = field("f1" , std::make_shared<Int16Type>(), false); |
155 | auto f2 = field("f2" , std::make_shared<Int32Type>()); |
156 | auto f3 = field("f3" , std::make_shared<Int64Type>()); |
157 | auto f4 = field("f4" , std::make_shared<UInt8Type>()); |
158 | auto f5 = field("f5" , std::make_shared<UInt16Type>()); |
159 | auto f6 = field("f6" , std::make_shared<UInt32Type>()); |
160 | auto f7 = field("f7" , std::make_shared<UInt64Type>()); |
161 | auto f8 = field("f8" , std::make_shared<FloatType>()); |
162 | auto f9 = field("f9" , std::make_shared<DoubleType>(), false); |
163 | auto f10 = field("f10" , std::make_shared<BooleanType>()); |
164 | |
165 | Schema schema({f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10}); |
166 | CheckRoundtrip(schema); |
167 | } |
168 | |
169 | TEST_F(TestSchemaMetadata, NestedFields) { |
170 | auto type = list(int32()); |
171 | auto f0 = field("f0" , type); |
172 | |
173 | std::shared_ptr<StructType> type2( |
174 | new StructType({field("k1" , INT32), field("k2" , INT32), field("k3" , INT32)})); |
175 | auto f1 = field("f1" , type2); |
176 | |
177 | Schema schema({f0, f1}); |
178 | CheckRoundtrip(schema); |
179 | } |
180 | |
181 | TEST_F(TestSchemaMetadata, KeyValueMetadata) { |
182 | auto field_metadata = key_value_metadata({{"key" , "value" }}); |
183 | auto schema_metadata = key_value_metadata({{"foo" , "bar" }, {"bizz" , "buzz" }}); |
184 | |
185 | auto f0 = field("f0" , std::make_shared<Int8Type>()); |
186 | auto f1 = field("f1" , std::make_shared<Int16Type>(), false, field_metadata); |
187 | |
188 | Schema schema({f0, f1}, schema_metadata); |
189 | CheckRoundtrip(schema); |
190 | } |
191 | |
192 | #define BATCH_CASES() \ |
193 | ::testing::Values(&MakeIntRecordBatch, &MakeListRecordBatch, &MakeNonNullRecordBatch, \ |
194 | &MakeZeroLengthRecordBatch, &MakeDeeplyNestedList, \ |
195 | &MakeStringTypesRecordBatchWithNulls, &MakeStruct, &MakeUnion, \ |
196 | &MakeDictionary, &MakeDates, &MakeTimestamps, &MakeTimes, \ |
197 | &MakeFWBinary, &MakeNull, &MakeDecimal, &MakeBooleanBatch); |
198 | |
199 | static int g_file_number = 0; |
200 | |
201 | class IpcTestFixture : public io::MemoryMapFixture { |
202 | public: |
203 | Status DoSchemaRoundTrip(const Schema& schema, std::shared_ptr<Schema>* result) { |
204 | std::shared_ptr<Buffer> serialized_schema; |
205 | RETURN_NOT_OK(SerializeSchema(schema, pool_, &serialized_schema)); |
206 | |
207 | io::BufferReader buf_reader(serialized_schema); |
208 | return ReadSchema(&buf_reader, result); |
209 | } |
210 | |
211 | Status DoStandardRoundTrip(const RecordBatch& batch, |
212 | std::shared_ptr<RecordBatch>* batch_result) { |
213 | std::shared_ptr<Buffer> serialized_batch; |
214 | RETURN_NOT_OK(SerializeRecordBatch(batch, pool_, &serialized_batch)); |
215 | |
216 | io::BufferReader buf_reader(serialized_batch); |
217 | return ReadRecordBatch(batch.schema(), &buf_reader, batch_result); |
218 | } |
219 | |
220 | Status DoLargeRoundTrip(const RecordBatch& batch, bool zero_data, |
221 | std::shared_ptr<RecordBatch>* result) { |
222 | if (zero_data) { |
223 | RETURN_NOT_OK(ZeroMemoryMap(mmap_.get())); |
224 | } |
225 | RETURN_NOT_OK(mmap_->Seek(0)); |
226 | |
227 | std::shared_ptr<RecordBatchWriter> file_writer; |
228 | RETURN_NOT_OK(RecordBatchFileWriter::Open(mmap_.get(), batch.schema(), &file_writer)); |
229 | RETURN_NOT_OK(file_writer->WriteRecordBatch(batch, true)); |
230 | RETURN_NOT_OK(file_writer->Close()); |
231 | |
232 | int64_t offset; |
233 | RETURN_NOT_OK(mmap_->Tell(&offset)); |
234 | |
235 | std::shared_ptr<RecordBatchFileReader> file_reader; |
236 | RETURN_NOT_OK(RecordBatchFileReader::Open(mmap_.get(), offset, &file_reader)); |
237 | |
238 | return file_reader->ReadRecordBatch(0, result); |
239 | } |
240 | |
241 | void CheckReadResult(const RecordBatch& result, const RecordBatch& expected) { |
242 | EXPECT_EQ(expected.num_rows(), result.num_rows()); |
243 | |
244 | ASSERT_TRUE(expected.schema()->Equals(*result.schema())); |
245 | ASSERT_EQ(expected.num_columns(), result.num_columns()) |
246 | << expected.schema()->ToString() << " result: " << result.schema()->ToString(); |
247 | |
248 | CompareBatchColumnsDetailed(result, expected); |
249 | } |
250 | |
251 | void CheckRoundtrip(const RecordBatch& batch, int64_t buffer_size) { |
252 | std::stringstream ss; |
253 | ss << "test-write-row-batch-" << g_file_number++; |
254 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(buffer_size, ss.str(), &mmap_)); |
255 | |
256 | std::shared_ptr<Schema> schema_result; |
257 | ASSERT_OK(DoSchemaRoundTrip(*batch.schema(), &schema_result)); |
258 | ASSERT_TRUE(batch.schema()->Equals(*schema_result)); |
259 | |
260 | std::shared_ptr<RecordBatch> result; |
261 | ASSERT_OK(DoStandardRoundTrip(batch, &result)); |
262 | CheckReadResult(*result, batch); |
263 | |
264 | ASSERT_OK(DoLargeRoundTrip(batch, true, &result)); |
265 | CheckReadResult(*result, batch); |
266 | } |
267 | |
268 | void CheckRoundtrip(const std::shared_ptr<Array>& array, int64_t buffer_size) { |
269 | auto f0 = arrow::field("f0" , array->type()); |
270 | std::vector<std::shared_ptr<Field>> fields = {f0}; |
271 | auto schema = std::make_shared<Schema>(fields); |
272 | |
273 | auto batch = RecordBatch::Make(schema, 0, {array}); |
274 | CheckRoundtrip(*batch, buffer_size); |
275 | } |
276 | |
277 | protected: |
278 | std::shared_ptr<io::MemoryMappedFile> mmap_; |
279 | MemoryPool* pool_; |
280 | }; |
281 | |
282 | class TestWriteRecordBatch : public ::testing::Test, public IpcTestFixture { |
283 | public: |
284 | void SetUp() { pool_ = default_memory_pool(); } |
285 | void TearDown() { io::MemoryMapFixture::TearDown(); } |
286 | }; |
287 | |
288 | class TestIpcRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*>, |
289 | public IpcTestFixture { |
290 | public: |
291 | void SetUp() { pool_ = default_memory_pool(); } |
292 | void TearDown() { io::MemoryMapFixture::TearDown(); } |
293 | }; |
294 | |
295 | TEST_P(TestIpcRoundTrip, RoundTrip) { |
296 | std::shared_ptr<RecordBatch> batch; |
297 | ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue |
298 | |
299 | CheckRoundtrip(*batch, 1 << 20); |
300 | } |
301 | |
302 | TEST_F(TestIpcRoundTrip, MetadataVersion) { |
303 | std::shared_ptr<RecordBatch> batch; |
304 | ASSERT_OK(MakeIntRecordBatch(&batch)); |
305 | |
306 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(1 << 16, "test-metadata" , &mmap_)); |
307 | |
308 | int32_t metadata_length; |
309 | int64_t body_length; |
310 | |
311 | const int64_t buffer_offset = 0; |
312 | |
313 | ASSERT_OK(WriteRecordBatch(*batch, buffer_offset, mmap_.get(), &metadata_length, |
314 | &body_length, pool_)); |
315 | |
316 | std::unique_ptr<Message> message; |
317 | ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); |
318 | |
319 | ASSERT_EQ(MetadataVersion::V4, message->metadata_version()); |
320 | } |
321 | |
322 | TEST_P(TestIpcRoundTrip, SliceRoundTrip) { |
323 | std::shared_ptr<RecordBatch> batch; |
324 | ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue |
325 | |
326 | // Skip the zero-length case |
327 | if (batch->num_rows() < 2) { |
328 | return; |
329 | } |
330 | |
331 | auto sliced_batch = batch->Slice(2, 10); |
332 | CheckRoundtrip(*sliced_batch, 1 << 20); |
333 | } |
334 | |
335 | TEST_P(TestIpcRoundTrip, ZeroLengthArrays) { |
336 | std::shared_ptr<RecordBatch> batch; |
337 | ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue |
338 | |
339 | std::shared_ptr<RecordBatch> zero_length_batch; |
340 | if (batch->num_rows() > 2) { |
341 | zero_length_batch = batch->Slice(2, 0); |
342 | } else { |
343 | zero_length_batch = batch->Slice(0, 0); |
344 | } |
345 | |
346 | CheckRoundtrip(*zero_length_batch, 1 << 20); |
347 | |
348 | // ARROW-544: check binary array |
349 | std::shared_ptr<Buffer> value_offsets; |
350 | ASSERT_OK(AllocateBuffer(pool_, sizeof(int32_t), &value_offsets)); |
351 | *reinterpret_cast<int32_t*>(value_offsets->mutable_data()) = 0; |
352 | |
353 | std::shared_ptr<Array> bin_array = std::make_shared<BinaryArray>( |
354 | 0, value_offsets, std::make_shared<Buffer>(nullptr, 0), |
355 | std::make_shared<Buffer>(nullptr, 0)); |
356 | |
357 | // null value_offsets |
358 | std::shared_ptr<Array> bin_array2 = std::make_shared<BinaryArray>(0, nullptr, nullptr); |
359 | |
360 | CheckRoundtrip(bin_array, 1 << 20); |
361 | CheckRoundtrip(bin_array2, 1 << 20); |
362 | } |
363 | |
364 | TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) { |
365 | auto CheckArray = [this](const std::shared_ptr<Array>& array) { |
366 | auto f0 = field("f0" , array->type()); |
367 | auto schema = ::arrow::schema({f0}); |
368 | auto batch = RecordBatch::Make(schema, array->length(), {array}); |
369 | auto sliced_batch = batch->Slice(0, 5); |
370 | |
371 | int64_t full_size; |
372 | int64_t sliced_size; |
373 | |
374 | ASSERT_OK(GetRecordBatchSize(*batch, &full_size)); |
375 | ASSERT_OK(GetRecordBatchSize(*sliced_batch, &sliced_size)); |
376 | ASSERT_TRUE(sliced_size < full_size) << sliced_size << " " << full_size; |
377 | |
378 | // make sure we can write and read it |
379 | this->CheckRoundtrip(*sliced_batch, 1 << 20); |
380 | }; |
381 | |
382 | std::shared_ptr<Array> a0, a1; |
383 | auto pool = default_memory_pool(); |
384 | |
385 | // Integer |
386 | ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0)); |
387 | CheckArray(a0); |
388 | |
389 | // String / Binary |
390 | { |
391 | auto s = MakeRandomBinaryArray<StringBuilder, char>(500, false, pool, &a0); |
392 | ASSERT_TRUE(s.ok()); |
393 | } |
394 | CheckArray(a0); |
395 | |
396 | // Boolean |
397 | ASSERT_OK(MakeRandomBooleanArray(10000, false, &a0)); |
398 | CheckArray(a0); |
399 | |
400 | // List |
401 | ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0)); |
402 | ASSERT_OK(MakeRandomListArray(a0, 200, false, pool, &a1)); |
403 | CheckArray(a1); |
404 | |
405 | // Struct |
406 | auto struct_type = struct_({field("f0" , a0->type())}); |
407 | std::vector<std::shared_ptr<Array>> struct_children = {a0}; |
408 | a1 = std::make_shared<StructArray>(struct_type, a0->length(), struct_children); |
409 | CheckArray(a1); |
410 | |
411 | // Sparse Union |
412 | auto union_type = union_({field("f0" , a0->type())}, {0}); |
413 | std::vector<int32_t> type_ids(a0->length()); |
414 | std::shared_ptr<Buffer> ids_buffer; |
415 | ASSERT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &ids_buffer)); |
416 | a1 = |
417 | std::make_shared<UnionArray>(union_type, a0->length(), struct_children, ids_buffer); |
418 | CheckArray(a1); |
419 | |
420 | // Dense union |
421 | auto dense_union_type = union_({field("f0" , a0->type())}, {0}, UnionMode::DENSE); |
422 | std::vector<int32_t> type_offsets; |
423 | for (int32_t i = 0; i < a0->length(); ++i) { |
424 | type_offsets.push_back(i); |
425 | } |
426 | std::shared_ptr<Buffer> offsets_buffer; |
427 | ASSERT_OK(CopyBufferFromVector(type_offsets, default_memory_pool(), &offsets_buffer)); |
428 | a1 = std::make_shared<UnionArray>(dense_union_type, a0->length(), struct_children, |
429 | ids_buffer, offsets_buffer); |
430 | CheckArray(a1); |
431 | } |
432 | |
433 | void TestGetRecordBatchSize(std::shared_ptr<RecordBatch> batch) { |
434 | io::MockOutputStream mock; |
435 | int32_t mock_metadata_length = -1; |
436 | int64_t mock_body_length = -1; |
437 | int64_t size = -1; |
438 | ASSERT_OK(WriteRecordBatch(*batch, 0, &mock, &mock_metadata_length, &mock_body_length, |
439 | default_memory_pool())); |
440 | ASSERT_OK(GetRecordBatchSize(*batch, &size)); |
441 | ASSERT_EQ(mock.GetExtentBytesWritten(), size); |
442 | } |
443 | |
444 | TEST_F(TestWriteRecordBatch, IntegerGetRecordBatchSize) { |
445 | std::shared_ptr<RecordBatch> batch; |
446 | |
447 | ASSERT_OK(MakeIntRecordBatch(&batch)); |
448 | TestGetRecordBatchSize(batch); |
449 | |
450 | ASSERT_OK(MakeListRecordBatch(&batch)); |
451 | TestGetRecordBatchSize(batch); |
452 | |
453 | ASSERT_OK(MakeZeroLengthRecordBatch(&batch)); |
454 | TestGetRecordBatchSize(batch); |
455 | |
456 | ASSERT_OK(MakeNonNullRecordBatch(&batch)); |
457 | TestGetRecordBatchSize(batch); |
458 | |
459 | ASSERT_OK(MakeDeeplyNestedList(&batch)); |
460 | TestGetRecordBatchSize(batch); |
461 | } |
462 | |
463 | class RecursionLimits : public ::testing::Test, public io::MemoryMapFixture { |
464 | public: |
465 | void SetUp() { pool_ = default_memory_pool(); } |
466 | void TearDown() { io::MemoryMapFixture::TearDown(); } |
467 | |
468 | Status WriteToMmap(int recursion_level, bool override_level, int32_t* metadata_length, |
469 | int64_t* body_length, std::shared_ptr<RecordBatch>* batch, |
470 | std::shared_ptr<Schema>* schema) { |
471 | const int batch_length = 5; |
472 | auto type = int32(); |
473 | std::shared_ptr<Array> array; |
474 | const bool include_nulls = true; |
475 | RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool_, &array)); |
476 | for (int i = 0; i < recursion_level; ++i) { |
477 | type = list(type); |
478 | RETURN_NOT_OK( |
479 | MakeRandomListArray(array, batch_length, include_nulls, pool_, &array)); |
480 | } |
481 | |
482 | auto f0 = field("f0" , type); |
483 | |
484 | *schema = ::arrow::schema({f0}); |
485 | |
486 | *batch = RecordBatch::Make(*schema, batch_length, {array}); |
487 | |
488 | std::stringstream ss; |
489 | ss << "test-write-past-max-recursion-" << g_file_number++; |
490 | const int memory_map_size = 1 << 20; |
491 | RETURN_NOT_OK(io::MemoryMapFixture::InitMemoryMap(memory_map_size, ss.str(), &mmap_)); |
492 | |
493 | if (override_level) { |
494 | return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length, |
495 | pool_, recursion_level + 1); |
496 | } else { |
497 | return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length, |
498 | pool_); |
499 | } |
500 | } |
501 | |
502 | protected: |
503 | std::shared_ptr<io::MemoryMappedFile> mmap_; |
504 | MemoryPool* pool_; |
505 | }; |
506 | |
507 | TEST_F(RecursionLimits, WriteLimit) { |
508 | int32_t metadata_length = -1; |
509 | int64_t body_length = -1; |
510 | std::shared_ptr<Schema> schema; |
511 | std::shared_ptr<RecordBatch> batch; |
512 | ASSERT_RAISES(Invalid, WriteToMmap((1 << 8) + 1, false, &metadata_length, &body_length, |
513 | &batch, &schema)); |
514 | } |
515 | |
516 | TEST_F(RecursionLimits, ReadLimit) { |
517 | int32_t metadata_length = -1; |
518 | int64_t body_length = -1; |
519 | std::shared_ptr<Schema> schema; |
520 | |
521 | const int recursion_depth = 64; |
522 | |
523 | std::shared_ptr<RecordBatch> batch; |
524 | ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch, |
525 | &schema)); |
526 | |
527 | std::unique_ptr<Message> message; |
528 | ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); |
529 | |
530 | io::BufferReader reader(message->body()); |
531 | |
532 | std::shared_ptr<RecordBatch> result; |
533 | ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &reader, &result)); |
534 | } |
535 | |
536 | // Test fails with a structured exception on Windows + Debug |
537 | #if !defined(_WIN32) || defined(NDEBUG) |
538 | TEST_F(RecursionLimits, StressLimit) { |
539 | auto CheckDepth = [this](int recursion_depth, bool* it_works) { |
540 | int32_t metadata_length = -1; |
541 | int64_t body_length = -1; |
542 | std::shared_ptr<Schema> schema; |
543 | std::shared_ptr<RecordBatch> batch; |
544 | ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch, |
545 | &schema)); |
546 | |
547 | std::unique_ptr<Message> message; |
548 | ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); |
549 | |
550 | io::BufferReader reader(message->body()); |
551 | std::shared_ptr<RecordBatch> result; |
552 | ASSERT_OK(ReadRecordBatch(*message->metadata(), schema, recursion_depth + 1, &reader, |
553 | &result)); |
554 | *it_works = result->Equals(*batch); |
555 | }; |
556 | |
557 | bool it_works = false; |
558 | CheckDepth(100, &it_works); |
559 | ASSERT_TRUE(it_works); |
560 | |
561 | // Mitigate Valgrind's slowness |
562 | #if !defined(ARROW_VALGRIND) |
563 | CheckDepth(500, &it_works); |
564 | ASSERT_TRUE(it_works); |
565 | #endif |
566 | } |
567 | #endif // !defined(_WIN32) || defined(NDEBUG) |
568 | |
569 | class TestFileFormat : public ::testing::TestWithParam<MakeRecordBatch*> { |
570 | public: |
571 | void SetUp() { |
572 | pool_ = default_memory_pool(); |
573 | ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_)); |
574 | sink_.reset(new io::BufferOutputStream(buffer_)); |
575 | } |
576 | void TearDown() {} |
577 | |
578 | Status RoundTripHelper(const BatchVector& in_batches, BatchVector* out_batches) { |
579 | // Write the file |
580 | std::shared_ptr<RecordBatchWriter> writer; |
581 | RETURN_NOT_OK( |
582 | RecordBatchFileWriter::Open(sink_.get(), in_batches[0]->schema(), &writer)); |
583 | |
584 | const int num_batches = static_cast<int>(in_batches.size()); |
585 | |
586 | for (const auto& batch : in_batches) { |
587 | RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); |
588 | } |
589 | RETURN_NOT_OK(writer->Close()); |
590 | RETURN_NOT_OK(sink_->Close()); |
591 | |
592 | // Current offset into stream is the end of the file |
593 | int64_t ; |
594 | RETURN_NOT_OK(sink_->Tell(&footer_offset)); |
595 | |
596 | // Open the file |
597 | auto buf_reader = std::make_shared<io::BufferReader>(buffer_); |
598 | std::shared_ptr<RecordBatchFileReader> reader; |
599 | RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset, &reader)); |
600 | |
601 | EXPECT_EQ(num_batches, reader->num_record_batches()); |
602 | for (int i = 0; i < num_batches; ++i) { |
603 | std::shared_ptr<RecordBatch> chunk; |
604 | RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); |
605 | out_batches->emplace_back(chunk); |
606 | } |
607 | |
608 | return Status::OK(); |
609 | } |
610 | |
611 | protected: |
612 | MemoryPool* pool_; |
613 | |
614 | std::unique_ptr<io::BufferOutputStream> sink_; |
615 | std::shared_ptr<ResizableBuffer> buffer_; |
616 | }; |
617 | |
618 | TEST_P(TestFileFormat, RoundTrip) { |
619 | std::shared_ptr<RecordBatch> batch1; |
620 | std::shared_ptr<RecordBatch> batch2; |
621 | ASSERT_OK((*GetParam())(&batch1)); // NOLINT clang-tidy gtest issue |
622 | ASSERT_OK((*GetParam())(&batch2)); // NOLINT clang-tidy gtest issue |
623 | |
624 | BatchVector in_batches = {batch1, batch2}; |
625 | BatchVector out_batches; |
626 | |
627 | ASSERT_OK(RoundTripHelper(in_batches, &out_batches)); |
628 | |
629 | // Compare batches |
630 | for (size_t i = 0; i < in_batches.size(); ++i) { |
631 | CompareBatch(*in_batches[i], *out_batches[i]); |
632 | } |
633 | } |
634 | |
635 | class TestStreamFormat : public ::testing::TestWithParam<MakeRecordBatch*> { |
636 | public: |
637 | void SetUp() { |
638 | pool_ = default_memory_pool(); |
639 | ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_)); |
640 | sink_.reset(new io::BufferOutputStream(buffer_)); |
641 | } |
642 | void TearDown() {} |
643 | |
644 | Status RoundTripHelper(const BatchVector& batches, BatchVector* out_batches) { |
645 | // Write the file |
646 | std::shared_ptr<RecordBatchWriter> writer; |
647 | RETURN_NOT_OK( |
648 | RecordBatchStreamWriter::Open(sink_.get(), batches[0]->schema(), &writer)); |
649 | |
650 | for (const auto& batch : batches) { |
651 | RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); |
652 | } |
653 | RETURN_NOT_OK(writer->Close()); |
654 | RETURN_NOT_OK(sink_->Close()); |
655 | |
656 | // Open the file |
657 | io::BufferReader buf_reader(buffer_); |
658 | |
659 | std::shared_ptr<RecordBatchReader> reader; |
660 | RETURN_NOT_OK(RecordBatchStreamReader::Open(&buf_reader, &reader)); |
661 | return reader->ReadAll(out_batches); |
662 | } |
663 | |
664 | protected: |
665 | MemoryPool* pool_; |
666 | |
667 | std::unique_ptr<io::BufferOutputStream> sink_; |
668 | std::shared_ptr<ResizableBuffer> buffer_; |
669 | }; |
670 | |
671 | TEST_P(TestStreamFormat, RoundTrip) { |
672 | std::shared_ptr<RecordBatch> batch; |
673 | ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue |
674 | |
675 | BatchVector out_batches; |
676 | |
677 | ASSERT_OK(RoundTripHelper({batch, batch, batch}, &out_batches)); |
678 | |
679 | // Compare batches. Same |
680 | for (size_t i = 0; i < out_batches.size(); ++i) { |
681 | CompareBatch(*batch, *out_batches[i]); |
682 | } |
683 | } |
684 | |
685 | INSTANTIATE_TEST_CASE_P(GenericIpcRoundTripTests, TestIpcRoundTrip, BATCH_CASES()); |
686 | INSTANTIATE_TEST_CASE_P(FileRoundTripTests, TestFileFormat, BATCH_CASES()); |
687 | INSTANTIATE_TEST_CASE_P(StreamRoundTripTests, TestStreamFormat, BATCH_CASES()); |
688 | |
689 | // This test uses uninitialized memory |
690 | |
691 | #if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) |
692 | TEST_F(TestIpcRoundTrip, LargeRecordBatch) { |
693 | const int64_t length = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1; |
694 | |
695 | BooleanBuilder builder(default_memory_pool()); |
696 | ASSERT_OK(builder.Reserve(length)); |
697 | ASSERT_OK(builder.Advance(length)); |
698 | |
699 | std::shared_ptr<Array> array; |
700 | ASSERT_OK(builder.Finish(&array)); |
701 | |
702 | auto f0 = arrow::field("f0" , array->type()); |
703 | std::vector<std::shared_ptr<Field>> fields = {f0}; |
704 | auto schema = std::make_shared<Schema>(fields); |
705 | |
706 | auto batch = RecordBatch::Make(schema, length, {array}); |
707 | |
708 | std::string path = "test-write-large-record_batch" ; |
709 | |
710 | // 512 MB |
711 | constexpr int64_t kBufferSize = 1 << 29; |
712 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); |
713 | |
714 | std::shared_ptr<RecordBatch> result; |
715 | ASSERT_OK(DoLargeRoundTrip(*batch, false, &result)); |
716 | CheckReadResult(*result, *batch); |
717 | |
718 | ASSERT_EQ(length, result->num_rows()); |
719 | } |
720 | #endif |
721 | |
722 | void CheckBatchDictionaries(const RecordBatch& batch) { |
723 | // Check that dictionaries that should be the same are the same |
724 | auto schema = batch.schema(); |
725 | |
726 | const auto& t0 = checked_cast<const DictionaryType&>(*schema->field(0)->type()); |
727 | const auto& t1 = checked_cast<const DictionaryType&>(*schema->field(1)->type()); |
728 | |
729 | ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get()); |
730 | |
731 | // Same dictionary used for list values |
732 | const auto& t3 = checked_cast<const ListType&>(*schema->field(3)->type()); |
733 | const auto& t3_value = checked_cast<const DictionaryType&>(*t3.value_type()); |
734 | ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get()); |
735 | } |
736 | |
737 | TEST_F(TestStreamFormat, DictionaryRoundTrip) { |
738 | std::shared_ptr<RecordBatch> batch; |
739 | ASSERT_OK(MakeDictionary(&batch)); |
740 | |
741 | BatchVector out_batches; |
742 | ASSERT_OK(RoundTripHelper({batch}, &out_batches)); |
743 | |
744 | CheckBatchDictionaries(*out_batches[0]); |
745 | } |
746 | |
747 | TEST_F(TestStreamFormat, WriteTable) { |
748 | std::shared_ptr<RecordBatch> b1, b2, b3; |
749 | ASSERT_OK(MakeIntRecordBatch(&b1)); |
750 | ASSERT_OK(MakeIntRecordBatch(&b2)); |
751 | ASSERT_OK(MakeIntRecordBatch(&b3)); |
752 | |
753 | BatchVector out_batches; |
754 | ASSERT_OK(RoundTripHelper({b1, b2, b3}, &out_batches)); |
755 | |
756 | ASSERT_TRUE(b1->Equals(*out_batches[0])); |
757 | ASSERT_TRUE(b2->Equals(*out_batches[1])); |
758 | ASSERT_TRUE(b3->Equals(*out_batches[2])); |
759 | } |
760 | |
761 | TEST_F(TestFileFormat, DictionaryRoundTrip) { |
762 | std::shared_ptr<RecordBatch> batch; |
763 | ASSERT_OK(MakeDictionary(&batch)); |
764 | |
765 | BatchVector out_batches; |
766 | ASSERT_OK(RoundTripHelper({batch}, &out_batches)); |
767 | |
768 | CheckBatchDictionaries(*out_batches[0]); |
769 | } |
770 | |
771 | class TestTensorRoundTrip : public ::testing::Test, public IpcTestFixture { |
772 | public: |
773 | void SetUp() { pool_ = default_memory_pool(); } |
774 | void TearDown() { io::MemoryMapFixture::TearDown(); } |
775 | |
776 | void CheckTensorRoundTrip(const Tensor& tensor) { |
777 | int32_t metadata_length; |
778 | int64_t body_length; |
779 | |
780 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
781 | const int elem_size = type.bit_width() / 8; |
782 | |
783 | ASSERT_OK(mmap_->Seek(0)); |
784 | |
785 | ASSERT_OK(WriteTensor(tensor, mmap_.get(), &metadata_length, &body_length)); |
786 | |
787 | const int64_t expected_body_length = elem_size * tensor.size(); |
788 | ASSERT_EQ(expected_body_length, body_length); |
789 | |
790 | ASSERT_OK(mmap_->Seek(0)); |
791 | |
792 | std::shared_ptr<Tensor> result; |
793 | ASSERT_OK(ReadTensor(mmap_.get(), &result)); |
794 | |
795 | ASSERT_EQ(result->data()->size(), expected_body_length); |
796 | ASSERT_TRUE(tensor.Equals(*result)); |
797 | } |
798 | }; |
799 | |
800 | TEST_F(TestTensorRoundTrip, BasicRoundtrip) { |
801 | std::string path = "test-write-tensor" ; |
802 | constexpr int64_t kBufferSize = 1 << 20; |
803 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); |
804 | |
805 | std::vector<int64_t> shape = {4, 6}; |
806 | std::vector<int64_t> strides = {48, 8}; |
807 | std::vector<std::string> dim_names = {"foo" , "bar" }; |
808 | int64_t size = 24; |
809 | |
810 | std::vector<int64_t> values; |
811 | randint(size, 0, 100, &values); |
812 | |
813 | auto data = Buffer::Wrap(values); |
814 | |
815 | Tensor t0(int64(), data, shape, strides, dim_names); |
816 | Tensor t_no_dims(int64(), data, {}, {}, {}); |
817 | Tensor t_zero_length_dim(int64(), data, {0}, {8}, {"foo" }); |
818 | |
819 | CheckTensorRoundTrip(t0); |
820 | CheckTensorRoundTrip(t_no_dims); |
821 | CheckTensorRoundTrip(t_zero_length_dim); |
822 | |
823 | int64_t serialized_size; |
824 | ASSERT_OK(GetTensorSize(t0, &serialized_size)); |
825 | ASSERT_TRUE(serialized_size > static_cast<int64_t>(size * sizeof(int64_t))); |
826 | |
827 | // ARROW-2840: Check that padding/alignment minded |
828 | std::vector<int64_t> shape_2 = {1, 1}; |
829 | std::vector<int64_t> strides_2 = {8, 8}; |
830 | Tensor t0_not_multiple_64(int64(), data, shape_2, strides_2, dim_names); |
831 | CheckTensorRoundTrip(t0_not_multiple_64); |
832 | } |
833 | |
834 | TEST_F(TestTensorRoundTrip, NonContiguous) { |
835 | std::string path = "test-write-tensor-strided" ; |
836 | constexpr int64_t kBufferSize = 1 << 20; |
837 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); |
838 | |
839 | std::vector<int64_t> values; |
840 | randint(24, 0, 100, &values); |
841 | |
842 | auto data = Buffer::Wrap(values); |
843 | Tensor tensor(int64(), data, {4, 3}, {48, 16}); |
844 | |
845 | CheckTensorRoundTrip(tensor); |
846 | } |
847 | |
848 | class TestSparseTensorRoundTrip : public ::testing::Test, public IpcTestFixture { |
849 | public: |
850 | void SetUp() { pool_ = default_memory_pool(); } |
851 | void TearDown() { io::MemoryMapFixture::TearDown(); } |
852 | |
853 | template <typename SparseIndexType> |
854 | void CheckSparseTensorRoundTrip(const SparseTensorImpl<SparseIndexType>& tensor) { |
855 | GTEST_FAIL(); |
856 | } |
857 | }; |
858 | |
859 | template <> |
860 | void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCOOIndex>( |
861 | const SparseTensorImpl<SparseCOOIndex>& tensor) { |
862 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
863 | const int elem_size = type.bit_width() / 8; |
864 | |
865 | int32_t metadata_length; |
866 | int64_t body_length; |
867 | |
868 | ASSERT_OK(mmap_->Seek(0)); |
869 | |
870 | ASSERT_OK(WriteSparseTensor(tensor, mmap_.get(), &metadata_length, &body_length, |
871 | default_memory_pool())); |
872 | |
873 | const auto& sparse_index = checked_cast<const SparseCOOIndex&>(*tensor.sparse_index()); |
874 | const int64_t indices_length = elem_size * sparse_index.indices()->size(); |
875 | const int64_t data_length = elem_size * tensor.non_zero_length(); |
876 | const int64_t expected_body_length = indices_length + data_length; |
877 | ASSERT_EQ(expected_body_length, body_length); |
878 | |
879 | ASSERT_OK(mmap_->Seek(0)); |
880 | |
881 | std::shared_ptr<SparseTensor> result; |
882 | ASSERT_OK(ReadSparseTensor(mmap_.get(), &result)); |
883 | |
884 | const auto& resulted_sparse_index = |
885 | checked_cast<const SparseCOOIndex&>(*result->sparse_index()); |
886 | ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length); |
887 | ASSERT_EQ(result->data()->size(), data_length); |
888 | ASSERT_TRUE(result->Equals(*result)); |
889 | } |
890 | |
891 | template <> |
892 | void TestSparseTensorRoundTrip::CheckSparseTensorRoundTrip<SparseCSRIndex>( |
893 | const SparseTensorImpl<SparseCSRIndex>& tensor) { |
894 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
895 | const int elem_size = type.bit_width() / 8; |
896 | |
897 | int32_t metadata_length; |
898 | int64_t body_length; |
899 | |
900 | ASSERT_OK(mmap_->Seek(0)); |
901 | |
902 | ASSERT_OK(WriteSparseTensor(tensor, mmap_.get(), &metadata_length, &body_length, |
903 | default_memory_pool())); |
904 | |
905 | const auto& sparse_index = checked_cast<const SparseCSRIndex&>(*tensor.sparse_index()); |
906 | const int64_t indptr_length = elem_size * sparse_index.indptr()->size(); |
907 | const int64_t indices_length = elem_size * sparse_index.indices()->size(); |
908 | const int64_t data_length = elem_size * tensor.non_zero_length(); |
909 | const int64_t expected_body_length = indptr_length + indices_length + data_length; |
910 | ASSERT_EQ(expected_body_length, body_length); |
911 | |
912 | ASSERT_OK(mmap_->Seek(0)); |
913 | |
914 | std::shared_ptr<SparseTensor> result; |
915 | ASSERT_OK(ReadSparseTensor(mmap_.get(), &result)); |
916 | |
917 | const auto& resulted_sparse_index = |
918 | checked_cast<const SparseCSRIndex&>(*result->sparse_index()); |
919 | ASSERT_EQ(resulted_sparse_index.indptr()->data()->size(), indptr_length); |
920 | ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length); |
921 | ASSERT_EQ(result->data()->size(), data_length); |
922 | ASSERT_TRUE(result->Equals(*result)); |
923 | } |
924 | |
925 | TEST_F(TestSparseTensorRoundTrip, WithSparseCOOIndex) { |
926 | std::string path = "test-write-sparse-coo-tensor" ; |
927 | constexpr int64_t kBufferSize = 1 << 20; |
928 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); |
929 | |
930 | std::vector<int64_t> shape = {2, 3, 4}; |
931 | std::vector<std::string> dim_names = {"foo" , "bar" , "baz" }; |
932 | std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, |
933 | 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16}; |
934 | |
935 | auto data = Buffer::Wrap(values); |
936 | NumericTensor<Int64Type> t(data, shape, {}, dim_names); |
937 | SparseTensorImpl<SparseCOOIndex> st(t); |
938 | |
939 | CheckSparseTensorRoundTrip(st); |
940 | } |
941 | |
942 | TEST_F(TestSparseTensorRoundTrip, WithSparseCSRIndex) { |
943 | std::string path = "test-write-sparse-csr-matrix" ; |
944 | constexpr int64_t kBufferSize = 1 << 20; |
945 | ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); |
946 | |
947 | std::vector<int64_t> shape = {4, 6}; |
948 | std::vector<std::string> dim_names = {"foo" , "bar" , "baz" }; |
949 | std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, |
950 | 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16}; |
951 | |
952 | auto data = Buffer::Wrap(values); |
953 | NumericTensor<Int64Type> t(data, shape, {}, dim_names); |
954 | SparseTensorImpl<SparseCSRIndex> st(t); |
955 | |
956 | CheckSparseTensorRoundTrip(st); |
957 | } |
958 | |
959 | TEST(TestRecordBatchStreamReader, MalformedInput) { |
960 | const std::string empty_str = "" ; |
961 | const std::string garbage_str = "12345678" ; |
962 | |
963 | auto empty = std::make_shared<Buffer>(empty_str); |
964 | auto garbage = std::make_shared<Buffer>(garbage_str); |
965 | |
966 | std::shared_ptr<RecordBatchReader> batch_reader; |
967 | |
968 | io::BufferReader empty_reader(empty); |
969 | ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&empty_reader, &batch_reader)); |
970 | |
971 | io::BufferReader garbage_reader(garbage); |
972 | ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader)); |
973 | } |
974 | |
975 | } // namespace ipc |
976 | } // namespace arrow |
977 | |