1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/ipc/message.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <memory> |
23 | #include <sstream> |
24 | #include <string> |
25 | |
26 | #include <flatbuffers/flatbuffers.h> |
27 | |
28 | #include "arrow/buffer.h" |
29 | #include "arrow/io/interfaces.h" |
30 | #include "arrow/ipc/Message_generated.h" |
31 | #include "arrow/ipc/metadata-internal.h" |
32 | #include "arrow/ipc/util.h" |
33 | #include "arrow/status.h" |
34 | #include "arrow/util/logging.h" |
35 | |
36 | namespace arrow { |
37 | namespace ipc { |
38 | |
39 | class Message::MessageImpl { |
40 | public: |
41 | explicit MessageImpl(const std::shared_ptr<Buffer>& metadata, |
42 | const std::shared_ptr<Buffer>& body) |
43 | : metadata_(metadata), message_(nullptr), body_(body) {} |
44 | |
45 | Status Open() { |
46 | message_ = flatbuf::GetMessage(metadata_->data()); |
47 | |
48 | // Check that the metadata version is supported |
49 | if (message_->version() < internal::kMinMetadataVersion) { |
50 | return Status::Invalid("Old metadata version not supported" ); |
51 | } |
52 | |
53 | return Status::OK(); |
54 | } |
55 | |
56 | Message::Type type() const { |
57 | switch (message_->header_type()) { |
58 | case flatbuf::MessageHeader_Schema: |
59 | return Message::SCHEMA; |
60 | case flatbuf::MessageHeader_DictionaryBatch: |
61 | return Message::DICTIONARY_BATCH; |
62 | case flatbuf::MessageHeader_RecordBatch: |
63 | return Message::RECORD_BATCH; |
64 | case flatbuf::MessageHeader_Tensor: |
65 | return Message::TENSOR; |
66 | case flatbuf::MessageHeader_SparseTensor: |
67 | return Message::SPARSE_TENSOR; |
68 | default: |
69 | return Message::NONE; |
70 | } |
71 | } |
72 | |
73 | MetadataVersion version() const { |
74 | return internal::GetMetadataVersion(message_->version()); |
75 | } |
76 | |
77 | const void* () const { return message_->header(); } |
78 | |
79 | int64_t body_length() const { return message_->bodyLength(); } |
80 | |
81 | std::shared_ptr<Buffer> body() const { return body_; } |
82 | |
83 | std::shared_ptr<Buffer> metadata() const { return metadata_; } |
84 | |
85 | private: |
86 | // The Flatbuffer metadata |
87 | std::shared_ptr<Buffer> metadata_; |
88 | const flatbuf::Message* message_; |
89 | |
90 | // The message body, if any |
91 | std::shared_ptr<Buffer> body_; |
92 | }; |
93 | |
94 | Message::Message(const std::shared_ptr<Buffer>& metadata, |
95 | const std::shared_ptr<Buffer>& body) { |
96 | impl_.reset(new MessageImpl(metadata, body)); |
97 | } |
98 | |
99 | Status Message::Open(const std::shared_ptr<Buffer>& metadata, |
100 | const std::shared_ptr<Buffer>& body, std::unique_ptr<Message>* out) { |
101 | out->reset(new Message(metadata, body)); |
102 | return (*out)->impl_->Open(); |
103 | } |
104 | |
105 | Message::~Message() {} |
106 | |
107 | std::shared_ptr<Buffer> Message::body() const { return impl_->body(); } |
108 | |
109 | int64_t Message::body_length() const { return impl_->body_length(); } |
110 | |
111 | std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); } |
112 | |
113 | Message::Type Message::type() const { return impl_->type(); } |
114 | |
115 | MetadataVersion Message::metadata_version() const { return impl_->version(); } |
116 | |
117 | const void* Message::() const { return impl_->header(); } |
118 | |
119 | bool Message::Equals(const Message& other) const { |
120 | int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size()); |
121 | |
122 | if (!metadata()->Equals(*other.metadata(), metadata_bytes)) { |
123 | return false; |
124 | } |
125 | |
126 | // Compare bodies, if they have them |
127 | auto this_body = body(); |
128 | auto other_body = other.body(); |
129 | |
130 | const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0); |
131 | const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0); |
132 | |
133 | if (this_has_body && other_has_body) { |
134 | return this_body->Equals(*other_body); |
135 | } else if (this_has_body ^ other_has_body) { |
136 | // One has a body but not the other |
137 | return false; |
138 | } else { |
139 | // Neither has a body |
140 | return true; |
141 | } |
142 | } |
143 | |
144 | Status Message::ReadFrom(const std::shared_ptr<Buffer>& metadata, io::InputStream* stream, |
145 | std::unique_ptr<Message>* out) { |
146 | auto data = metadata->data(); |
147 | flatbuffers::Verifier verifier(data, metadata->size(), 128); |
148 | if (!flatbuf::VerifyMessageBuffer(verifier)) { |
149 | return Status::IOError("Invalid flatbuffers message." ); |
150 | } |
151 | auto fb_message = flatbuf::GetMessage(data); |
152 | |
153 | int64_t body_length = fb_message->bodyLength(); |
154 | |
155 | std::shared_ptr<Buffer> body; |
156 | RETURN_NOT_OK(stream->Read(body_length, &body)); |
157 | if (body->size() < body_length) { |
158 | return Status::IOError("Expected to be able to read " , body_length, |
159 | " bytes for message body, got " , body->size()); |
160 | } |
161 | |
162 | return Message::Open(metadata, body, out); |
163 | } |
164 | |
165 | Status Message::ReadFrom(const int64_t offset, const std::shared_ptr<Buffer>& metadata, |
166 | io::RandomAccessFile* file, std::unique_ptr<Message>* out) { |
167 | auto fb_message = flatbuf::GetMessage(metadata->data()); |
168 | |
169 | int64_t body_length = fb_message->bodyLength(); |
170 | |
171 | std::shared_ptr<Buffer> body; |
172 | RETURN_NOT_OK(file->ReadAt(offset, body_length, &body)); |
173 | if (body->size() < body_length) { |
174 | return Status::IOError("Expected to be able to read " , body_length, |
175 | " bytes for message body, got " , body->size()); |
176 | } |
177 | |
178 | return Message::Open(metadata, body, out); |
179 | } |
180 | |
181 | Status WritePadding(io::OutputStream* stream, int64_t nbytes) { |
182 | while (nbytes > 0) { |
183 | const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment); |
184 | RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write)); |
185 | nbytes -= bytes_to_write; |
186 | } |
187 | return Status::OK(); |
188 | } |
189 | |
190 | Status Message::SerializeTo(io::OutputStream* stream, int32_t alignment, |
191 | int64_t* output_length) const { |
192 | int32_t metadata_length = 0; |
193 | RETURN_NOT_OK(internal::WriteMessage(*metadata(), alignment, stream, &metadata_length)); |
194 | |
195 | *output_length = metadata_length; |
196 | |
197 | auto body_buffer = body(); |
198 | if (body_buffer) { |
199 | RETURN_NOT_OK(stream->Write(body_buffer->data(), body_buffer->size())); |
200 | *output_length += body_buffer->size(); |
201 | |
202 | DCHECK_GE(this->body_length(), body_buffer->size()); |
203 | |
204 | int64_t remainder = this->body_length() - body_buffer->size(); |
205 | RETURN_NOT_OK(WritePadding(stream, remainder)); |
206 | *output_length += remainder; |
207 | } |
208 | return Status::OK(); |
209 | } |
210 | |
211 | bool Message::Verify() const { |
212 | std::shared_ptr<Buffer> meta = this->metadata(); |
213 | flatbuffers::Verifier verifier(meta->data(), meta->size(), 128); |
214 | return flatbuf::VerifyMessageBuffer(verifier); |
215 | } |
216 | |
217 | std::string FormatMessageType(Message::Type type) { |
218 | switch (type) { |
219 | case Message::SCHEMA: |
220 | return "schema" ; |
221 | case Message::RECORD_BATCH: |
222 | return "record batch" ; |
223 | case Message::DICTIONARY_BATCH: |
224 | return "dictionary" ; |
225 | default: |
226 | break; |
227 | } |
228 | return "unknown" ; |
229 | } |
230 | |
231 | Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file, |
232 | std::unique_ptr<Message>* message) { |
233 | DCHECK_GT(static_cast<size_t>(metadata_length), sizeof(int32_t)); |
234 | |
235 | std::shared_ptr<Buffer> buffer; |
236 | RETURN_NOT_OK(file->ReadAt(offset, metadata_length, &buffer)); |
237 | |
238 | if (buffer->size() < metadata_length) { |
239 | return Status::Invalid("Expected to read " , metadata_length, |
240 | " metadata bytes but got " , buffer->size()); |
241 | } |
242 | |
243 | int32_t flatbuffer_size = *reinterpret_cast<const int32_t*>(buffer->data()); |
244 | |
245 | if (flatbuffer_size + static_cast<int>(sizeof(int32_t)) > metadata_length) { |
246 | return Status::Invalid("flatbuffer size " , metadata_length, |
247 | " invalid. File offset: " , offset, |
248 | ", metadata length: " , metadata_length); |
249 | } |
250 | |
251 | auto metadata = SliceBuffer(buffer, 4, buffer->size() - 4); |
252 | return Message::ReadFrom(offset + metadata_length, metadata, file, message); |
253 | } |
254 | |
255 | Status AlignStream(io::InputStream* stream, int32_t alignment) { |
256 | int64_t position = -1; |
257 | RETURN_NOT_OK(stream->Tell(&position)); |
258 | return stream->Advance(PaddedLength(position, alignment) - position); |
259 | } |
260 | |
261 | Status AlignStream(io::OutputStream* stream, int32_t alignment) { |
262 | int64_t position = -1; |
263 | RETURN_NOT_OK(stream->Tell(&position)); |
264 | int64_t remainder = PaddedLength(position, alignment) - position; |
265 | if (remainder > 0) { |
266 | return stream->Write(kPaddingBytes, remainder); |
267 | } |
268 | return Status::OK(); |
269 | } |
270 | |
271 | Status CheckAligned(io::FileInterface* stream, int32_t alignment) { |
272 | int64_t current_position; |
273 | ARROW_RETURN_NOT_OK(stream->Tell(¤t_position)); |
274 | if (current_position % alignment != 0) { |
275 | return Status::Invalid("Stream is not aligned" ); |
276 | } else { |
277 | return Status::OK(); |
278 | } |
279 | } |
280 | |
281 | Status ReadMessage(io::InputStream* file, std::unique_ptr<Message>* message) { |
282 | int32_t message_length = 0; |
283 | int64_t bytes_read = 0; |
284 | RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, |
285 | reinterpret_cast<uint8_t*>(&message_length))); |
286 | |
287 | if (bytes_read != sizeof(int32_t)) { |
288 | *message = nullptr; |
289 | return Status::OK(); |
290 | } |
291 | |
292 | if (message_length == 0) { |
293 | // Optional 0 EOS control message |
294 | *message = nullptr; |
295 | return Status::OK(); |
296 | } |
297 | |
298 | std::shared_ptr<Buffer> metadata; |
299 | RETURN_NOT_OK(file->Read(message_length, &metadata)); |
300 | if (metadata->size() != message_length) { |
301 | return Status::Invalid("Expected to read " , message_length, " metadata bytes, but " , |
302 | "only read " , metadata->size()); |
303 | } |
304 | |
305 | return Message::ReadFrom(metadata, file, message); |
306 | } |
307 | |
308 | // ---------------------------------------------------------------------- |
309 | // Implement InputStream message reader |
310 | |
311 | /// \brief Implementation of MessageReader that reads from InputStream |
312 | class InputStreamMessageReader : public MessageReader { |
313 | public: |
314 | explicit InputStreamMessageReader(io::InputStream* stream) : stream_(stream) {} |
315 | |
316 | explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream) |
317 | : InputStreamMessageReader(owned_stream.get()) { |
318 | owned_stream_ = owned_stream; |
319 | } |
320 | |
321 | ~InputStreamMessageReader() {} |
322 | |
323 | Status ReadNextMessage(std::unique_ptr<Message>* message) { |
324 | return ReadMessage(stream_, message); |
325 | } |
326 | |
327 | private: |
328 | io::InputStream* stream_; |
329 | std::shared_ptr<io::InputStream> owned_stream_; |
330 | }; |
331 | |
332 | std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) { |
333 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream)); |
334 | } |
335 | |
336 | std::unique_ptr<MessageReader> MessageReader::Open( |
337 | const std::shared_ptr<io::InputStream>& owned_stream) { |
338 | return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream)); |
339 | } |
340 | |
341 | } // namespace ipc |
342 | } // namespace arrow |
343 | |