| 1 | // Licensed to the Apache Software Foundation (ASF) under one |
| 2 | // or more contributor license agreements. See the NOTICE file |
| 3 | // distributed with this work for additional information |
| 4 | // regarding copyright ownership. The ASF licenses this file |
| 5 | // to you under the Apache License, Version 2.0 (the |
| 6 | // "License"); you may not use this file except in compliance |
| 7 | // with the License. You may obtain a copy of the License at |
| 8 | // |
| 9 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | // |
| 11 | // Unless required by applicable law or agreed to in writing, |
| 12 | // software distributed under the License is distributed on an |
| 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | // KIND, either express or implied. See the License for the |
| 15 | // specific language governing permissions and limitations |
| 16 | // under the License. |
| 17 | |
| 18 | #include "arrow/ipc/message.h" |
| 19 | |
| 20 | #include <algorithm> |
| 21 | #include <cstdint> |
| 22 | #include <memory> |
| 23 | #include <sstream> |
| 24 | #include <string> |
| 25 | |
| 26 | #include <flatbuffers/flatbuffers.h> |
| 27 | |
| 28 | #include "arrow/buffer.h" |
| 29 | #include "arrow/io/interfaces.h" |
| 30 | #include "arrow/ipc/Message_generated.h" |
| 31 | #include "arrow/ipc/metadata_internal.h" |
| 32 | #include "arrow/ipc/util.h" |
| 33 | #include "arrow/status.h" |
| 34 | #include "arrow/util/logging.h" |
| 35 | #include "arrow/util/ubsan.h" |
| 36 | |
| 37 | namespace arrow { |
| 38 | namespace ipc { |
| 39 | |
| 40 | class Message::MessageImpl { |
| 41 | public: |
| 42 | explicit MessageImpl(const std::shared_ptr<Buffer>& metadata, |
| 43 | const std::shared_ptr<Buffer>& body) |
| 44 | : metadata_(metadata), message_(nullptr), body_(body) {} |
| 45 | |
| 46 | Status Open() { |
| 47 | RETURN_NOT_OK( |
| 48 | internal::VerifyMessage(metadata_->data(), metadata_->size(), &message_)); |
| 49 | |
| 50 | // Check that the metadata version is supported |
| 51 | if (message_->version() < internal::kMinMetadataVersion) { |
| 52 | return Status::Invalid("Old metadata version not supported" ); |
| 53 | } |
| 54 | |
| 55 | return Status::OK(); |
| 56 | } |
| 57 | |
| 58 | Message::Type type() const { |
| 59 | switch (message_->header_type()) { |
| 60 | case flatbuf::MessageHeader_Schema: |
| 61 | return Message::SCHEMA; |
| 62 | case flatbuf::MessageHeader_DictionaryBatch: |
| 63 | return Message::DICTIONARY_BATCH; |
| 64 | case flatbuf::MessageHeader_RecordBatch: |
| 65 | return Message::RECORD_BATCH; |
| 66 | case flatbuf::MessageHeader_Tensor: |
| 67 | return Message::TENSOR; |
| 68 | case flatbuf::MessageHeader_SparseTensor: |
| 69 | return Message::SPARSE_TENSOR; |
| 70 | default: |
| 71 | return Message::NONE; |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | MetadataVersion version() const { |
| 76 | return internal::GetMetadataVersion(message_->version()); |
| 77 | } |
| 78 | |
| 79 | const void* () const { return message_->header(); } |
| 80 | |
| 81 | int64_t body_length() const { return message_->bodyLength(); } |
| 82 | |
| 83 | std::shared_ptr<Buffer> body() const { return body_; } |
| 84 | |
| 85 | std::shared_ptr<Buffer> metadata() const { return metadata_; } |
| 86 | |
| 87 | private: |
| 88 | // The Flatbuffer metadata |
| 89 | std::shared_ptr<Buffer> metadata_; |
| 90 | const flatbuf::Message* message_; |
| 91 | |
| 92 | // The message body, if any |
| 93 | std::shared_ptr<Buffer> body_; |
| 94 | }; |
| 95 | |
| 96 | Message::Message(const std::shared_ptr<Buffer>& metadata, |
| 97 | const std::shared_ptr<Buffer>& body) { |
| 98 | impl_.reset(new MessageImpl(metadata, body)); |
| 99 | } |
| 100 | |
| 101 | Status Message::Open(const std::shared_ptr<Buffer>& metadata, |
| 102 | const std::shared_ptr<Buffer>& body, std::unique_ptr<Message>* out) { |
| 103 | out->reset(new Message(metadata, body)); |
| 104 | return (*out)->impl_->Open(); |
| 105 | } |
| 106 | |
| 107 | Message::~Message() {} |
| 108 | |
| 109 | std::shared_ptr<Buffer> Message::body() const { return impl_->body(); } |
| 110 | |
| 111 | int64_t Message::body_length() const { return impl_->body_length(); } |
| 112 | |
| 113 | std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); } |
| 114 | |
| 115 | Message::Type Message::type() const { return impl_->type(); } |
| 116 | |
| 117 | MetadataVersion Message::metadata_version() const { return impl_->version(); } |
| 118 | |
| 119 | const void* Message::() const { return impl_->header(); } |
| 120 | |
| 121 | bool Message::Equals(const Message& other) const { |
| 122 | int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size()); |
| 123 | |
| 124 | if (!metadata()->Equals(*other.metadata(), metadata_bytes)) { |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | // Compare bodies, if they have them |
| 129 | auto this_body = body(); |
| 130 | auto other_body = other.body(); |
| 131 | |
| 132 | const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0); |
| 133 | const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0); |
| 134 | |
| 135 | if (this_has_body && other_has_body) { |
| 136 | return this_body->Equals(*other_body); |
| 137 | } else if (this_has_body ^ other_has_body) { |
| 138 | // One has a body but not the other |
| 139 | return false; |
| 140 | } else { |
| 141 | // Neither has a body |
| 142 | return true; |
| 143 | } |
| 144 | } |
| 145 | |
| 146 | Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) { |
| 147 | if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) { |
| 148 | // If the metadata memory is not aligned, we copy it here to avoid |
| 149 | // potential UBSAN issues from Flatbuffers |
| 150 | RETURN_NOT_OK((*metadata)->Copy(0, (*metadata)->size(), metadata)); |
| 151 | } |
| 152 | return Status::OK(); |
| 153 | } |
| 154 | |
| 155 | Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) { |
| 156 | const flatbuf::Message* fb_message; |
| 157 | RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message)); |
| 158 | *body_length = fb_message->bodyLength(); |
| 159 | return Status::OK(); |
| 160 | } |
| 161 | |
| 162 | Status Message::ReadFrom(std::shared_ptr<Buffer> metadata, io::InputStream* stream, |
| 163 | std::unique_ptr<Message>* out) { |
| 164 | RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); |
| 165 | int64_t body_length = -1; |
| 166 | RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); |
| 167 | |
| 168 | std::shared_ptr<Buffer> body; |
| 169 | RETURN_NOT_OK(stream->Read(body_length, &body)); |
| 170 | if (body->size() < body_length) { |
| 171 | return Status::IOError("Expected to be able to read " , body_length, |
| 172 | " bytes for message body, got " , body->size()); |
| 173 | } |
| 174 | |
| 175 | return Message::Open(metadata, body, out); |
| 176 | } |
| 177 | |
| 178 | Status Message::ReadFrom(const int64_t offset, std::shared_ptr<Buffer> metadata, |
| 179 | io::RandomAccessFile* file, std::unique_ptr<Message>* out) { |
| 180 | RETURN_NOT_OK(MaybeAlignMetadata(&metadata)); |
| 181 | int64_t body_length = -1; |
| 182 | RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata, &body_length)); |
| 183 | |
| 184 | std::shared_ptr<Buffer> body; |
| 185 | RETURN_NOT_OK(file->ReadAt(offset, body_length, &body)); |
| 186 | if (body->size() < body_length) { |
| 187 | return Status::IOError("Expected to be able to read " , body_length, |
| 188 | " bytes for message body, got " , body->size()); |
| 189 | } |
| 190 | |
| 191 | return Message::Open(metadata, body, out); |
| 192 | } |
| 193 | |
| 194 | Status WritePadding(io::OutputStream* stream, int64_t nbytes) { |
| 195 | while (nbytes > 0) { |
| 196 | const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment); |
| 197 | RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write)); |
| 198 | nbytes -= bytes_to_write; |
| 199 | } |
| 200 | return Status::OK(); |
| 201 | } |
| 202 | |
| 203 | Status Message::SerializeTo(io::OutputStream* stream, const IpcOptions& options, |
| 204 | int64_t* output_length) const { |
| 205 | int32_t metadata_length = 0; |
| 206 | RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length)); |
| 207 | |
| 208 | *output_length = metadata_length; |
| 209 | |
| 210 | auto body_buffer = body(); |
| 211 | if (body_buffer) { |
| 212 | RETURN_NOT_OK(stream->Write(body_buffer)); |
| 213 | *output_length += body_buffer->size(); |
| 214 | |
| 215 | DCHECK_GE(this->body_length(), body_buffer->size()); |
| 216 | |
| 217 | int64_t remainder = this->body_length() - body_buffer->size(); |
| 218 | RETURN_NOT_OK(WritePadding(stream, remainder)); |
| 219 | *output_length += remainder; |
| 220 | } |
| 221 | return Status::OK(); |
| 222 | } |
| 223 | |
| 224 | bool Message::Verify() const { |
| 225 | const flatbuf::Message* unused; |
| 226 | return internal::VerifyMessage(metadata()->data(), metadata()->size(), &unused).ok(); |
| 227 | } |
| 228 | |
| 229 | std::string FormatMessageType(Message::Type type) { |
| 230 | switch (type) { |
| 231 | case Message::SCHEMA: |
| 232 | return "schema" ; |
| 233 | case Message::RECORD_BATCH: |
| 234 | return "record batch" ; |
| 235 | case Message::DICTIONARY_BATCH: |
| 236 | return "dictionary" ; |
| 237 | default: |
| 238 | break; |
| 239 | } |
| 240 | return "unknown" ; |
| 241 | } |
| 242 | |
| 243 | Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file, |
| 244 | std::unique_ptr<Message>* message) { |
| 245 | ARROW_CHECK_GT(static_cast<size_t>(metadata_length), sizeof(int32_t)) |
| 246 | << "metadata_length should be at least 4" ; |
| 247 | |
| 248 | std::shared_ptr<Buffer> buffer; |
| 249 | RETURN_NOT_OK(file->ReadAt(offset, metadata_length, &buffer)); |
| 250 | |
| 251 | if (buffer->size() < metadata_length) { |
| 252 | return Status::Invalid("Expected to read " , metadata_length, |
| 253 | " metadata bytes but got " , buffer->size()); |
| 254 | } |
| 255 | |
| 256 | const int32_t continuation = util::SafeLoadAs<int32_t>(buffer->data()); |
| 257 | |
| 258 | // The size of the Flatbuffer including padding |
| 259 | int32_t flatbuffer_length = -1; |
| 260 | int32_t prefix_size = -1; |
| 261 | if (continuation == internal::kIpcContinuationToken) { |
| 262 | if (metadata_length < 8) { |
| 263 | return Status::Invalid( |
| 264 | "Corrupted IPC message, had continuation token " |
| 265 | " but length " , |
| 266 | metadata_length); |
| 267 | } |
| 268 | |
| 269 | // Valid IPC message, parse the message length now |
| 270 | flatbuffer_length = util::SafeLoadAs<int32_t>(buffer->data() + 4); |
| 271 | prefix_size = 8; |
| 272 | } else { |
| 273 | // ARROW-6314: Backwards compatibility for reading old IPC |
| 274 | // messages produced prior to version 0.15.0 |
| 275 | flatbuffer_length = continuation; |
| 276 | prefix_size = 4; |
| 277 | } |
| 278 | |
| 279 | if (flatbuffer_length == 0) { |
| 280 | // EOS |
| 281 | *message = nullptr; |
| 282 | return Status::OK(); |
| 283 | } |
| 284 | |
| 285 | if (flatbuffer_length + prefix_size != metadata_length) { |
| 286 | return Status::Invalid("flatbuffer size " , flatbuffer_length, |
| 287 | " invalid. File offset: " , offset, |
| 288 | ", metadata length: " , metadata_length); |
| 289 | } |
| 290 | |
| 291 | std::shared_ptr<Buffer> metadata = |
| 292 | SliceBuffer(buffer, prefix_size, buffer->size() - prefix_size); |
| 293 | return Message::ReadFrom(offset + metadata_length, metadata, file, message); |
| 294 | } |
| 295 | |
| 296 | Status AlignStream(io::InputStream* stream, int32_t alignment) { |
| 297 | int64_t position = -1; |
| 298 | RETURN_NOT_OK(stream->Tell(&position)); |
| 299 | return stream->Advance(PaddedLength(position, alignment) - position); |
| 300 | } |
| 301 | |
| 302 | Status AlignStream(io::OutputStream* stream, int32_t alignment) { |
| 303 | int64_t position = -1; |
| 304 | RETURN_NOT_OK(stream->Tell(&position)); |
| 305 | int64_t remainder = PaddedLength(position, alignment) - position; |
| 306 | if (remainder > 0) { |
| 307 | return stream->Write(kPaddingBytes, remainder); |
| 308 | } |
| 309 | return Status::OK(); |
| 310 | } |
| 311 | |
| 312 | Status CheckAligned(io::FileInterface* stream, int32_t alignment) { |
| 313 | int64_t current_position; |
| 314 | ARROW_RETURN_NOT_OK(stream->Tell(¤t_position)); |
| 315 | if (current_position % alignment != 0) { |
| 316 | return Status::Invalid("Stream is not aligned pos: " , current_position, |
| 317 | " alignment: " , alignment); |
| 318 | } else { |
| 319 | return Status::OK(); |
| 320 | } |
| 321 | } |
| 322 | |
| 323 | namespace { |
| 324 | |
| 325 | Status ReadMessage(io::InputStream* file, MemoryPool* pool, bool copy_metadata, |
| 326 | std::unique_ptr<Message>* message) { |
| 327 | int32_t continuation = 0; |
| 328 | int64_t bytes_read = 0; |
| 329 | RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, |
| 330 | reinterpret_cast<uint8_t*>(&continuation))); |
| 331 | |
| 332 | if (bytes_read == 0) { |
| 333 | // EOS without indication |
| 334 | *message = nullptr; |
| 335 | return Status::OK(); |
| 336 | } else if (bytes_read != sizeof(int32_t)) { |
| 337 | return Status::Invalid("Corrupted message, only " , bytes_read, " bytes available" ); |
| 338 | } |
| 339 | |
| 340 | int32_t flatbuffer_length = -1; |
| 341 | if (continuation == internal::kIpcContinuationToken) { |
| 342 | // Valid IPC message, read the message length now |
| 343 | RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, |
| 344 | reinterpret_cast<uint8_t*>(&flatbuffer_length))); |
| 345 | } else { |
| 346 | // ARROW-6314: Backwards compatibility for reading old IPC |
| 347 | // messages produced prior to version 0.15.0 |
| 348 | flatbuffer_length = continuation; |
| 349 | } |
| 350 | |
| 351 | if (flatbuffer_length == 0) { |
| 352 | // EOS |
| 353 | *message = nullptr; |
| 354 | return Status::OK(); |
| 355 | } |
| 356 | |
| 357 | std::shared_ptr<Buffer> metadata; |
| 358 | if (copy_metadata) { |
| 359 | DCHECK_NE(pool, nullptr); |
| 360 | RETURN_NOT_OK(AllocateBuffer(pool, flatbuffer_length, &metadata)); |
| 361 | RETURN_NOT_OK(file->Read(flatbuffer_length, &bytes_read, metadata->mutable_data())); |
| 362 | } else { |
| 363 | RETURN_NOT_OK(file->Read(flatbuffer_length, &metadata)); |
| 364 | bytes_read = metadata->size(); |
| 365 | } |
| 366 | if (bytes_read != flatbuffer_length) { |
| 367 | return Status::Invalid("Expected to read " , flatbuffer_length, |
| 368 | " metadata bytes, but " , "only read " , bytes_read); |
| 369 | } |
| 370 | |
| 371 | return Message::ReadFrom(metadata, file, message); |
| 372 | } |
| 373 | |
| 374 | } // namespace |
| 375 | |
| 376 | Status ReadMessage(io::InputStream* file, std::unique_ptr<Message>* out) { |
| 377 | return ReadMessage(file, default_memory_pool(), /*copy_metadata=*/false, out); |
| 378 | } |
| 379 | |
| 380 | Status ReadMessageCopy(io::InputStream* file, MemoryPool* pool, |
| 381 | std::unique_ptr<Message>* out) { |
| 382 | return ReadMessage(file, pool, /*copy_metadata=*/true, out); |
| 383 | } |
| 384 | |
| 385 | Status WriteMessage(const Buffer& message, const IpcOptions& options, |
| 386 | io::OutputStream* file, int32_t* message_length) { |
| 387 | const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8; |
| 388 | const int32_t flatbuffer_size = static_cast<int32_t>(message.size()); |
| 389 | |
| 390 | int32_t padded_message_length = static_cast<int32_t>( |
| 391 | PaddedLength(flatbuffer_size + prefix_size, options.alignment)); |
| 392 | |
| 393 | int32_t padding = padded_message_length - flatbuffer_size - prefix_size; |
| 394 | |
| 395 | // The returned message size includes the length prefix, the flatbuffer, |
| 396 | // plus padding |
| 397 | *message_length = padded_message_length; |
| 398 | |
| 399 | // ARROW-6314: Write continuation / padding token |
| 400 | if (!options.write_legacy_ipc_format) { |
| 401 | RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t))); |
| 402 | } |
| 403 | |
| 404 | // Write the flatbuffer size prefix including padding |
| 405 | int32_t padded_flatbuffer_size = padded_message_length - prefix_size; |
| 406 | RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t))); |
| 407 | |
| 408 | // Write the flatbuffer |
| 409 | RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size)); |
| 410 | if (padding > 0) { |
| 411 | RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); |
| 412 | } |
| 413 | |
| 414 | return Status::OK(); |
| 415 | } |
| 416 | |
| 417 | // ---------------------------------------------------------------------- |
| 418 | // Implement InputStream message reader |
| 419 | |
| 420 | /// \brief Implementation of MessageReader that reads from InputStream |
| 421 | class InputStreamMessageReader : public MessageReader { |
| 422 | public: |
| 423 | explicit InputStreamMessageReader(io::InputStream* stream) : stream_(stream) {} |
| 424 | |
| 425 | explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream) |
| 426 | : InputStreamMessageReader(owned_stream.get()) { |
| 427 | owned_stream_ = owned_stream; |
| 428 | } |
| 429 | |
| 430 | ~InputStreamMessageReader() {} |
| 431 | |
| 432 | Status ReadNextMessage(std::unique_ptr<Message>* message) { |
| 433 | return ReadMessage(stream_, message); |
| 434 | } |
| 435 | |
| 436 | private: |
| 437 | io::InputStream* stream_; |
| 438 | std::shared_ptr<io::InputStream> owned_stream_; |
| 439 | }; |
| 440 | |
| 441 | std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) { |
| 442 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream)); |
| 443 | } |
| 444 | |
| 445 | std::unique_ptr<MessageReader> MessageReader::Open( |
| 446 | const std::shared_ptr<io::InputStream>& owned_stream) { |
| 447 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream)); |
| 448 | } |
| 449 | |
| 450 | } // namespace ipc |
| 451 | } // namespace arrow |
| 452 | |