1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/ipc/writer.h"
19
20#include <algorithm>
21#include <cstdint>
22#include <cstring>
23#include <limits>
24#include <sstream>
25#include <vector>
26
27#include "arrow/array.h"
28#include "arrow/buffer.h"
29#include "arrow/io/interfaces.h"
30#include "arrow/io/memory.h"
31#include "arrow/ipc/dictionary.h"
32#include "arrow/ipc/message.h"
33#include "arrow/ipc/metadata-internal.h"
34#include "arrow/ipc/util.h"
35#include "arrow/memory_pool.h"
36#include "arrow/record_batch.h"
37#include "arrow/sparse_tensor.h"
38#include "arrow/status.h"
39#include "arrow/table.h"
40#include "arrow/tensor.h"
41#include "arrow/type.h"
42#include "arrow/util/bit-util.h"
43#include "arrow/util/checked_cast.h"
44#include "arrow/util/logging.h"
45#include "arrow/visitor.h"
46
47namespace arrow {
48
49using internal::checked_cast;
50using internal::CopyBitmap;
51
52namespace ipc {
53
54using internal::FileBlock;
55using internal::kArrowMagicBytes;
56
57// ----------------------------------------------------------------------
58// Record batch write path
59
60static inline Status GetTruncatedBitmap(int64_t offset, int64_t length,
61 const std::shared_ptr<Buffer> input,
62 MemoryPool* pool,
63 std::shared_ptr<Buffer>* buffer) {
64 if (!input) {
65 *buffer = input;
66 return Status::OK();
67 }
68 int64_t min_length = PaddedLength(BitUtil::BytesForBits(length));
69 if (offset != 0 || min_length < input->size()) {
70 // With a sliced array / non-zero offset, we must copy the bitmap
71 RETURN_NOT_OK(CopyBitmap(pool, input->data(), offset, length, buffer));
72 } else {
73 *buffer = input;
74 }
75 return Status::OK();
76}
77
78template <typename T>
79inline Status GetTruncatedBuffer(int64_t offset, int64_t length,
80 const std::shared_ptr<Buffer> input, MemoryPool* pool,
81 std::shared_ptr<Buffer>* buffer) {
82 if (!input) {
83 *buffer = input;
84 return Status::OK();
85 }
86 int32_t byte_width = static_cast<int32_t>(sizeof(T));
87 int64_t padded_length = PaddedLength(length * byte_width);
88 if (offset != 0 || padded_length < input->size()) {
89 *buffer =
90 SliceBuffer(input, offset * byte_width, std::min(padded_length, input->size()));
91 } else {
92 *buffer = input;
93 }
94 return Status::OK();
95}
96
97static inline bool NeedTruncate(int64_t offset, const Buffer* buffer,
98 int64_t min_length) {
99 // buffer can be NULL
100 if (buffer == nullptr) {
101 return false;
102 }
103 return offset != 0 || min_length < buffer->size();
104}
105
106namespace internal {
107
108class RecordBatchSerializer : public ArrayVisitor {
109 public:
110 RecordBatchSerializer(MemoryPool* pool, int64_t buffer_start_offset,
111 int max_recursion_depth, bool allow_64bit, IpcPayload* out)
112 : out_(out),
113 pool_(pool),
114 max_recursion_depth_(max_recursion_depth),
115 buffer_start_offset_(buffer_start_offset),
116 allow_64bit_(allow_64bit) {
117 DCHECK_GT(max_recursion_depth, 0);
118 }
119
120 ~RecordBatchSerializer() override = default;
121
122 Status VisitArray(const Array& arr) {
123 if (max_recursion_depth_ <= 0) {
124 return Status::Invalid("Max recursion depth reached");
125 }
126
127 if (!allow_64bit_ && arr.length() > std::numeric_limits<int32_t>::max()) {
128 return Status::CapacityError("Cannot write arrays larger than 2^31 - 1 in length");
129 }
130
131 // push back all common elements
132 field_nodes_.push_back({arr.length(), arr.null_count(), 0});
133
134 if (arr.null_count() > 0) {
135 std::shared_ptr<Buffer> bitmap;
136 RETURN_NOT_OK(GetTruncatedBitmap(arr.offset(), arr.length(), arr.null_bitmap(),
137 pool_, &bitmap));
138 out_->body_buffers.emplace_back(bitmap);
139 } else {
140 // Push a dummy zero-length buffer, not to be copied
141 out_->body_buffers.emplace_back(std::make_shared<Buffer>(nullptr, 0));
142 }
143 return arr.Accept(this);
144 }
145
146 // Override this for writing dictionary metadata
147 virtual Status SerializeMetadata(int64_t num_rows) {
148 return WriteRecordBatchMessage(num_rows, out_->body_length, field_nodes_,
149 buffer_meta_, &out_->metadata);
150 }
151
152 Status Assemble(const RecordBatch& batch) {
153 if (field_nodes_.size() > 0) {
154 field_nodes_.clear();
155 buffer_meta_.clear();
156 out_->body_buffers.clear();
157 }
158
159 // Perform depth-first traversal of the row-batch
160 for (int i = 0; i < batch.num_columns(); ++i) {
161 RETURN_NOT_OK(VisitArray(*batch.column(i)));
162 }
163
164 // The position for the start of a buffer relative to the passed frame of
165 // reference. May be 0 or some other position in an address space
166 int64_t offset = buffer_start_offset_;
167
168 buffer_meta_.reserve(out_->body_buffers.size());
169
170 // Construct the buffer metadata for the record batch header
171 for (size_t i = 0; i < out_->body_buffers.size(); ++i) {
172 const Buffer* buffer = out_->body_buffers[i].get();
173 int64_t size = 0;
174 int64_t padding = 0;
175
176 // The buffer might be null if we are handling zero row lengths.
177 if (buffer) {
178 size = buffer->size();
179 padding = BitUtil::RoundUpToMultipleOf8(size) - size;
180 }
181
182 buffer_meta_.push_back({offset, size + padding});
183 offset += size + padding;
184 }
185
186 out_->body_length = offset - buffer_start_offset_;
187 DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
188
189 // Now that we have computed the locations of all of the buffers in shared
190 // memory, the data header can be converted to a flatbuffer and written out
191 //
192 // Note: The memory written here is prefixed by the size of the flatbuffer
193 // itself as an int32_t.
194 return SerializeMetadata(batch.num_rows());
195 }
196
197 protected:
198 template <typename ArrayType>
199 Status VisitFixedWidth(const ArrayType& array) {
200 std::shared_ptr<Buffer> data = array.values();
201
202 const auto& fw_type = checked_cast<const FixedWidthType&>(*array.type());
203 const int64_t type_width = fw_type.bit_width() / 8;
204 int64_t min_length = PaddedLength(array.length() * type_width);
205
206 if (NeedTruncate(array.offset(), data.get(), min_length)) {
207 // Non-zero offset, slice the buffer
208 const int64_t byte_offset = array.offset() * type_width;
209
210 // Send padding if it's available
211 const int64_t buffer_length =
212 std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width),
213 data->size() - byte_offset);
214 data = SliceBuffer(data, byte_offset, buffer_length);
215 }
216 out_->body_buffers.emplace_back(data);
217 return Status::OK();
218 }
219
220 template <typename ArrayType>
221 Status GetZeroBasedValueOffsets(const ArrayType& array,
222 std::shared_ptr<Buffer>* value_offsets) {
223 // Share slicing logic between ListArray and BinaryArray
224
225 auto offsets = array.value_offsets();
226
227 if (array.offset() != 0) {
228 // If we have a non-zero offset, then the value offsets do not start at
229 // zero. We must a) create a new offsets array with shifted offsets and
230 // b) slice the values array accordingly
231
232 std::shared_ptr<Buffer> shifted_offsets;
233 RETURN_NOT_OK(AllocateBuffer(pool_, sizeof(int32_t) * (array.length() + 1),
234 &shifted_offsets));
235
236 int32_t* dest_offsets = reinterpret_cast<int32_t*>(shifted_offsets->mutable_data());
237 const int32_t start_offset = array.value_offset(0);
238
239 for (int i = 0; i < array.length(); ++i) {
240 dest_offsets[i] = array.value_offset(i) - start_offset;
241 }
242 // Final offset
243 dest_offsets[array.length()] = array.value_offset(array.length()) - start_offset;
244 offsets = shifted_offsets;
245 }
246
247 *value_offsets = offsets;
248 return Status::OK();
249 }
250
251 Status VisitBinary(const BinaryArray& array) {
252 std::shared_ptr<Buffer> value_offsets;
253 RETURN_NOT_OK(GetZeroBasedValueOffsets<BinaryArray>(array, &value_offsets));
254 auto data = array.value_data();
255
256 int64_t total_data_bytes = 0;
257 if (value_offsets) {
258 total_data_bytes = array.value_offset(array.length()) - array.value_offset(0);
259 }
260 if (NeedTruncate(array.offset(), data.get(), total_data_bytes)) {
261 // Slice the data buffer to include only the range we need now
262 const int64_t start_offset = array.value_offset(0);
263 const int64_t slice_length =
264 std::min(PaddedLength(total_data_bytes), data->size() - start_offset);
265 data = SliceBuffer(data, start_offset, slice_length);
266 }
267
268 out_->body_buffers.emplace_back(value_offsets);
269 out_->body_buffers.emplace_back(data);
270 return Status::OK();
271 }
272
273 Status Visit(const BooleanArray& array) override {
274 std::shared_ptr<Buffer> data;
275 RETURN_NOT_OK(
276 GetTruncatedBitmap(array.offset(), array.length(), array.values(), pool_, &data));
277 out_->body_buffers.emplace_back(data);
278 return Status::OK();
279 }
280
281 Status Visit(const NullArray& array) override {
282 out_->body_buffers.emplace_back(nullptr);
283 return Status::OK();
284 }
285
286#define VISIT_FIXED_WIDTH(TYPE) \
287 Status Visit(const TYPE& array) override { return VisitFixedWidth<TYPE>(array); }
288
289 VISIT_FIXED_WIDTH(Int8Array)
290 VISIT_FIXED_WIDTH(Int16Array)
291 VISIT_FIXED_WIDTH(Int32Array)
292 VISIT_FIXED_WIDTH(Int64Array)
293 VISIT_FIXED_WIDTH(UInt8Array)
294 VISIT_FIXED_WIDTH(UInt16Array)
295 VISIT_FIXED_WIDTH(UInt32Array)
296 VISIT_FIXED_WIDTH(UInt64Array)
297 VISIT_FIXED_WIDTH(HalfFloatArray)
298 VISIT_FIXED_WIDTH(FloatArray)
299 VISIT_FIXED_WIDTH(DoubleArray)
300 VISIT_FIXED_WIDTH(Date32Array)
301 VISIT_FIXED_WIDTH(Date64Array)
302 VISIT_FIXED_WIDTH(TimestampArray)
303 VISIT_FIXED_WIDTH(Time32Array)
304 VISIT_FIXED_WIDTH(Time64Array)
305 VISIT_FIXED_WIDTH(FixedSizeBinaryArray)
306 VISIT_FIXED_WIDTH(Decimal128Array)
307
308#undef VISIT_FIXED_WIDTH
309
310 Status Visit(const StringArray& array) override { return VisitBinary(array); }
311
312 Status Visit(const BinaryArray& array) override { return VisitBinary(array); }
313
314 Status Visit(const ListArray& array) override {
315 std::shared_ptr<Buffer> value_offsets;
316 RETURN_NOT_OK(GetZeroBasedValueOffsets<ListArray>(array, &value_offsets));
317 out_->body_buffers.emplace_back(value_offsets);
318
319 --max_recursion_depth_;
320 std::shared_ptr<Array> values = array.values();
321
322 int32_t values_offset = 0;
323 int32_t values_length = 0;
324 if (value_offsets) {
325 values_offset = array.value_offset(0);
326 values_length = array.value_offset(array.length()) - values_offset;
327 }
328
329 if (array.offset() != 0 || values_length < values->length()) {
330 // Must also slice the values
331 values = values->Slice(values_offset, values_length);
332 }
333 RETURN_NOT_OK(VisitArray(*values));
334 ++max_recursion_depth_;
335 return Status::OK();
336 }
337
338 Status Visit(const StructArray& array) override {
339 --max_recursion_depth_;
340 for (int i = 0; i < array.num_fields(); ++i) {
341 std::shared_ptr<Array> field = array.field(i);
342 RETURN_NOT_OK(VisitArray(*field));
343 }
344 ++max_recursion_depth_;
345 return Status::OK();
346 }
347
348 Status Visit(const UnionArray& array) override {
349 const int64_t offset = array.offset();
350 const int64_t length = array.length();
351
352 std::shared_ptr<Buffer> type_ids;
353 RETURN_NOT_OK(GetTruncatedBuffer<UnionArray::type_id_t>(
354 offset, length, array.type_ids(), pool_, &type_ids));
355 out_->body_buffers.emplace_back(type_ids);
356
357 --max_recursion_depth_;
358 if (array.mode() == UnionMode::DENSE) {
359 const auto& type = checked_cast<const UnionType&>(*array.type());
360
361 std::shared_ptr<Buffer> value_offsets;
362 RETURN_NOT_OK(GetTruncatedBuffer<int32_t>(offset, length, array.value_offsets(),
363 pool_, &value_offsets));
364
365 // The Union type codes are not necessary 0-indexed
366 uint8_t max_code = 0;
367 for (uint8_t code : type.type_codes()) {
368 if (code > max_code) {
369 max_code = code;
370 }
371 }
372
373 // Allocate an array of child offsets. Set all to -1 to indicate that we
374 // haven't observed a first occurrence of a particular child yet
375 std::vector<int32_t> child_offsets(max_code + 1, -1);
376 std::vector<int32_t> child_lengths(max_code + 1, 0);
377
378 if (offset != 0) {
379 // This is an unpleasant case. Because the offsets are different for
380 // each child array, when we have a sliced array, we need to "rebase"
381 // the value_offsets for each array
382
383 const int32_t* unshifted_offsets = array.raw_value_offsets();
384 const uint8_t* type_ids = array.raw_type_ids();
385
386 // Allocate the shifted offsets
387 std::shared_ptr<Buffer> shifted_offsets_buffer;
388 RETURN_NOT_OK(
389 AllocateBuffer(pool_, length * sizeof(int32_t), &shifted_offsets_buffer));
390 int32_t* shifted_offsets =
391 reinterpret_cast<int32_t*>(shifted_offsets_buffer->mutable_data());
392
393 // Offsets may not be ascending, so we need to find out the start offset
394 // for each child
395 for (int64_t i = 0; i < length; ++i) {
396 const uint8_t code = type_ids[i];
397 if (child_offsets[code] == -1) {
398 child_offsets[code] = unshifted_offsets[i];
399 } else {
400 child_offsets[code] = std::min(child_offsets[code], unshifted_offsets[i]);
401 }
402 }
403
404 // Now compute shifted offsets by subtracting child offset
405 for (int64_t i = 0; i < length; ++i) {
406 const uint8_t code = type_ids[i];
407 shifted_offsets[i] = unshifted_offsets[i] - child_offsets[code];
408 // Update the child length to account for observed value
409 child_lengths[code] = std::max(child_lengths[code], shifted_offsets[i] + 1);
410 }
411
412 value_offsets = shifted_offsets_buffer;
413 }
414 out_->body_buffers.emplace_back(value_offsets);
415
416 // Visit children and slice accordingly
417 for (int i = 0; i < type.num_children(); ++i) {
418 std::shared_ptr<Array> child = array.child(i);
419
420 // TODO: ARROW-809, for sliced unions, tricky to know how much to
421 // truncate the children. For now, we are truncating the children to be
422 // no longer than the parent union.
423 if (offset != 0) {
424 const uint8_t code = type.type_codes()[i];
425 const int64_t child_offset = child_offsets[code];
426 const int64_t child_length = child_lengths[code];
427
428 if (child_offset > 0) {
429 child = child->Slice(child_offset, child_length);
430 } else if (child_length < child->length()) {
431 // This case includes when child is not encountered at all
432 child = child->Slice(0, child_length);
433 }
434 }
435 RETURN_NOT_OK(VisitArray(*child));
436 }
437 } else {
438 for (int i = 0; i < array.num_fields(); ++i) {
439 // Sparse union, slicing is done for us by child()
440 RETURN_NOT_OK(VisitArray(*array.child(i)));
441 }
442 }
443 ++max_recursion_depth_;
444 return Status::OK();
445 }
446
447 Status Visit(const DictionaryArray& array) override {
448 // Dictionary written out separately. Slice offset contained in the indices
449 return array.indices()->Accept(this);
450 }
451
452 // Destination for output buffers
453 IpcPayload* out_;
454
455 // In some cases, intermediate buffers may need to be allocated (with sliced arrays)
456 MemoryPool* pool_;
457
458 std::vector<internal::FieldMetadata> field_nodes_;
459 std::vector<internal::BufferMetadata> buffer_meta_;
460
461 int64_t max_recursion_depth_;
462 int64_t buffer_start_offset_;
463 bool allow_64bit_;
464};
465
466class DictionaryWriter : public RecordBatchSerializer {
467 public:
468 DictionaryWriter(int64_t dictionary_id, MemoryPool* pool, int64_t buffer_start_offset,
469 int max_recursion_depth, bool allow_64bit, IpcPayload* out)
470 : RecordBatchSerializer(pool, buffer_start_offset, max_recursion_depth, allow_64bit,
471 out),
472 dictionary_id_(dictionary_id) {}
473
474 Status SerializeMetadata(int64_t num_rows) override {
475 return WriteDictionaryMessage(dictionary_id_, num_rows, out_->body_length,
476 field_nodes_, buffer_meta_, &out_->metadata);
477 }
478
479 Status Assemble(const std::shared_ptr<Array>& dictionary) {
480 // Make a dummy record batch. A bit tedious as we have to make a schema
481 auto schema = arrow::schema({arrow::field("dictionary", dictionary->type())});
482 auto batch = RecordBatch::Make(schema, dictionary->length(), {dictionary});
483 return RecordBatchSerializer::Assemble(*batch);
484 }
485
486 private:
487 int64_t dictionary_id_;
488};
489
490Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst,
491 int32_t* metadata_length) {
492 RETURN_NOT_OK(internal::WriteMessage(*payload.metadata, kArrowIpcAlignment, dst,
493 metadata_length));
494
495#ifndef NDEBUG
496 RETURN_NOT_OK(CheckAligned(dst));
497#endif
498
499 // Now write the buffers
500 for (size_t i = 0; i < payload.body_buffers.size(); ++i) {
501 const Buffer* buffer = payload.body_buffers[i].get();
502 int64_t size = 0;
503 int64_t padding = 0;
504
505 // The buffer might be null if we are handling zero row lengths.
506 if (buffer) {
507 size = buffer->size();
508 padding = BitUtil::RoundUpToMultipleOf8(size) - size;
509 }
510
511 if (size > 0) {
512 RETURN_NOT_OK(dst->Write(buffer->data(), size));
513 }
514
515 if (padding > 0) {
516 RETURN_NOT_OK(dst->Write(kPaddingBytes, padding));
517 }
518 }
519
520#ifndef NDEBUG
521 RETURN_NOT_OK(CheckAligned(dst));
522#endif
523
524 return Status::OK();
525}
526
527Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool,
528 IpcPayload* out) {
529 RecordBatchSerializer writer(pool, 0, kMaxNestingDepth, true, out);
530 return writer.Assemble(batch);
531}
532
533} // namespace internal
534
535Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
536 io::OutputStream* dst, int32_t* metadata_length,
537 int64_t* body_length, MemoryPool* pool, int max_recursion_depth,
538 bool allow_64bit) {
539 internal::IpcPayload payload;
540 internal::RecordBatchSerializer writer(pool, buffer_start_offset, max_recursion_depth,
541 allow_64bit, &payload);
542 RETURN_NOT_OK(writer.Assemble(batch));
543
544 // TODO(wesm): it's a rough edge that the metadata and body length here are
545 // computed separately
546
547 // The body size is computed in the payload
548 *body_length = payload.body_length;
549
550 return internal::WriteIpcPayload(payload, dst, metadata_length);
551}
552
553Status WriteRecordBatchStream(const std::vector<std::shared_ptr<RecordBatch>>& batches,
554 io::OutputStream* dst) {
555 std::shared_ptr<RecordBatchWriter> writer;
556 RETURN_NOT_OK(RecordBatchStreamWriter::Open(dst, batches[0]->schema(), &writer));
557 for (const auto& batch : batches) {
558 // allow sizes > INT32_MAX
559 DCHECK(batch->schema()->Equals(*batches[0]->schema())) << "Schemas unequal";
560 RETURN_NOT_OK(writer->WriteRecordBatch(*batch, true));
561 }
562 RETURN_NOT_OK(writer->Close());
563 return Status::OK();
564}
565
566Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
567 io::OutputStream* dst, int32_t* metadata_length,
568 int64_t* body_length, MemoryPool* pool) {
569 return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length,
570 pool, kMaxNestingDepth, true);
571}
572
573namespace {
574
575Status WriteTensorHeader(const Tensor& tensor, io::OutputStream* dst,
576 int32_t* metadata_length) {
577 std::shared_ptr<Buffer> metadata;
578 RETURN_NOT_OK(internal::WriteTensorMessage(tensor, 0, &metadata));
579 return internal::WriteMessage(*metadata, kTensorAlignment, dst, metadata_length);
580}
581
582Status WriteStridedTensorData(int dim_index, int64_t offset, int elem_size,
583 const Tensor& tensor, uint8_t* scratch_space,
584 io::OutputStream* dst) {
585 if (dim_index == tensor.ndim() - 1) {
586 const uint8_t* data_ptr = tensor.raw_data() + offset;
587 const int64_t stride = tensor.strides()[dim_index];
588 for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
589 memcpy(scratch_space + i * elem_size, data_ptr, elem_size);
590 data_ptr += stride;
591 }
592 return dst->Write(scratch_space, elem_size * tensor.shape()[dim_index]);
593 }
594 for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
595 RETURN_NOT_OK(WriteStridedTensorData(dim_index + 1, offset, elem_size, tensor,
596 scratch_space, dst));
597 offset += tensor.strides()[dim_index];
598 }
599 return Status::OK();
600}
601
602Status GetContiguousTensor(const Tensor& tensor, MemoryPool* pool,
603 std::unique_ptr<Tensor>* out) {
604 const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
605 const int elem_size = type.bit_width() / 8;
606
607 std::shared_ptr<Buffer> scratch_space;
608 RETURN_NOT_OK(AllocateBuffer(pool, tensor.shape()[tensor.ndim() - 1] * elem_size,
609 &scratch_space));
610
611 std::shared_ptr<ResizableBuffer> contiguous_data;
612 RETURN_NOT_OK(
613 AllocateResizableBuffer(pool, tensor.size() * elem_size, &contiguous_data));
614
615 io::BufferOutputStream stream(contiguous_data);
616 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
617 scratch_space->mutable_data(), &stream));
618
619 out->reset(new Tensor(tensor.type(), contiguous_data, tensor.shape()));
620
621 return Status::OK();
622}
623
624} // namespace
625
626Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
627 int64_t* body_length) {
628 const auto& type = checked_cast<const FixedWidthType&>(*tensor.type());
629 const int elem_size = type.bit_width() / 8;
630
631 *body_length = tensor.size() * elem_size;
632
633 // Tensor metadata accounts for padding
634 if (tensor.is_contiguous()) {
635 RETURN_NOT_OK(WriteTensorHeader(tensor, dst, metadata_length));
636 auto data = tensor.data();
637 if (data && data->data()) {
638 RETURN_NOT_OK(dst->Write(data->data(), *body_length));
639 } else {
640 *body_length = 0;
641 }
642 } else {
643 // The tensor written is made contiguous
644 Tensor dummy(tensor.type(), nullptr, tensor.shape());
645 RETURN_NOT_OK(WriteTensorHeader(dummy, dst, metadata_length));
646
647 // TODO(wesm): Do we care enough about this temporary allocation to pass in
648 // a MemoryPool to this function?
649 std::shared_ptr<Buffer> scratch_space;
650 RETURN_NOT_OK(
651 AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size, &scratch_space));
652
653 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
654 scratch_space->mutable_data(), dst));
655 }
656
657 return Status::OK();
658}
659
660Status GetTensorMessage(const Tensor& tensor, MemoryPool* pool,
661 std::unique_ptr<Message>* out) {
662 const Tensor* tensor_to_write = &tensor;
663 std::unique_ptr<Tensor> temp_tensor;
664
665 if (!tensor.is_contiguous()) {
666 RETURN_NOT_OK(GetContiguousTensor(tensor, pool, &temp_tensor));
667 tensor_to_write = temp_tensor.get();
668 }
669
670 std::shared_ptr<Buffer> metadata;
671 RETURN_NOT_OK(internal::WriteTensorMessage(*tensor_to_write, 0, &metadata));
672 out->reset(new Message(metadata, tensor_to_write->data()));
673 return Status::OK();
674}
675
676namespace internal {
677
678class SparseTensorSerializer {
679 public:
680 SparseTensorSerializer(int64_t buffer_start_offset, IpcPayload* out)
681 : out_(out), buffer_start_offset_(buffer_start_offset) {}
682
683 ~SparseTensorSerializer() = default;
684
685 Status VisitSparseIndex(const SparseIndex& sparse_index) {
686 switch (sparse_index.format_id()) {
687 case SparseTensorFormat::COO:
688 RETURN_NOT_OK(
689 VisitSparseCOOIndex(checked_cast<const SparseCOOIndex&>(sparse_index)));
690 break;
691
692 case SparseTensorFormat::CSR:
693 RETURN_NOT_OK(
694 VisitSparseCSRIndex(checked_cast<const SparseCSRIndex&>(sparse_index)));
695 break;
696
697 default:
698 std::stringstream ss;
699 ss << "Unable to convert type: " << sparse_index.ToString() << std::endl;
700 return Status::NotImplemented(ss.str());
701 }
702
703 return Status::OK();
704 }
705
706 Status SerializeMetadata(const SparseTensor& sparse_tensor) {
707 return WriteSparseTensorMessage(sparse_tensor, out_->body_length, buffer_meta_,
708 &out_->metadata);
709 }
710
711 Status Assemble(const SparseTensor& sparse_tensor) {
712 if (buffer_meta_.size() > 0) {
713 buffer_meta_.clear();
714 out_->body_buffers.clear();
715 }
716
717 RETURN_NOT_OK(VisitSparseIndex(*sparse_tensor.sparse_index()));
718 out_->body_buffers.emplace_back(sparse_tensor.data());
719
720 int64_t offset = buffer_start_offset_;
721 buffer_meta_.reserve(out_->body_buffers.size());
722
723 for (size_t i = 0; i < out_->body_buffers.size(); ++i) {
724 const Buffer* buffer = out_->body_buffers[i].get();
725 int64_t size = buffer->size();
726 int64_t padding = BitUtil::RoundUpToMultipleOf8(size) - size;
727 buffer_meta_.push_back({offset, size + padding});
728 offset += size + padding;
729 }
730
731 out_->body_length = offset - buffer_start_offset_;
732 DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
733
734 return SerializeMetadata(sparse_tensor);
735 }
736
737 private:
738 Status VisitSparseCOOIndex(const SparseCOOIndex& sparse_index) {
739 out_->body_buffers.emplace_back(sparse_index.indices()->data());
740 return Status::OK();
741 }
742
743 Status VisitSparseCSRIndex(const SparseCSRIndex& sparse_index) {
744 out_->body_buffers.emplace_back(sparse_index.indptr()->data());
745 out_->body_buffers.emplace_back(sparse_index.indices()->data());
746 return Status::OK();
747 }
748
749 IpcPayload* out_;
750
751 std::vector<internal::BufferMetadata> buffer_meta_;
752
753 int64_t buffer_start_offset_;
754};
755
756Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool,
757 IpcPayload* out) {
758 SparseTensorSerializer writer(0, out);
759 return writer.Assemble(sparse_tensor);
760}
761
762} // namespace internal
763
764Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
765 int32_t* metadata_length, int64_t* body_length,
766 MemoryPool* pool) {
767 internal::IpcPayload payload;
768 internal::SparseTensorSerializer writer(0, &payload);
769 RETURN_NOT_OK(writer.Assemble(sparse_tensor));
770
771 *body_length = payload.body_length;
772 return internal::WriteIpcPayload(payload, dst, metadata_length);
773}
774
775Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary,
776 int64_t buffer_start_offset, io::OutputStream* dst,
777 int32_t* metadata_length, int64_t* body_length, MemoryPool* pool) {
778 internal::IpcPayload payload;
779 internal::DictionaryWriter writer(dictionary_id, pool, buffer_start_offset,
780 kMaxNestingDepth, true, &payload);
781 RETURN_NOT_OK(writer.Assemble(dictionary));
782
783 // The body size is computed in the payload
784 *body_length = payload.body_length;
785 return internal::WriteIpcPayload(payload, dst, metadata_length);
786}
787
788Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
789 // emulates the behavior of Write without actually writing
790 int32_t metadata_length = 0;
791 int64_t body_length = 0;
792 io::MockOutputStream dst;
793 RETURN_NOT_OK(WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length,
794 default_memory_pool(), kMaxNestingDepth, true));
795 *size = dst.GetExtentBytesWritten();
796 return Status::OK();
797}
798
799Status GetTensorSize(const Tensor& tensor, int64_t* size) {
800 // emulates the behavior of Write without actually writing
801 int32_t metadata_length = 0;
802 int64_t body_length = 0;
803 io::MockOutputStream dst;
804 RETURN_NOT_OK(WriteTensor(tensor, &dst, &metadata_length, &body_length));
805 *size = dst.GetExtentBytesWritten();
806 return Status::OK();
807}
808
809// ----------------------------------------------------------------------
810
811RecordBatchWriter::~RecordBatchWriter() {}
812
813Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) {
814 TableBatchReader reader(table);
815
816 if (max_chunksize > 0) {
817 reader.set_chunksize(max_chunksize);
818 }
819
820 std::shared_ptr<RecordBatch> batch;
821 while (true) {
822 RETURN_NOT_OK(reader.ReadNext(&batch));
823 if (batch == nullptr) {
824 break;
825 }
826 RETURN_NOT_OK(WriteRecordBatch(*batch, true));
827 }
828
829 return Status::OK();
830}
831
832Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); }
833
834// ----------------------------------------------------------------------
835// Stream writer implementation
836
837class StreamBookKeeper {
838 public:
839 StreamBookKeeper() : sink_(nullptr), position_(-1) {}
840 explicit StreamBookKeeper(io::OutputStream* sink) : sink_(sink), position_(-1) {}
841
842 Status UpdatePosition() { return sink_->Tell(&position_); }
843
844 Status UpdatePositionCheckAligned() {
845 RETURN_NOT_OK(UpdatePosition());
846 DCHECK_EQ(0, position_ % 8) << "Stream is not aligned";
847 return Status::OK();
848 }
849
850 Status Align(int32_t alignment = kArrowIpcAlignment) {
851 // Adds padding bytes if necessary to ensure all memory blocks are written on
852 // 8-byte (or other alignment) boundaries.
853 int64_t remainder = PaddedLength(position_, alignment) - position_;
854 if (remainder > 0) {
855 return Write(kPaddingBytes, remainder);
856 }
857 return Status::OK();
858 }
859
860 // Write data and update position
861 Status Write(const void* data, int64_t nbytes) {
862 RETURN_NOT_OK(sink_->Write(data, nbytes));
863 position_ += nbytes;
864 return Status::OK();
865 }
866
867 protected:
868 io::OutputStream* sink_;
869 int64_t position_;
870};
871
872class SchemaWriter : public StreamBookKeeper {
873 public:
874 SchemaWriter(const Schema& schema, DictionaryMemo* dictionary_memo, MemoryPool* pool,
875 io::OutputStream* sink)
876 : StreamBookKeeper(sink),
877 pool_(pool),
878 schema_(schema),
879 dictionary_memo_(dictionary_memo) {}
880
881 Status WriteSchema() {
882#ifndef NDEBUG
883 // Catch bug fixed in ARROW-3236
884 RETURN_NOT_OK(UpdatePositionCheckAligned());
885#endif
886
887 std::shared_ptr<Buffer> schema_fb;
888 RETURN_NOT_OK(internal::WriteSchemaMessage(schema_, dictionary_memo_, &schema_fb));
889
890 int32_t metadata_length = 0;
891 RETURN_NOT_OK(internal::WriteMessage(*schema_fb, 8, sink_, &metadata_length));
892 RETURN_NOT_OK(UpdatePositionCheckAligned());
893 return Status::OK();
894 }
895
896 Status WriteDictionaries(std::vector<FileBlock>* dictionaries) {
897 const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary();
898
899 dictionaries->resize(id_to_dictionary.size());
900
901 // TODO(wesm): does sorting by id yield any benefit?
902 int dict_index = 0;
903 for (const auto& entry : id_to_dictionary) {
904 FileBlock* block = &(*dictionaries)[dict_index++];
905
906 block->offset = position_;
907
908 // Frame of reference in file format is 0, see ARROW-384
909 const int64_t buffer_start_offset = 0;
910 RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, buffer_start_offset, sink_,
911 &block->metadata_length, &block->body_length, pool_));
912 RETURN_NOT_OK(UpdatePositionCheckAligned());
913 }
914
915 return Status::OK();
916 }
917
918 Status Write(std::vector<FileBlock>* dictionaries) {
919 RETURN_NOT_OK(WriteSchema());
920
921 // If there are any dictionaries, write them as the next messages
922 return WriteDictionaries(dictionaries);
923 }
924
925 private:
926 MemoryPool* pool_;
927 const Schema& schema_;
928 DictionaryMemo* dictionary_memo_;
929};
930
931class RecordBatchStreamWriter::RecordBatchStreamWriterImpl : public StreamBookKeeper {
932 public:
933 RecordBatchStreamWriterImpl(io::OutputStream* sink,
934 const std::shared_ptr<Schema>& schema)
935 : StreamBookKeeper(sink),
936 schema_(schema),
937 pool_(default_memory_pool()),
938 started_(false) {}
939
940 virtual ~RecordBatchStreamWriterImpl() = default;
941
942 virtual Status Start() {
943 SchemaWriter schema_writer(*schema_, &dictionary_memo_, pool_, sink_);
944 RETURN_NOT_OK(schema_writer.Write(&dictionaries_));
945 started_ = true;
946 return Status::OK();
947 }
948
949 virtual Status Close() {
950 // Write the schema if not already written
951 // User is responsible for closing the OutputStream
952 RETURN_NOT_OK(CheckStarted());
953
954 // Write 0 EOS message
955 const int32_t kEos = 0;
956 return Write(&kEos, sizeof(int32_t));
957 }
958
959 Status CheckStarted() {
960 if (!started_) {
961 return Start();
962 }
963 return Status::OK();
964 }
965
966 Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit, FileBlock* block) {
967 RETURN_NOT_OK(CheckStarted());
968 RETURN_NOT_OK(UpdatePosition());
969
970 block->offset = position_;
971
972 // Frame of reference in file format is 0, see ARROW-384
973 const int64_t buffer_start_offset = 0;
974 RETURN_NOT_OK(arrow::ipc::WriteRecordBatch(
975 batch, buffer_start_offset, sink_, &block->metadata_length, &block->body_length,
976 pool_, kMaxNestingDepth, allow_64bit));
977 RETURN_NOT_OK(UpdatePositionCheckAligned());
978
979 return Status::OK();
980 }
981
982 Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) {
983 // Push an empty FileBlock. Can be written in the footer later
984 record_batches_.push_back({0, 0, 0});
985 return WriteRecordBatch(batch, allow_64bit,
986 &record_batches_[record_batches_.size() - 1]);
987 }
988
989 void set_memory_pool(MemoryPool* pool) { pool_ = pool; }
990
991 protected:
992 std::shared_ptr<Schema> schema_;
993 MemoryPool* pool_;
994 bool started_;
995
996 // When writing out the schema, we keep track of all the dictionaries we
997 // encounter, as they must be written out first in the stream
998 DictionaryMemo dictionary_memo_;
999
1000 std::vector<FileBlock> dictionaries_;
1001 std::vector<FileBlock> record_batches_;
1002};
1003
1004RecordBatchStreamWriter::RecordBatchStreamWriter() {}
1005
1006RecordBatchStreamWriter::~RecordBatchStreamWriter() {}
1007
1008Status RecordBatchStreamWriter::WriteRecordBatch(const RecordBatch& batch,
1009 bool allow_64bit) {
1010 return impl_->WriteRecordBatch(batch, allow_64bit);
1011}
1012
1013void RecordBatchStreamWriter::set_memory_pool(MemoryPool* pool) {
1014 impl_->set_memory_pool(pool);
1015}
1016
1017Status RecordBatchStreamWriter::Open(io::OutputStream* sink,
1018 const std::shared_ptr<Schema>& schema,
1019 std::shared_ptr<RecordBatchWriter>* out) {
1020 // ctor is private
1021 auto result = std::shared_ptr<RecordBatchStreamWriter>(new RecordBatchStreamWriter());
1022 result->impl_.reset(new RecordBatchStreamWriterImpl(sink, schema));
1023 *out = result;
1024 return Status::OK();
1025}
1026
1027Status RecordBatchStreamWriter::Close() { return impl_->Close(); }
1028
1029// ----------------------------------------------------------------------
1030// File writer implementation
1031
1032class RecordBatchFileWriter::RecordBatchFileWriterImpl
1033 : public RecordBatchStreamWriter::RecordBatchStreamWriterImpl {
1034 public:
1035 using BASE = RecordBatchStreamWriter::RecordBatchStreamWriterImpl;
1036
1037 RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr<Schema>& schema)
1038 : BASE(sink, schema) {}
1039
1040 Status Start() override {
1041 // ARROW-3236: The initial position -1 needs to be updated to the stream's
1042 // current position otherwise an incorrect amount of padding will be
1043 // written to new files.
1044 RETURN_NOT_OK(UpdatePosition());
1045
1046 // It is only necessary to align to 8-byte boundary at the start of the file
1047 RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes)));
1048 RETURN_NOT_OK(Align());
1049
1050 // We write the schema at the start of the file (and the end). This also
1051 // writes all the dictionaries at the beginning of the file
1052 return BASE::Start();
1053 }
1054
1055 Status Close() override {
1056 // Write the schema if not already written
1057 // User is responsible for closing the OutputStream
1058 RETURN_NOT_OK(CheckStarted());
1059
1060 // Write metadata
1061 RETURN_NOT_OK(UpdatePosition());
1062
1063 int64_t initial_position = position_;
1064 RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_,
1065 &dictionary_memo_, sink_));
1066 RETURN_NOT_OK(UpdatePosition());
1067
1068 // Write footer length
1069 int32_t footer_length = static_cast<int32_t>(position_ - initial_position);
1070
1071 if (footer_length <= 0) {
1072 return Status::Invalid("Invalid file footer");
1073 }
1074
1075 RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t)));
1076
1077 // Write magic bytes to end file
1078 return Write(kArrowMagicBytes, strlen(kArrowMagicBytes));
1079 }
1080};
1081
1082RecordBatchFileWriter::RecordBatchFileWriter() {}
1083
1084RecordBatchFileWriter::~RecordBatchFileWriter() {}
1085
1086Status RecordBatchFileWriter::Open(io::OutputStream* sink,
1087 const std::shared_ptr<Schema>& schema,
1088 std::shared_ptr<RecordBatchWriter>* out) {
1089 // ctor is private
1090 auto result = std::shared_ptr<RecordBatchFileWriter>(new RecordBatchFileWriter());
1091 result->file_impl_.reset(new RecordBatchFileWriterImpl(sink, schema));
1092 *out = result;
1093 return Status::OK();
1094}
1095
1096Status RecordBatchFileWriter::WriteRecordBatch(const RecordBatch& batch,
1097 bool allow_64bit) {
1098 return file_impl_->WriteRecordBatch(batch, allow_64bit);
1099}
1100
1101Status RecordBatchFileWriter::Close() { return file_impl_->Close(); }
1102
1103// ----------------------------------------------------------------------
1104// Serialization public APIs
1105
1106Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool,
1107 std::shared_ptr<Buffer>* out) {
1108 int64_t size = 0;
1109 RETURN_NOT_OK(GetRecordBatchSize(batch, &size));
1110 std::shared_ptr<Buffer> buffer;
1111 RETURN_NOT_OK(AllocateBuffer(pool, size, &buffer));
1112
1113 io::FixedSizeBufferWriter stream(buffer);
1114 RETURN_NOT_OK(SerializeRecordBatch(batch, pool, &stream));
1115 *out = buffer;
1116 return Status::OK();
1117}
1118
1119Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool,
1120 io::OutputStream* out) {
1121 int32_t metadata_length = 0;
1122 int64_t body_length = 0;
1123 return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, pool,
1124 kMaxNestingDepth, true);
1125}
1126
1127Status SerializeSchema(const Schema& schema, MemoryPool* pool,
1128 std::shared_ptr<Buffer>* out) {
1129 std::shared_ptr<io::BufferOutputStream> stream;
1130 RETURN_NOT_OK(io::BufferOutputStream::Create(1024, pool, &stream));
1131
1132 DictionaryMemo memo;
1133 SchemaWriter schema_writer(schema, &memo, pool, stream.get());
1134
1135 // Unused
1136 std::vector<FileBlock> dictionary_blocks;
1137
1138 RETURN_NOT_OK(schema_writer.Write(&dictionary_blocks));
1139 return stream->Finish(out);
1140}
1141
1142} // namespace ipc
1143} // namespace arrow
1144