1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/ipc/message.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <memory> |
23 | #include <sstream> |
24 | #include <string> |
25 | |
26 | #include <flatbuffers/flatbuffers.h> |
27 | |
28 | #include "arrow/buffer.h" |
29 | #include "arrow/io/interfaces.h" |
30 | #include "arrow/ipc/Message_generated.h" |
31 | #include "arrow/ipc/metadata_internal.h" |
32 | #include "arrow/ipc/util.h" |
33 | #include "arrow/status.h" |
34 | #include "arrow/util/logging.h" |
35 | #include "arrow/util/ubsan.h" |
36 | |
37 | namespace arrow { |
38 | namespace ipc { |
39 | |
40 | class Message::MessageImpl { |
41 | public: |
42 | explicit MessageImpl(const std::shared_ptr<Buffer>& metadata, |
43 | const std::shared_ptr<Buffer>& body) |
44 | : metadata_(metadata), message_(nullptr), body_(body) {} |
45 | |
46 | Status Open() { |
47 | RETURN_NOT_OK( |
48 | internal::VerifyMessage(metadata_->data(), metadata_->size(), &message_)); |
49 | |
50 | // Check that the metadata version is supported |
51 | if (message_->version() < internal::kMinMetadataVersion) { |
52 | return Status::Invalid("Old metadata version not supported" ); |
53 | } |
54 | |
55 | return Status::OK(); |
56 | } |
57 | |
58 | Message::Type type() const { |
59 | switch (message_->header_type()) { |
60 | case flatbuf::MessageHeader_Schema: |
61 | return Message::SCHEMA; |
62 | case flatbuf::MessageHeader_DictionaryBatch: |
63 | return Message::DICTIONARY_BATCH; |
64 | case flatbuf::MessageHeader_RecordBatch: |
65 | return Message::RECORD_BATCH; |
66 | case flatbuf::MessageHeader_Tensor: |
67 | return Message::TENSOR; |
68 | case flatbuf::MessageHeader_SparseTensor: |
69 | return Message::SPARSE_TENSOR; |
70 | default: |
71 | return Message::NONE; |
72 | } |
73 | } |
74 | |
75 | MetadataVersion version() const { |
76 | return internal::GetMetadataVersion(message_->version()); |
77 | } |
78 | |
79 | const void* () const { return message_->header(); } |
80 | |
81 | int64_t body_length() const { return message_->bodyLength(); } |
82 | |
83 | std::shared_ptr<Buffer> body() const { return body_; } |
84 | |
85 | std::shared_ptr<Buffer> metadata() const { return metadata_; } |
86 | |
87 | private: |
88 | // The Flatbuffer metadata |
89 | std::shared_ptr<Buffer> metadata_; |
90 | const flatbuf::Message* message_; |
91 | |
92 | // The message body, if any |
93 | std::shared_ptr<Buffer> body_; |
94 | }; |
95 | |
96 | Message::Message(const std::shared_ptr<Buffer>& metadata, |
97 | const std::shared_ptr<Buffer>& body) { |
98 | impl_.reset(new MessageImpl(metadata, body)); |
99 | } |
100 | |
101 | Status Message::Open(const std::shared_ptr<Buffer>& metadata, |
102 | const std::shared_ptr<Buffer>& body, std::unique_ptr<Message>* out) { |
103 | out->reset(new Message(metadata, body)); |
104 | return (*out)->impl_->Open(); |
105 | } |
106 | |
107 | Message::~Message() {} |
108 | |
109 | std::shared_ptr<Buffer> Message::body() const { return impl_->body(); } |
110 | |
111 | int64_t Message::body_length() const { return impl_->body_length(); } |
112 | |
113 | std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); } |
114 | |
115 | Message::Type Message::type() const { return impl_->type(); } |
116 | |
117 | MetadataVersion Message::metadata_version() const { return impl_->version(); } |
118 | |
119 | const void* Message::() const { return impl_->header(); } |
120 | |
121 | bool Message::Equals(const Message& other) const { |
122 | int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size()); |
123 | |
124 | if (!metadata()->Equals(*other.metadata(), metadata_bytes)) { |
125 | return false; |
126 | } |
127 | |
128 | // Compare bodies, if they have them |
129 | auto this_body = body(); |
130 | auto other_body = other.body(); |
131 | |
132 | const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0); |
133 | const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0); |
134 | |
135 | if (this_has_body && other_has_body) { |
136 | return this_body->Equals(*other_body); |
137 | } else if (this_has_body ^ other_has_body) { |
138 | // One has a body but not the other |
139 | return false; |
140 | } else { |
141 | // Neither has a body |
142 | return true; |
143 | } |
144 | } |
145 | |
146 | Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) { |
147 | if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) { |
148 | // If the metadata memory is not aligned, we copy it here to avoid |
149 | // potential UBSAN issues from Flatbuffers |
150 | RETURN_NOT_OK((*metadata)->Copy(0, (*metadata)->size(), metadata)); |
151 | } |
152 | return Status::OK(); |
153 | } |
154 | |
155 | Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) { |
156 | const flatbuf::Message* fb_message; |
157 | RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message)); |
158 | *body_length = fb_message->bodyLength(); |
159 | return Status::OK(); |
160 | } |
161 | |
162 | Status Message::ReadFrom(std::shared_ptr<Buffer> metadata, io::InputStream* stream, |
163 | std::unique_ptr<Message>* out) { |
164 | RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); |
165 | int64_t body_length = -1; |
166 | RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); |
167 | |
168 | std::shared_ptr<Buffer> body; |
169 | RETURN_NOT_OK(stream->Read(body_length, &body)); |
170 | if (body->size() < body_length) { |
171 | return Status::IOError("Expected to be able to read " , body_length, |
172 | " bytes for message body, got " , body->size()); |
173 | } |
174 | |
175 | return Message::Open(metadata, body, out); |
176 | } |
177 | |
178 | Status Message::ReadFrom(const int64_t offset, std::shared_ptr<Buffer> metadata, |
179 | io::RandomAccessFile* file, std::unique_ptr<Message>* out) { |
180 | RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); |
181 | int64_t body_length = -1; |
182 | RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); |
183 | |
184 | std::shared_ptr<Buffer> body; |
185 | RETURN_NOT_OK(file->ReadAt(offset, body_length, &body)); |
186 | if (body->size() < body_length) { |
187 | return Status::IOError("Expected to be able to read " , body_length, |
188 | " bytes for message body, got " , body->size()); |
189 | } |
190 | |
191 | return Message::Open(metadata, body, out); |
192 | } |
193 | |
194 | Status WritePadding(io::OutputStream* stream, int64_t nbytes) { |
195 | while (nbytes > 0) { |
196 | const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment); |
197 | RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write)); |
198 | nbytes -= bytes_to_write; |
199 | } |
200 | return Status::OK(); |
201 | } |
202 | |
203 | Status Message::SerializeTo(io::OutputStream* stream, const IpcOptions& options, |
204 | int64_t* output_length) const { |
205 | int32_t metadata_length = 0; |
206 | RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length)); |
207 | |
208 | *output_length = metadata_length; |
209 | |
210 | auto body_buffer = body(); |
211 | if (body_buffer) { |
212 | RETURN_NOT_OK(stream->Write(body_buffer)); |
213 | *output_length += body_buffer->size(); |
214 | |
215 | DCHECK_GE(this->body_length(), body_buffer->size()); |
216 | |
217 | int64_t remainder = this->body_length() - body_buffer->size(); |
218 | RETURN_NOT_OK(WritePadding(stream, remainder)); |
219 | *output_length += remainder; |
220 | } |
221 | return Status::OK(); |
222 | } |
223 | |
224 | bool Message::Verify() const { |
225 | const flatbuf::Message* unused; |
226 | return internal::VerifyMessage(metadata()->data(), metadata()->size(), &unused).ok(); |
227 | } |
228 | |
229 | std::string FormatMessageType(Message::Type type) { |
230 | switch (type) { |
231 | case Message::SCHEMA: |
232 | return "schema" ; |
233 | case Message::RECORD_BATCH: |
234 | return "record batch" ; |
235 | case Message::DICTIONARY_BATCH: |
236 | return "dictionary" ; |
237 | default: |
238 | break; |
239 | } |
240 | return "unknown" ; |
241 | } |
242 | |
243 | Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file, |
244 | std::unique_ptr<Message>* message) { |
245 | ARROW_CHECK_GT(static_cast<size_t>(metadata_length), sizeof(int32_t)) |
246 | << "metadata_length should be at least 4" ; |
247 | |
248 | std::shared_ptr<Buffer> buffer; |
249 | RETURN_NOT_OK(file->ReadAt(offset, metadata_length, &buffer)); |
250 | |
251 | if (buffer->size() < metadata_length) { |
252 | return Status::Invalid("Expected to read " , metadata_length, |
253 | " metadata bytes but got " , buffer->size()); |
254 | } |
255 | |
256 | const int32_t continuation = util::SafeLoadAs<int32_t>(buffer->data()); |
257 | |
258 | // The size of the Flatbuffer including padding |
259 | int32_t flatbuffer_length = -1; |
260 | int32_t prefix_size = -1; |
261 | if (continuation == internal::kIpcContinuationToken) { |
262 | if (metadata_length < 8) { |
263 | return Status::Invalid( |
264 | "Corrupted IPC message, had continuation token " |
265 | " but length " , |
266 | metadata_length); |
267 | } |
268 | |
269 | // Valid IPC message, parse the message length now |
270 | flatbuffer_length = util::SafeLoadAs<int32_t>(buffer->data() + 4); |
271 | prefix_size = 8; |
272 | } else { |
273 | // ARROW-6314: Backwards compatibility for reading old IPC |
274 | // messages produced prior to version 0.15.0 |
275 | flatbuffer_length = continuation; |
276 | prefix_size = 4; |
277 | } |
278 | |
279 | if (flatbuffer_length == 0) { |
280 | // EOS |
281 | *message = nullptr; |
282 | return Status::OK(); |
283 | } |
284 | |
285 | if (flatbuffer_length + prefix_size != metadata_length) { |
286 | return Status::Invalid("flatbuffer size " , flatbuffer_length, |
287 | " invalid. File offset: " , offset, |
288 | ", metadata length: " , metadata_length); |
289 | } |
290 | |
291 | std::shared_ptr<Buffer> metadata = |
292 | SliceBuffer(buffer, prefix_size, buffer->size() - prefix_size); |
293 | return Message::ReadFrom(offset + metadata_length, metadata, file, message); |
294 | } |
295 | |
296 | Status AlignStream(io::InputStream* stream, int32_t alignment) { |
297 | int64_t position = -1; |
298 | RETURN_NOT_OK(stream->Tell(&position)); |
299 | return stream->Advance(PaddedLength(position, alignment) - position); |
300 | } |
301 | |
302 | Status AlignStream(io::OutputStream* stream, int32_t alignment) { |
303 | int64_t position = -1; |
304 | RETURN_NOT_OK(stream->Tell(&position)); |
305 | int64_t remainder = PaddedLength(position, alignment) - position; |
306 | if (remainder > 0) { |
307 | return stream->Write(kPaddingBytes, remainder); |
308 | } |
309 | return Status::OK(); |
310 | } |
311 | |
312 | Status CheckAligned(io::FileInterface* stream, int32_t alignment) { |
313 | int64_t current_position; |
314 | ARROW_RETURN_NOT_OK(stream->Tell(¤t_position)); |
315 | if (current_position % alignment != 0) { |
316 | return Status::Invalid("Stream is not aligned pos: " , current_position, |
317 | " alignment: " , alignment); |
318 | } else { |
319 | return Status::OK(); |
320 | } |
321 | } |
322 | |
323 | namespace { |
324 | |
325 | Status ReadMessage(io::InputStream* file, MemoryPool* pool, bool copy_metadata, |
326 | std::unique_ptr<Message>* message) { |
327 | int32_t continuation = 0; |
328 | int64_t bytes_read = 0; |
329 | RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, |
330 | reinterpret_cast<uint8_t*>(&continuation))); |
331 | |
332 | if (bytes_read == 0) { |
333 | // EOS without indication |
334 | *message = nullptr; |
335 | return Status::OK(); |
336 | } else if (bytes_read != sizeof(int32_t)) { |
337 | return Status::Invalid("Corrupted message, only " , bytes_read, " bytes available" ); |
338 | } |
339 | |
340 | int32_t flatbuffer_length = -1; |
341 | if (continuation == internal::kIpcContinuationToken) { |
342 | // Valid IPC message, read the message length now |
343 | RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, |
344 | reinterpret_cast<uint8_t*>(&flatbuffer_length))); |
345 | } else { |
346 | // ARROW-6314: Backwards compatibility for reading old IPC |
347 | // messages produced prior to version 0.15.0 |
348 | flatbuffer_length = continuation; |
349 | } |
350 | |
351 | if (flatbuffer_length == 0) { |
352 | // EOS |
353 | *message = nullptr; |
354 | return Status::OK(); |
355 | } |
356 | |
357 | std::shared_ptr<Buffer> metadata; |
358 | if (copy_metadata) { |
359 | DCHECK_NE(pool, nullptr); |
360 | RETURN_NOT_OK(AllocateBuffer(pool, flatbuffer_length, &metadata)); |
361 | RETURN_NOT_OK(file->Read(flatbuffer_length, &bytes_read, metadata->mutable_data())); |
362 | } else { |
363 | RETURN_NOT_OK(file->Read(flatbuffer_length, &metadata)); |
364 | bytes_read = metadata->size(); |
365 | } |
366 | if (bytes_read != flatbuffer_length) { |
367 | return Status::Invalid("Expected to read " , flatbuffer_length, |
368 | " metadata bytes, but " , "only read " , bytes_read); |
369 | } |
370 | |
371 | return Message::ReadFrom(metadata, file, message); |
372 | } |
373 | |
374 | } // namespace |
375 | |
376 | Status ReadMessage(io::InputStream* file, std::unique_ptr<Message>* out) { |
377 | return ReadMessage(file, default_memory_pool(), /*copy_metadata=*/false, out); |
378 | } |
379 | |
380 | Status ReadMessageCopy(io::InputStream* file, MemoryPool* pool, |
381 | std::unique_ptr<Message>* out) { |
382 | return ReadMessage(file, pool, /*copy_metadata=*/true, out); |
383 | } |
384 | |
385 | Status WriteMessage(const Buffer& message, const IpcOptions& options, |
386 | io::OutputStream* file, int32_t* message_length) { |
387 | const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8; |
388 | const int32_t flatbuffer_size = static_cast<int32_t>(message.size()); |
389 | |
390 | int32_t padded_message_length = static_cast<int32_t>( |
391 | PaddedLength(flatbuffer_size + prefix_size, options.alignment)); |
392 | |
393 | int32_t padding = padded_message_length - flatbuffer_size - prefix_size; |
394 | |
395 | // The returned message size includes the length prefix, the flatbuffer, |
396 | // plus padding |
397 | *message_length = padded_message_length; |
398 | |
399 | // ARROW-6314: Write continuation / padding token |
400 | if (!options.write_legacy_ipc_format) { |
401 | RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t))); |
402 | } |
403 | |
404 | // Write the flatbuffer size prefix including padding |
405 | int32_t padded_flatbuffer_size = padded_message_length - prefix_size; |
406 | RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t))); |
407 | |
408 | // Write the flatbuffer |
409 | RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size)); |
410 | if (padding > 0) { |
411 | RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); |
412 | } |
413 | |
414 | return Status::OK(); |
415 | } |
416 | |
417 | // ---------------------------------------------------------------------- |
418 | // Implement InputStream message reader |
419 | |
420 | /// \brief Implementation of MessageReader that reads from InputStream |
421 | class InputStreamMessageReader : public MessageReader { |
422 | public: |
423 | explicit InputStreamMessageReader(io::InputStream* stream) : stream_(stream) {} |
424 | |
425 | explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream) |
426 | : InputStreamMessageReader(owned_stream.get()) { |
427 | owned_stream_ = owned_stream; |
428 | } |
429 | |
430 | ~InputStreamMessageReader() {} |
431 | |
432 | Status ReadNextMessage(std::unique_ptr<Message>* message) { |
433 | return ReadMessage(stream_, message); |
434 | } |
435 | |
436 | private: |
437 | io::InputStream* stream_; |
438 | std::shared_ptr<io::InputStream> owned_stream_; |
439 | }; |
440 | |
441 | std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) { |
442 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream)); |
443 | } |
444 | |
445 | std::unique_ptr<MessageReader> MessageReader::Open( |
446 | const std::shared_ptr<io::InputStream>& owned_stream) { |
447 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream)); |
448 | } |
449 | |
450 | } // namespace ipc |
451 | } // namespace arrow |
452 | |