1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/ipc/writer.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <cstring> |
23 | #include <limits> |
24 | #include <sstream> |
25 | #include <vector> |
26 | |
27 | #include "arrow/array.h" |
28 | #include "arrow/buffer.h" |
29 | #include "arrow/io/interfaces.h" |
30 | #include "arrow/io/memory.h" |
31 | #include "arrow/ipc/dictionary.h" |
32 | #include "arrow/ipc/message.h" |
33 | #include "arrow/ipc/metadata-internal.h" |
34 | #include "arrow/ipc/util.h" |
35 | #include "arrow/memory_pool.h" |
36 | #include "arrow/record_batch.h" |
37 | #include "arrow/sparse_tensor.h" |
38 | #include "arrow/status.h" |
39 | #include "arrow/table.h" |
40 | #include "arrow/tensor.h" |
41 | #include "arrow/type.h" |
42 | #include "arrow/util/bit-util.h" |
43 | #include "arrow/util/checked_cast.h" |
44 | #include "arrow/util/logging.h" |
45 | #include "arrow/visitor.h" |
46 | |
47 | namespace arrow { |
48 | |
49 | using internal::checked_cast; |
50 | using internal::CopyBitmap; |
51 | |
52 | namespace ipc { |
53 | |
54 | using internal::FileBlock; |
55 | using internal::kArrowMagicBytes; |
56 | |
57 | // ---------------------------------------------------------------------- |
58 | // Record batch write path |
59 | |
60 | static inline Status GetTruncatedBitmap(int64_t offset, int64_t length, |
61 | const std::shared_ptr<Buffer> input, |
62 | MemoryPool* pool, |
63 | std::shared_ptr<Buffer>* buffer) { |
64 | if (!input) { |
65 | *buffer = input; |
66 | return Status::OK(); |
67 | } |
68 | int64_t min_length = PaddedLength(BitUtil::BytesForBits(length)); |
69 | if (offset != 0 || min_length < input->size()) { |
70 | // With a sliced array / non-zero offset, we must copy the bitmap |
71 | RETURN_NOT_OK(CopyBitmap(pool, input->data(), offset, length, buffer)); |
72 | } else { |
73 | *buffer = input; |
74 | } |
75 | return Status::OK(); |
76 | } |
77 | |
78 | template <typename T> |
79 | inline Status GetTruncatedBuffer(int64_t offset, int64_t length, |
80 | const std::shared_ptr<Buffer> input, MemoryPool* pool, |
81 | std::shared_ptr<Buffer>* buffer) { |
82 | if (!input) { |
83 | *buffer = input; |
84 | return Status::OK(); |
85 | } |
86 | int32_t byte_width = static_cast<int32_t>(sizeof(T)); |
87 | int64_t padded_length = PaddedLength(length * byte_width); |
88 | if (offset != 0 || padded_length < input->size()) { |
89 | *buffer = |
90 | SliceBuffer(input, offset * byte_width, std::min(padded_length, input->size())); |
91 | } else { |
92 | *buffer = input; |
93 | } |
94 | return Status::OK(); |
95 | } |
96 | |
97 | static inline bool NeedTruncate(int64_t offset, const Buffer* buffer, |
98 | int64_t min_length) { |
99 | // buffer can be NULL |
100 | if (buffer == nullptr) { |
101 | return false; |
102 | } |
103 | return offset != 0 || min_length < buffer->size(); |
104 | } |
105 | |
106 | namespace internal { |
107 | |
108 | class RecordBatchSerializer : public ArrayVisitor { |
109 | public: |
110 | RecordBatchSerializer(MemoryPool* pool, int64_t buffer_start_offset, |
111 | int max_recursion_depth, bool allow_64bit, IpcPayload* out) |
112 | : out_(out), |
113 | pool_(pool), |
114 | max_recursion_depth_(max_recursion_depth), |
115 | buffer_start_offset_(buffer_start_offset), |
116 | allow_64bit_(allow_64bit) { |
117 | DCHECK_GT(max_recursion_depth, 0); |
118 | } |
119 | |
120 | ~RecordBatchSerializer() override = default; |
121 | |
122 | Status VisitArray(const Array& arr) { |
123 | if (max_recursion_depth_ <= 0) { |
124 | return Status::Invalid("Max recursion depth reached" ); |
125 | } |
126 | |
127 | if (!allow_64bit_ && arr.length() > std::numeric_limits<int32_t>::max()) { |
128 | return Status::CapacityError("Cannot write arrays larger than 2^31 - 1 in length" ); |
129 | } |
130 | |
131 | // push back all common elements |
132 | field_nodes_.push_back({arr.length(), arr.null_count(), 0}); |
133 | |
134 | if (arr.null_count() > 0) { |
135 | std::shared_ptr<Buffer> bitmap; |
136 | RETURN_NOT_OK(GetTruncatedBitmap(arr.offset(), arr.length(), arr.null_bitmap(), |
137 | pool_, &bitmap)); |
138 | out_->body_buffers.emplace_back(bitmap); |
139 | } else { |
140 | // Push a dummy zero-length buffer, not to be copied |
141 | out_->body_buffers.emplace_back(std::make_shared<Buffer>(nullptr, 0)); |
142 | } |
143 | return arr.Accept(this); |
144 | } |
145 | |
146 | // Override this for writing dictionary metadata |
147 | virtual Status SerializeMetadata(int64_t num_rows) { |
148 | return WriteRecordBatchMessage(num_rows, out_->body_length, field_nodes_, |
149 | buffer_meta_, &out_->metadata); |
150 | } |
151 | |
152 | Status Assemble(const RecordBatch& batch) { |
153 | if (field_nodes_.size() > 0) { |
154 | field_nodes_.clear(); |
155 | buffer_meta_.clear(); |
156 | out_->body_buffers.clear(); |
157 | } |
158 | |
159 | // Perform depth-first traversal of the row-batch |
160 | for (int i = 0; i < batch.num_columns(); ++i) { |
161 | RETURN_NOT_OK(VisitArray(*batch.column(i))); |
162 | } |
163 | |
164 | // The position for the start of a buffer relative to the passed frame of |
165 | // reference. May be 0 or some other position in an address space |
166 | int64_t offset = buffer_start_offset_; |
167 | |
168 | buffer_meta_.reserve(out_->body_buffers.size()); |
169 | |
170 | // Construct the buffer metadata for the record batch header |
171 | for (size_t i = 0; i < out_->body_buffers.size(); ++i) { |
172 | const Buffer* buffer = out_->body_buffers[i].get(); |
173 | int64_t size = 0; |
174 | int64_t padding = 0; |
175 | |
176 | // The buffer might be null if we are handling zero row lengths. |
177 | if (buffer) { |
178 | size = buffer->size(); |
179 | padding = BitUtil::RoundUpToMultipleOf8(size) - size; |
180 | } |
181 | |
182 | buffer_meta_.push_back({offset, size + padding}); |
183 | offset += size + padding; |
184 | } |
185 | |
186 | out_->body_length = offset - buffer_start_offset_; |
187 | DCHECK(BitUtil::IsMultipleOf8(out_->body_length)); |
188 | |
189 | // Now that we have computed the locations of all of the buffers in shared |
190 | // memory, the data header can be converted to a flatbuffer and written out |
191 | // |
192 | // Note: The memory written here is prefixed by the size of the flatbuffer |
193 | // itself as an int32_t. |
194 | return SerializeMetadata(batch.num_rows()); |
195 | } |
196 | |
197 | protected: |
198 | template <typename ArrayType> |
199 | Status VisitFixedWidth(const ArrayType& array) { |
200 | std::shared_ptr<Buffer> data = array.values(); |
201 | |
202 | const auto& fw_type = checked_cast<const FixedWidthType&>(*array.type()); |
203 | const int64_t type_width = fw_type.bit_width() / 8; |
204 | int64_t min_length = PaddedLength(array.length() * type_width); |
205 | |
206 | if (NeedTruncate(array.offset(), data.get(), min_length)) { |
207 | // Non-zero offset, slice the buffer |
208 | const int64_t byte_offset = array.offset() * type_width; |
209 | |
210 | // Send padding if it's available |
211 | const int64_t buffer_length = |
212 | std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width), |
213 | data->size() - byte_offset); |
214 | data = SliceBuffer(data, byte_offset, buffer_length); |
215 | } |
216 | out_->body_buffers.emplace_back(data); |
217 | return Status::OK(); |
218 | } |
219 | |
220 | template <typename ArrayType> |
221 | Status GetZeroBasedValueOffsets(const ArrayType& array, |
222 | std::shared_ptr<Buffer>* value_offsets) { |
223 | // Share slicing logic between ListArray and BinaryArray |
224 | |
225 | auto offsets = array.value_offsets(); |
226 | |
227 | if (array.offset() != 0) { |
228 | // If we have a non-zero offset, then the value offsets do not start at |
229 | // zero. We must a) create a new offsets array with shifted offsets and |
230 | // b) slice the values array accordingly |
231 | |
232 | std::shared_ptr<Buffer> shifted_offsets; |
233 | RETURN_NOT_OK(AllocateBuffer(pool_, sizeof(int32_t) * (array.length() + 1), |
234 | &shifted_offsets)); |
235 | |
236 | int32_t* dest_offsets = reinterpret_cast<int32_t*>(shifted_offsets->mutable_data()); |
237 | const int32_t start_offset = array.value_offset(0); |
238 | |
239 | for (int i = 0; i < array.length(); ++i) { |
240 | dest_offsets[i] = array.value_offset(i) - start_offset; |
241 | } |
242 | // Final offset |
243 | dest_offsets[array.length()] = array.value_offset(array.length()) - start_offset; |
244 | offsets = shifted_offsets; |
245 | } |
246 | |
247 | *value_offsets = offsets; |
248 | return Status::OK(); |
249 | } |
250 | |
251 | Status VisitBinary(const BinaryArray& array) { |
252 | std::shared_ptr<Buffer> value_offsets; |
253 | RETURN_NOT_OK(GetZeroBasedValueOffsets<BinaryArray>(array, &value_offsets)); |
254 | auto data = array.value_data(); |
255 | |
256 | int64_t total_data_bytes = 0; |
257 | if (value_offsets) { |
258 | total_data_bytes = array.value_offset(array.length()) - array.value_offset(0); |
259 | } |
260 | if (NeedTruncate(array.offset(), data.get(), total_data_bytes)) { |
261 | // Slice the data buffer to include only the range we need now |
262 | const int64_t start_offset = array.value_offset(0); |
263 | const int64_t slice_length = |
264 | std::min(PaddedLength(total_data_bytes), data->size() - start_offset); |
265 | data = SliceBuffer(data, start_offset, slice_length); |
266 | } |
267 | |
268 | out_->body_buffers.emplace_back(value_offsets); |
269 | out_->body_buffers.emplace_back(data); |
270 | return Status::OK(); |
271 | } |
272 | |
273 | Status Visit(const BooleanArray& array) override { |
274 | std::shared_ptr<Buffer> data; |
275 | RETURN_NOT_OK( |
276 | GetTruncatedBitmap(array.offset(), array.length(), array.values(), pool_, &data)); |
277 | out_->body_buffers.emplace_back(data); |
278 | return Status::OK(); |
279 | } |
280 | |
281 | Status Visit(const NullArray& array) override { |
282 | out_->body_buffers.emplace_back(nullptr); |
283 | return Status::OK(); |
284 | } |
285 | |
286 | #define VISIT_FIXED_WIDTH(TYPE) \ |
287 | Status Visit(const TYPE& array) override { return VisitFixedWidth<TYPE>(array); } |
288 | |
289 | VISIT_FIXED_WIDTH(Int8Array) |
290 | VISIT_FIXED_WIDTH(Int16Array) |
291 | VISIT_FIXED_WIDTH(Int32Array) |
292 | VISIT_FIXED_WIDTH(Int64Array) |
293 | VISIT_FIXED_WIDTH(UInt8Array) |
294 | VISIT_FIXED_WIDTH(UInt16Array) |
295 | VISIT_FIXED_WIDTH(UInt32Array) |
296 | VISIT_FIXED_WIDTH(UInt64Array) |
297 | VISIT_FIXED_WIDTH(HalfFloatArray) |
298 | VISIT_FIXED_WIDTH(FloatArray) |
299 | VISIT_FIXED_WIDTH(DoubleArray) |
300 | VISIT_FIXED_WIDTH(Date32Array) |
301 | VISIT_FIXED_WIDTH(Date64Array) |
302 | VISIT_FIXED_WIDTH(TimestampArray) |
303 | VISIT_FIXED_WIDTH(Time32Array) |
304 | VISIT_FIXED_WIDTH(Time64Array) |
305 | VISIT_FIXED_WIDTH(FixedSizeBinaryArray) |
306 | VISIT_FIXED_WIDTH(Decimal128Array) |
307 | |
308 | #undef VISIT_FIXED_WIDTH |
309 | |
310 | Status Visit(const StringArray& array) override { return VisitBinary(array); } |
311 | |
312 | Status Visit(const BinaryArray& array) override { return VisitBinary(array); } |
313 | |
314 | Status Visit(const ListArray& array) override { |
315 | std::shared_ptr<Buffer> value_offsets; |
316 | RETURN_NOT_OK(GetZeroBasedValueOffsets<ListArray>(array, &value_offsets)); |
317 | out_->body_buffers.emplace_back(value_offsets); |
318 | |
319 | --max_recursion_depth_; |
320 | std::shared_ptr<Array> values = array.values(); |
321 | |
322 | int32_t values_offset = 0; |
323 | int32_t values_length = 0; |
324 | if (value_offsets) { |
325 | values_offset = array.value_offset(0); |
326 | values_length = array.value_offset(array.length()) - values_offset; |
327 | } |
328 | |
329 | if (array.offset() != 0 || values_length < values->length()) { |
330 | // Must also slice the values |
331 | values = values->Slice(values_offset, values_length); |
332 | } |
333 | RETURN_NOT_OK(VisitArray(*values)); |
334 | ++max_recursion_depth_; |
335 | return Status::OK(); |
336 | } |
337 | |
338 | Status Visit(const StructArray& array) override { |
339 | --max_recursion_depth_; |
340 | for (int i = 0; i < array.num_fields(); ++i) { |
341 | std::shared_ptr<Array> field = array.field(i); |
342 | RETURN_NOT_OK(VisitArray(*field)); |
343 | } |
344 | ++max_recursion_depth_; |
345 | return Status::OK(); |
346 | } |
347 | |
348 | Status Visit(const UnionArray& array) override { |
349 | const int64_t offset = array.offset(); |
350 | const int64_t length = array.length(); |
351 | |
352 | std::shared_ptr<Buffer> type_ids; |
353 | RETURN_NOT_OK(GetTruncatedBuffer<UnionArray::type_id_t>( |
354 | offset, length, array.type_ids(), pool_, &type_ids)); |
355 | out_->body_buffers.emplace_back(type_ids); |
356 | |
357 | --max_recursion_depth_; |
358 | if (array.mode() == UnionMode::DENSE) { |
359 | const auto& type = checked_cast<const UnionType&>(*array.type()); |
360 | |
361 | std::shared_ptr<Buffer> value_offsets; |
362 | RETURN_NOT_OK(GetTruncatedBuffer<int32_t>(offset, length, array.value_offsets(), |
363 | pool_, &value_offsets)); |
364 | |
365 | // The Union type codes are not necessary 0-indexed |
366 | uint8_t max_code = 0; |
367 | for (uint8_t code : type.type_codes()) { |
368 | if (code > max_code) { |
369 | max_code = code; |
370 | } |
371 | } |
372 | |
373 | // Allocate an array of child offsets. Set all to -1 to indicate that we |
374 | // haven't observed a first occurrence of a particular child yet |
375 | std::vector<int32_t> child_offsets(max_code + 1, -1); |
376 | std::vector<int32_t> child_lengths(max_code + 1, 0); |
377 | |
378 | if (offset != 0) { |
379 | // This is an unpleasant case. Because the offsets are different for |
380 | // each child array, when we have a sliced array, we need to "rebase" |
381 | // the value_offsets for each array |
382 | |
383 | const int32_t* unshifted_offsets = array.raw_value_offsets(); |
384 | const uint8_t* type_ids = array.raw_type_ids(); |
385 | |
386 | // Allocate the shifted offsets |
387 | std::shared_ptr<Buffer> shifted_offsets_buffer; |
388 | RETURN_NOT_OK( |
389 | AllocateBuffer(pool_, length * sizeof(int32_t), &shifted_offsets_buffer)); |
390 | int32_t* shifted_offsets = |
391 | reinterpret_cast<int32_t*>(shifted_offsets_buffer->mutable_data()); |
392 | |
393 | // Offsets may not be ascending, so we need to find out the start offset |
394 | // for each child |
395 | for (int64_t i = 0; i < length; ++i) { |
396 | const uint8_t code = type_ids[i]; |
397 | if (child_offsets[code] == -1) { |
398 | child_offsets[code] = unshifted_offsets[i]; |
399 | } else { |
400 | child_offsets[code] = std::min(child_offsets[code], unshifted_offsets[i]); |
401 | } |
402 | } |
403 | |
404 | // Now compute shifted offsets by subtracting child offset |
405 | for (int64_t i = 0; i < length; ++i) { |
406 | const uint8_t code = type_ids[i]; |
407 | shifted_offsets[i] = unshifted_offsets[i] - child_offsets[code]; |
408 | // Update the child length to account for observed value |
409 | child_lengths[code] = std::max(child_lengths[code], shifted_offsets[i] + 1); |
410 | } |
411 | |
412 | value_offsets = shifted_offsets_buffer; |
413 | } |
414 | out_->body_buffers.emplace_back(value_offsets); |
415 | |
416 | // Visit children and slice accordingly |
417 | for (int i = 0; i < type.num_children(); ++i) { |
418 | std::shared_ptr<Array> child = array.child(i); |
419 | |
420 | // TODO: ARROW-809, for sliced unions, tricky to know how much to |
421 | // truncate the children. For now, we are truncating the children to be |
422 | // no longer than the parent union. |
423 | if (offset != 0) { |
424 | const uint8_t code = type.type_codes()[i]; |
425 | const int64_t child_offset = child_offsets[code]; |
426 | const int64_t child_length = child_lengths[code]; |
427 | |
428 | if (child_offset > 0) { |
429 | child = child->Slice(child_offset, child_length); |
430 | } else if (child_length < child->length()) { |
431 | // This case includes when child is not encountered at all |
432 | child = child->Slice(0, child_length); |
433 | } |
434 | } |
435 | RETURN_NOT_OK(VisitArray(*child)); |
436 | } |
437 | } else { |
438 | for (int i = 0; i < array.num_fields(); ++i) { |
439 | // Sparse union, slicing is done for us by child() |
440 | RETURN_NOT_OK(VisitArray(*array.child(i))); |
441 | } |
442 | } |
443 | ++max_recursion_depth_; |
444 | return Status::OK(); |
445 | } |
446 | |
447 | Status Visit(const DictionaryArray& array) override { |
448 | // Dictionary written out separately. Slice offset contained in the indices |
449 | return array.indices()->Accept(this); |
450 | } |
451 | |
452 | // Destination for output buffers |
453 | IpcPayload* out_; |
454 | |
455 | // In some cases, intermediate buffers may need to be allocated (with sliced arrays) |
456 | MemoryPool* pool_; |
457 | |
458 | std::vector<internal::FieldMetadata> field_nodes_; |
459 | std::vector<internal::BufferMetadata> buffer_meta_; |
460 | |
461 | int64_t max_recursion_depth_; |
462 | int64_t buffer_start_offset_; |
463 | bool allow_64bit_; |
464 | }; |
465 | |
466 | class DictionaryWriter : public RecordBatchSerializer { |
467 | public: |
468 | DictionaryWriter(int64_t dictionary_id, MemoryPool* pool, int64_t buffer_start_offset, |
469 | int max_recursion_depth, bool allow_64bit, IpcPayload* out) |
470 | : RecordBatchSerializer(pool, buffer_start_offset, max_recursion_depth, allow_64bit, |
471 | out), |
472 | dictionary_id_(dictionary_id) {} |
473 | |
474 | Status SerializeMetadata(int64_t num_rows) override { |
475 | return WriteDictionaryMessage(dictionary_id_, num_rows, out_->body_length, |
476 | field_nodes_, buffer_meta_, &out_->metadata); |
477 | } |
478 | |
479 | Status Assemble(const std::shared_ptr<Array>& dictionary) { |
480 | // Make a dummy record batch. A bit tedious as we have to make a schema |
481 | auto schema = arrow::schema({arrow::field("dictionary" , dictionary->type())}); |
482 | auto batch = RecordBatch::Make(schema, dictionary->length(), {dictionary}); |
483 | return RecordBatchSerializer::Assemble(*batch); |
484 | } |
485 | |
486 | private: |
487 | int64_t dictionary_id_; |
488 | }; |
489 | |
490 | Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst, |
491 | int32_t* metadata_length) { |
492 | RETURN_NOT_OK(internal::WriteMessage(*payload.metadata, kArrowIpcAlignment, dst, |
493 | metadata_length)); |
494 | |
495 | #ifndef NDEBUG |
496 | RETURN_NOT_OK(CheckAligned(dst)); |
497 | #endif |
498 | |
499 | // Now write the buffers |
500 | for (size_t i = 0; i < payload.body_buffers.size(); ++i) { |
501 | const Buffer* buffer = payload.body_buffers[i].get(); |
502 | int64_t size = 0; |
503 | int64_t padding = 0; |
504 | |
505 | // The buffer might be null if we are handling zero row lengths. |
506 | if (buffer) { |
507 | size = buffer->size(); |
508 | padding = BitUtil::RoundUpToMultipleOf8(size) - size; |
509 | } |
510 | |
511 | if (size > 0) { |
512 | RETURN_NOT_OK(dst->Write(buffer->data(), size)); |
513 | } |
514 | |
515 | if (padding > 0) { |
516 | RETURN_NOT_OK(dst->Write(kPaddingBytes, padding)); |
517 | } |
518 | } |
519 | |
520 | #ifndef NDEBUG |
521 | RETURN_NOT_OK(CheckAligned(dst)); |
522 | #endif |
523 | |
524 | return Status::OK(); |
525 | } |
526 | |
527 | Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool, |
528 | IpcPayload* out) { |
529 | RecordBatchSerializer writer(pool, 0, kMaxNestingDepth, true, out); |
530 | return writer.Assemble(batch); |
531 | } |
532 | |
533 | } // namespace internal |
534 | |
535 | Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, |
536 | io::OutputStream* dst, int32_t* metadata_length, |
537 | int64_t* body_length, MemoryPool* pool, int max_recursion_depth, |
538 | bool allow_64bit) { |
539 | internal::IpcPayload payload; |
540 | internal::RecordBatchSerializer writer(pool, buffer_start_offset, max_recursion_depth, |
541 | allow_64bit, &payload); |
542 | RETURN_NOT_OK(writer.Assemble(batch)); |
543 | |
544 | // TODO(wesm): it's a rough edge that the metadata and body length here are |
545 | // computed separately |
546 | |
547 | // The body size is computed in the payload |
548 | *body_length = payload.body_length; |
549 | |
550 | return internal::WriteIpcPayload(payload, dst, metadata_length); |
551 | } |
552 | |
553 | Status WriteRecordBatchStream(const std::vector<std::shared_ptr<RecordBatch>>& batches, |
554 | io::OutputStream* dst) { |
555 | std::shared_ptr<RecordBatchWriter> writer; |
556 | RETURN_NOT_OK(RecordBatchStreamWriter::Open(dst, batches[0]->schema(), &writer)); |
557 | for (const auto& batch : batches) { |
558 | // allow sizes > INT32_MAX |
559 | DCHECK(batch->schema()->Equals(*batches[0]->schema())) << "Schemas unequal" ; |
560 | RETURN_NOT_OK(writer->WriteRecordBatch(*batch, true)); |
561 | } |
562 | RETURN_NOT_OK(writer->Close()); |
563 | return Status::OK(); |
564 | } |
565 | |
566 | Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, |
567 | io::OutputStream* dst, int32_t* metadata_length, |
568 | int64_t* body_length, MemoryPool* pool) { |
569 | return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length, |
570 | pool, kMaxNestingDepth, true); |
571 | } |
572 | |
573 | namespace { |
574 | |
575 | Status (const Tensor& tensor, io::OutputStream* dst, |
576 | int32_t* metadata_length) { |
577 | std::shared_ptr<Buffer> metadata; |
578 | RETURN_NOT_OK(internal::WriteTensorMessage(tensor, 0, &metadata)); |
579 | return internal::WriteMessage(*metadata, kTensorAlignment, dst, metadata_length); |
580 | } |
581 | |
582 | Status WriteStridedTensorData(int dim_index, int64_t offset, int elem_size, |
583 | const Tensor& tensor, uint8_t* scratch_space, |
584 | io::OutputStream* dst) { |
585 | if (dim_index == tensor.ndim() - 1) { |
586 | const uint8_t* data_ptr = tensor.raw_data() + offset; |
587 | const int64_t stride = tensor.strides()[dim_index]; |
588 | for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) { |
589 | memcpy(scratch_space + i * elem_size, data_ptr, elem_size); |
590 | data_ptr += stride; |
591 | } |
592 | return dst->Write(scratch_space, elem_size * tensor.shape()[dim_index]); |
593 | } |
594 | for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) { |
595 | RETURN_NOT_OK(WriteStridedTensorData(dim_index + 1, offset, elem_size, tensor, |
596 | scratch_space, dst)); |
597 | offset += tensor.strides()[dim_index]; |
598 | } |
599 | return Status::OK(); |
600 | } |
601 | |
602 | Status GetContiguousTensor(const Tensor& tensor, MemoryPool* pool, |
603 | std::unique_ptr<Tensor>* out) { |
604 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
605 | const int elem_size = type.bit_width() / 8; |
606 | |
607 | std::shared_ptr<Buffer> scratch_space; |
608 | RETURN_NOT_OK(AllocateBuffer(pool, tensor.shape()[tensor.ndim() - 1] * elem_size, |
609 | &scratch_space)); |
610 | |
611 | std::shared_ptr<ResizableBuffer> contiguous_data; |
612 | RETURN_NOT_OK( |
613 | AllocateResizableBuffer(pool, tensor.size() * elem_size, &contiguous_data)); |
614 | |
615 | io::BufferOutputStream stream(contiguous_data); |
616 | RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor, |
617 | scratch_space->mutable_data(), &stream)); |
618 | |
619 | out->reset(new Tensor(tensor.type(), contiguous_data, tensor.shape())); |
620 | |
621 | return Status::OK(); |
622 | } |
623 | |
624 | } // namespace |
625 | |
626 | Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length, |
627 | int64_t* body_length) { |
628 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
629 | const int elem_size = type.bit_width() / 8; |
630 | |
631 | *body_length = tensor.size() * elem_size; |
632 | |
633 | // Tensor metadata accounts for padding |
634 | if (tensor.is_contiguous()) { |
635 | RETURN_NOT_OK(WriteTensorHeader(tensor, dst, metadata_length)); |
636 | auto data = tensor.data(); |
637 | if (data && data->data()) { |
638 | RETURN_NOT_OK(dst->Write(data->data(), *body_length)); |
639 | } else { |
640 | *body_length = 0; |
641 | } |
642 | } else { |
643 | // The tensor written is made contiguous |
644 | Tensor dummy(tensor.type(), nullptr, tensor.shape()); |
645 | RETURN_NOT_OK(WriteTensorHeader(dummy, dst, metadata_length)); |
646 | |
647 | // TODO(wesm): Do we care enough about this temporary allocation to pass in |
648 | // a MemoryPool to this function? |
649 | std::shared_ptr<Buffer> scratch_space; |
650 | RETURN_NOT_OK( |
651 | AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size, &scratch_space)); |
652 | |
653 | RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor, |
654 | scratch_space->mutable_data(), dst)); |
655 | } |
656 | |
657 | return Status::OK(); |
658 | } |
659 | |
660 | Status GetTensorMessage(const Tensor& tensor, MemoryPool* pool, |
661 | std::unique_ptr<Message>* out) { |
662 | const Tensor* tensor_to_write = &tensor; |
663 | std::unique_ptr<Tensor> temp_tensor; |
664 | |
665 | if (!tensor.is_contiguous()) { |
666 | RETURN_NOT_OK(GetContiguousTensor(tensor, pool, &temp_tensor)); |
667 | tensor_to_write = temp_tensor.get(); |
668 | } |
669 | |
670 | std::shared_ptr<Buffer> metadata; |
671 | RETURN_NOT_OK(internal::WriteTensorMessage(*tensor_to_write, 0, &metadata)); |
672 | out->reset(new Message(metadata, tensor_to_write->data())); |
673 | return Status::OK(); |
674 | } |
675 | |
676 | namespace internal { |
677 | |
678 | class SparseTensorSerializer { |
679 | public: |
680 | SparseTensorSerializer(int64_t buffer_start_offset, IpcPayload* out) |
681 | : out_(out), buffer_start_offset_(buffer_start_offset) {} |
682 | |
683 | ~SparseTensorSerializer() = default; |
684 | |
685 | Status VisitSparseIndex(const SparseIndex& sparse_index) { |
686 | switch (sparse_index.format_id()) { |
687 | case SparseTensorFormat::COO: |
688 | RETURN_NOT_OK( |
689 | VisitSparseCOOIndex(checked_cast<const SparseCOOIndex&>(sparse_index))); |
690 | break; |
691 | |
692 | case SparseTensorFormat::CSR: |
693 | RETURN_NOT_OK( |
694 | VisitSparseCSRIndex(checked_cast<const SparseCSRIndex&>(sparse_index))); |
695 | break; |
696 | |
697 | default: |
698 | std::stringstream ss; |
699 | ss << "Unable to convert type: " << sparse_index.ToString() << std::endl; |
700 | return Status::NotImplemented(ss.str()); |
701 | } |
702 | |
703 | return Status::OK(); |
704 | } |
705 | |
706 | Status SerializeMetadata(const SparseTensor& sparse_tensor) { |
707 | return WriteSparseTensorMessage(sparse_tensor, out_->body_length, buffer_meta_, |
708 | &out_->metadata); |
709 | } |
710 | |
711 | Status Assemble(const SparseTensor& sparse_tensor) { |
712 | if (buffer_meta_.size() > 0) { |
713 | buffer_meta_.clear(); |
714 | out_->body_buffers.clear(); |
715 | } |
716 | |
717 | RETURN_NOT_OK(VisitSparseIndex(*sparse_tensor.sparse_index())); |
718 | out_->body_buffers.emplace_back(sparse_tensor.data()); |
719 | |
720 | int64_t offset = buffer_start_offset_; |
721 | buffer_meta_.reserve(out_->body_buffers.size()); |
722 | |
723 | for (size_t i = 0; i < out_->body_buffers.size(); ++i) { |
724 | const Buffer* buffer = out_->body_buffers[i].get(); |
725 | int64_t size = buffer->size(); |
726 | int64_t padding = BitUtil::RoundUpToMultipleOf8(size) - size; |
727 | buffer_meta_.push_back({offset, size + padding}); |
728 | offset += size + padding; |
729 | } |
730 | |
731 | out_->body_length = offset - buffer_start_offset_; |
732 | DCHECK(BitUtil::IsMultipleOf8(out_->body_length)); |
733 | |
734 | return SerializeMetadata(sparse_tensor); |
735 | } |
736 | |
737 | private: |
738 | Status VisitSparseCOOIndex(const SparseCOOIndex& sparse_index) { |
739 | out_->body_buffers.emplace_back(sparse_index.indices()->data()); |
740 | return Status::OK(); |
741 | } |
742 | |
743 | Status VisitSparseCSRIndex(const SparseCSRIndex& sparse_index) { |
744 | out_->body_buffers.emplace_back(sparse_index.indptr()->data()); |
745 | out_->body_buffers.emplace_back(sparse_index.indices()->data()); |
746 | return Status::OK(); |
747 | } |
748 | |
749 | IpcPayload* out_; |
750 | |
751 | std::vector<internal::BufferMetadata> buffer_meta_; |
752 | |
753 | int64_t buffer_start_offset_; |
754 | }; |
755 | |
756 | Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool, |
757 | IpcPayload* out) { |
758 | SparseTensorSerializer writer(0, out); |
759 | return writer.Assemble(sparse_tensor); |
760 | } |
761 | |
762 | } // namespace internal |
763 | |
764 | Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst, |
765 | int32_t* metadata_length, int64_t* body_length, |
766 | MemoryPool* pool) { |
767 | internal::IpcPayload payload; |
768 | internal::SparseTensorSerializer writer(0, &payload); |
769 | RETURN_NOT_OK(writer.Assemble(sparse_tensor)); |
770 | |
771 | *body_length = payload.body_length; |
772 | return internal::WriteIpcPayload(payload, dst, metadata_length); |
773 | } |
774 | |
775 | Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary, |
776 | int64_t buffer_start_offset, io::OutputStream* dst, |
777 | int32_t* metadata_length, int64_t* body_length, MemoryPool* pool) { |
778 | internal::IpcPayload payload; |
779 | internal::DictionaryWriter writer(dictionary_id, pool, buffer_start_offset, |
780 | kMaxNestingDepth, true, &payload); |
781 | RETURN_NOT_OK(writer.Assemble(dictionary)); |
782 | |
783 | // The body size is computed in the payload |
784 | *body_length = payload.body_length; |
785 | return internal::WriteIpcPayload(payload, dst, metadata_length); |
786 | } |
787 | |
788 | Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) { |
789 | // emulates the behavior of Write without actually writing |
790 | int32_t metadata_length = 0; |
791 | int64_t body_length = 0; |
792 | io::MockOutputStream dst; |
793 | RETURN_NOT_OK(WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length, |
794 | default_memory_pool(), kMaxNestingDepth, true)); |
795 | *size = dst.GetExtentBytesWritten(); |
796 | return Status::OK(); |
797 | } |
798 | |
799 | Status GetTensorSize(const Tensor& tensor, int64_t* size) { |
800 | // emulates the behavior of Write without actually writing |
801 | int32_t metadata_length = 0; |
802 | int64_t body_length = 0; |
803 | io::MockOutputStream dst; |
804 | RETURN_NOT_OK(WriteTensor(tensor, &dst, &metadata_length, &body_length)); |
805 | *size = dst.GetExtentBytesWritten(); |
806 | return Status::OK(); |
807 | } |
808 | |
809 | // ---------------------------------------------------------------------- |
810 | |
811 | RecordBatchWriter::~RecordBatchWriter() {} |
812 | |
813 | Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) { |
814 | TableBatchReader reader(table); |
815 | |
816 | if (max_chunksize > 0) { |
817 | reader.set_chunksize(max_chunksize); |
818 | } |
819 | |
820 | std::shared_ptr<RecordBatch> batch; |
821 | while (true) { |
822 | RETURN_NOT_OK(reader.ReadNext(&batch)); |
823 | if (batch == nullptr) { |
824 | break; |
825 | } |
826 | RETURN_NOT_OK(WriteRecordBatch(*batch, true)); |
827 | } |
828 | |
829 | return Status::OK(); |
830 | } |
831 | |
832 | Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); } |
833 | |
834 | // ---------------------------------------------------------------------- |
835 | // Stream writer implementation |
836 | |
837 | class StreamBookKeeper { |
838 | public: |
839 | StreamBookKeeper() : sink_(nullptr), position_(-1) {} |
840 | explicit StreamBookKeeper(io::OutputStream* sink) : sink_(sink), position_(-1) {} |
841 | |
842 | Status UpdatePosition() { return sink_->Tell(&position_); } |
843 | |
844 | Status UpdatePositionCheckAligned() { |
845 | RETURN_NOT_OK(UpdatePosition()); |
846 | DCHECK_EQ(0, position_ % 8) << "Stream is not aligned" ; |
847 | return Status::OK(); |
848 | } |
849 | |
850 | Status Align(int32_t alignment = kArrowIpcAlignment) { |
851 | // Adds padding bytes if necessary to ensure all memory blocks are written on |
852 | // 8-byte (or other alignment) boundaries. |
853 | int64_t remainder = PaddedLength(position_, alignment) - position_; |
854 | if (remainder > 0) { |
855 | return Write(kPaddingBytes, remainder); |
856 | } |
857 | return Status::OK(); |
858 | } |
859 | |
860 | // Write data and update position |
861 | Status Write(const void* data, int64_t nbytes) { |
862 | RETURN_NOT_OK(sink_->Write(data, nbytes)); |
863 | position_ += nbytes; |
864 | return Status::OK(); |
865 | } |
866 | |
867 | protected: |
868 | io::OutputStream* sink_; |
869 | int64_t position_; |
870 | }; |
871 | |
872 | class SchemaWriter : public StreamBookKeeper { |
873 | public: |
874 | SchemaWriter(const Schema& schema, DictionaryMemo* dictionary_memo, MemoryPool* pool, |
875 | io::OutputStream* sink) |
876 | : StreamBookKeeper(sink), |
877 | pool_(pool), |
878 | schema_(schema), |
879 | dictionary_memo_(dictionary_memo) {} |
880 | |
881 | Status WriteSchema() { |
882 | #ifndef NDEBUG |
883 | // Catch bug fixed in ARROW-3236 |
884 | RETURN_NOT_OK(UpdatePositionCheckAligned()); |
885 | #endif |
886 | |
887 | std::shared_ptr<Buffer> schema_fb; |
888 | RETURN_NOT_OK(internal::WriteSchemaMessage(schema_, dictionary_memo_, &schema_fb)); |
889 | |
890 | int32_t metadata_length = 0; |
891 | RETURN_NOT_OK(internal::WriteMessage(*schema_fb, 8, sink_, &metadata_length)); |
892 | RETURN_NOT_OK(UpdatePositionCheckAligned()); |
893 | return Status::OK(); |
894 | } |
895 | |
896 | Status WriteDictionaries(std::vector<FileBlock>* dictionaries) { |
897 | const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary(); |
898 | |
899 | dictionaries->resize(id_to_dictionary.size()); |
900 | |
901 | // TODO(wesm): does sorting by id yield any benefit? |
902 | int dict_index = 0; |
903 | for (const auto& entry : id_to_dictionary) { |
904 | FileBlock* block = &(*dictionaries)[dict_index++]; |
905 | |
906 | block->offset = position_; |
907 | |
908 | // Frame of reference in file format is 0, see ARROW-384 |
909 | const int64_t buffer_start_offset = 0; |
910 | RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, buffer_start_offset, sink_, |
911 | &block->metadata_length, &block->body_length, pool_)); |
912 | RETURN_NOT_OK(UpdatePositionCheckAligned()); |
913 | } |
914 | |
915 | return Status::OK(); |
916 | } |
917 | |
918 | Status Write(std::vector<FileBlock>* dictionaries) { |
919 | RETURN_NOT_OK(WriteSchema()); |
920 | |
921 | // If there are any dictionaries, write them as the next messages |
922 | return WriteDictionaries(dictionaries); |
923 | } |
924 | |
925 | private: |
926 | MemoryPool* pool_; |
927 | const Schema& schema_; |
928 | DictionaryMemo* dictionary_memo_; |
929 | }; |
930 | |
931 | class RecordBatchStreamWriter::RecordBatchStreamWriterImpl : public StreamBookKeeper { |
932 | public: |
933 | RecordBatchStreamWriterImpl(io::OutputStream* sink, |
934 | const std::shared_ptr<Schema>& schema) |
935 | : StreamBookKeeper(sink), |
936 | schema_(schema), |
937 | pool_(default_memory_pool()), |
938 | started_(false) {} |
939 | |
940 | virtual ~RecordBatchStreamWriterImpl() = default; |
941 | |
942 | virtual Status Start() { |
943 | SchemaWriter schema_writer(*schema_, &dictionary_memo_, pool_, sink_); |
944 | RETURN_NOT_OK(schema_writer.Write(&dictionaries_)); |
945 | started_ = true; |
946 | return Status::OK(); |
947 | } |
948 | |
949 | virtual Status Close() { |
950 | // Write the schema if not already written |
951 | // User is responsible for closing the OutputStream |
952 | RETURN_NOT_OK(CheckStarted()); |
953 | |
954 | // Write 0 EOS message |
955 | const int32_t kEos = 0; |
956 | return Write(&kEos, sizeof(int32_t)); |
957 | } |
958 | |
959 | Status CheckStarted() { |
960 | if (!started_) { |
961 | return Start(); |
962 | } |
963 | return Status::OK(); |
964 | } |
965 | |
966 | Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit, FileBlock* block) { |
967 | RETURN_NOT_OK(CheckStarted()); |
968 | RETURN_NOT_OK(UpdatePosition()); |
969 | |
970 | block->offset = position_; |
971 | |
972 | // Frame of reference in file format is 0, see ARROW-384 |
973 | const int64_t buffer_start_offset = 0; |
974 | RETURN_NOT_OK(arrow::ipc::WriteRecordBatch( |
975 | batch, buffer_start_offset, sink_, &block->metadata_length, &block->body_length, |
976 | pool_, kMaxNestingDepth, allow_64bit)); |
977 | RETURN_NOT_OK(UpdatePositionCheckAligned()); |
978 | |
979 | return Status::OK(); |
980 | } |
981 | |
982 | Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { |
983 | // Push an empty FileBlock. Can be written in the footer later |
984 | record_batches_.push_back({0, 0, 0}); |
985 | return WriteRecordBatch(batch, allow_64bit, |
986 | &record_batches_[record_batches_.size() - 1]); |
987 | } |
988 | |
989 | void set_memory_pool(MemoryPool* pool) { pool_ = pool; } |
990 | |
991 | protected: |
992 | std::shared_ptr<Schema> schema_; |
993 | MemoryPool* pool_; |
994 | bool started_; |
995 | |
996 | // When writing out the schema, we keep track of all the dictionaries we |
997 | // encounter, as they must be written out first in the stream |
998 | DictionaryMemo dictionary_memo_; |
999 | |
1000 | std::vector<FileBlock> dictionaries_; |
1001 | std::vector<FileBlock> record_batches_; |
1002 | }; |
1003 | |
1004 | RecordBatchStreamWriter::RecordBatchStreamWriter() {} |
1005 | |
1006 | RecordBatchStreamWriter::~RecordBatchStreamWriter() {} |
1007 | |
1008 | Status RecordBatchStreamWriter::WriteRecordBatch(const RecordBatch& batch, |
1009 | bool allow_64bit) { |
1010 | return impl_->WriteRecordBatch(batch, allow_64bit); |
1011 | } |
1012 | |
1013 | void RecordBatchStreamWriter::set_memory_pool(MemoryPool* pool) { |
1014 | impl_->set_memory_pool(pool); |
1015 | } |
1016 | |
1017 | Status RecordBatchStreamWriter::Open(io::OutputStream* sink, |
1018 | const std::shared_ptr<Schema>& schema, |
1019 | std::shared_ptr<RecordBatchWriter>* out) { |
1020 | // ctor is private |
1021 | auto result = std::shared_ptr<RecordBatchStreamWriter>(new RecordBatchStreamWriter()); |
1022 | result->impl_.reset(new RecordBatchStreamWriterImpl(sink, schema)); |
1023 | *out = result; |
1024 | return Status::OK(); |
1025 | } |
1026 | |
1027 | Status RecordBatchStreamWriter::Close() { return impl_->Close(); } |
1028 | |
1029 | // ---------------------------------------------------------------------- |
1030 | // File writer implementation |
1031 | |
1032 | class RecordBatchFileWriter::RecordBatchFileWriterImpl |
1033 | : public RecordBatchStreamWriter::RecordBatchStreamWriterImpl { |
1034 | public: |
1035 | using BASE = RecordBatchStreamWriter::RecordBatchStreamWriterImpl; |
1036 | |
1037 | RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr<Schema>& schema) |
1038 | : BASE(sink, schema) {} |
1039 | |
1040 | Status Start() override { |
1041 | // ARROW-3236: The initial position -1 needs to be updated to the stream's |
1042 | // current position otherwise an incorrect amount of padding will be |
1043 | // written to new files. |
1044 | RETURN_NOT_OK(UpdatePosition()); |
1045 | |
1046 | // It is only necessary to align to 8-byte boundary at the start of the file |
1047 | RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes))); |
1048 | RETURN_NOT_OK(Align()); |
1049 | |
1050 | // We write the schema at the start of the file (and the end). This also |
1051 | // writes all the dictionaries at the beginning of the file |
1052 | return BASE::Start(); |
1053 | } |
1054 | |
1055 | Status Close() override { |
1056 | // Write the schema if not already written |
1057 | // User is responsible for closing the OutputStream |
1058 | RETURN_NOT_OK(CheckStarted()); |
1059 | |
1060 | // Write metadata |
1061 | RETURN_NOT_OK(UpdatePosition()); |
1062 | |
1063 | int64_t initial_position = position_; |
1064 | RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_, |
1065 | &dictionary_memo_, sink_)); |
1066 | RETURN_NOT_OK(UpdatePosition()); |
1067 | |
1068 | // Write footer length |
1069 | int32_t = static_cast<int32_t>(position_ - initial_position); |
1070 | |
1071 | if (footer_length <= 0) { |
1072 | return Status::Invalid("Invalid file footer" ); |
1073 | } |
1074 | |
1075 | RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t))); |
1076 | |
1077 | // Write magic bytes to end file |
1078 | return Write(kArrowMagicBytes, strlen(kArrowMagicBytes)); |
1079 | } |
1080 | }; |
1081 | |
1082 | RecordBatchFileWriter::RecordBatchFileWriter() {} |
1083 | |
1084 | RecordBatchFileWriter::~RecordBatchFileWriter() {} |
1085 | |
1086 | Status RecordBatchFileWriter::Open(io::OutputStream* sink, |
1087 | const std::shared_ptr<Schema>& schema, |
1088 | std::shared_ptr<RecordBatchWriter>* out) { |
1089 | // ctor is private |
1090 | auto result = std::shared_ptr<RecordBatchFileWriter>(new RecordBatchFileWriter()); |
1091 | result->file_impl_.reset(new RecordBatchFileWriterImpl(sink, schema)); |
1092 | *out = result; |
1093 | return Status::OK(); |
1094 | } |
1095 | |
1096 | Status RecordBatchFileWriter::WriteRecordBatch(const RecordBatch& batch, |
1097 | bool allow_64bit) { |
1098 | return file_impl_->WriteRecordBatch(batch, allow_64bit); |
1099 | } |
1100 | |
1101 | Status RecordBatchFileWriter::Close() { return file_impl_->Close(); } |
1102 | |
1103 | // ---------------------------------------------------------------------- |
1104 | // Serialization public APIs |
1105 | |
1106 | Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, |
1107 | std::shared_ptr<Buffer>* out) { |
1108 | int64_t size = 0; |
1109 | RETURN_NOT_OK(GetRecordBatchSize(batch, &size)); |
1110 | std::shared_ptr<Buffer> buffer; |
1111 | RETURN_NOT_OK(AllocateBuffer(pool, size, &buffer)); |
1112 | |
1113 | io::FixedSizeBufferWriter stream(buffer); |
1114 | RETURN_NOT_OK(SerializeRecordBatch(batch, pool, &stream)); |
1115 | *out = buffer; |
1116 | return Status::OK(); |
1117 | } |
1118 | |
1119 | Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, |
1120 | io::OutputStream* out) { |
1121 | int32_t metadata_length = 0; |
1122 | int64_t body_length = 0; |
1123 | return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, pool, |
1124 | kMaxNestingDepth, true); |
1125 | } |
1126 | |
1127 | Status SerializeSchema(const Schema& schema, MemoryPool* pool, |
1128 | std::shared_ptr<Buffer>* out) { |
1129 | std::shared_ptr<io::BufferOutputStream> stream; |
1130 | RETURN_NOT_OK(io::BufferOutputStream::Create(1024, pool, &stream)); |
1131 | |
1132 | DictionaryMemo memo; |
1133 | SchemaWriter schema_writer(schema, &memo, pool, stream.get()); |
1134 | |
1135 | // Unused |
1136 | std::vector<FileBlock> dictionary_blocks; |
1137 | |
1138 | RETURN_NOT_OK(schema_writer.Write(&dictionary_blocks)); |
1139 | return stream->Finish(out); |
1140 | } |
1141 | |
1142 | } // namespace ipc |
1143 | } // namespace arrow |
1144 | |