1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/ipc/metadata-internal.h" |
19 | |
20 | #include <cstdint> |
21 | #include <memory> |
22 | #include <sstream> |
23 | #include <utility> |
24 | |
25 | #include <flatbuffers/flatbuffers.h> |
26 | |
27 | #include "arrow/array.h" |
28 | #include "arrow/io/interfaces.h" |
29 | #include "arrow/ipc/File_generated.h" // IWYU pragma: keep |
30 | #include "arrow/ipc/Message_generated.h" |
31 | #include "arrow/ipc/Tensor_generated.h" // IWYU pragma: keep |
32 | #include "arrow/ipc/message.h" |
33 | #include "arrow/ipc/util.h" |
34 | #include "arrow/sparse_tensor.h" |
35 | #include "arrow/status.h" |
36 | #include "arrow/tensor.h" |
37 | #include "arrow/type.h" |
38 | #include "arrow/util/checked_cast.h" |
39 | #include "arrow/util/logging.h" |
40 | |
41 | namespace arrow { |
42 | |
43 | namespace flatbuf = org::apache::arrow::flatbuf; |
44 | using internal::checked_cast; |
45 | |
46 | namespace ipc { |
47 | namespace internal { |
48 | |
49 | using FBB = flatbuffers::FlatBufferBuilder; |
50 | using DictionaryOffset = flatbuffers::Offset<flatbuf::DictionaryEncoding>; |
51 | using FieldOffset = flatbuffers::Offset<flatbuf::Field>; |
52 | using KeyValueOffset = flatbuffers::Offset<flatbuf::KeyValue>; |
53 | using RecordBatchOffset = flatbuffers::Offset<flatbuf::RecordBatch>; |
54 | using SparseTensorOffset = flatbuffers::Offset<flatbuf::SparseTensor>; |
55 | using Offset = flatbuffers::Offset<void>; |
56 | using FBString = flatbuffers::Offset<flatbuffers::String>; |
57 | |
58 | MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) { |
59 | switch (version) { |
60 | case flatbuf::MetadataVersion_V1: |
61 | // Arrow 0.1 |
62 | return MetadataVersion::V1; |
63 | case flatbuf::MetadataVersion_V2: |
64 | // Arrow 0.2 |
65 | return MetadataVersion::V2; |
66 | case flatbuf::MetadataVersion_V3: |
67 | // Arrow 0.3 to 0.7.1 |
68 | return MetadataVersion::V4; |
69 | case flatbuf::MetadataVersion_V4: |
70 | // Arrow >= 0.8 |
71 | return MetadataVersion::V4; |
72 | // Add cases as other versions become available |
73 | default: |
74 | return MetadataVersion::V4; |
75 | } |
76 | } |
77 | |
78 | static Status IntFromFlatbuffer(const flatbuf::Int* int_data, |
79 | std::shared_ptr<DataType>* out) { |
80 | if (int_data->bitWidth() > 64) { |
81 | return Status::NotImplemented("Integers with more than 64 bits not implemented" ); |
82 | } |
83 | if (int_data->bitWidth() < 8) { |
84 | return Status::NotImplemented("Integers with less than 8 bits not implemented" ); |
85 | } |
86 | |
87 | switch (int_data->bitWidth()) { |
88 | case 8: |
89 | *out = int_data->is_signed() ? int8() : uint8(); |
90 | break; |
91 | case 16: |
92 | *out = int_data->is_signed() ? int16() : uint16(); |
93 | break; |
94 | case 32: |
95 | *out = int_data->is_signed() ? int32() : uint32(); |
96 | break; |
97 | case 64: |
98 | *out = int_data->is_signed() ? int64() : uint64(); |
99 | break; |
100 | default: |
101 | return Status::NotImplemented("Integers not in cstdint are not implemented" ); |
102 | } |
103 | return Status::OK(); |
104 | } |
105 | |
106 | static Status FloatFromFlatbuffer(const flatbuf::FloatingPoint* float_data, |
107 | std::shared_ptr<DataType>* out) { |
108 | if (float_data->precision() == flatbuf::Precision_HALF) { |
109 | *out = float16(); |
110 | } else if (float_data->precision() == flatbuf::Precision_SINGLE) { |
111 | *out = float32(); |
112 | } else { |
113 | *out = float64(); |
114 | } |
115 | return Status::OK(); |
116 | } |
117 | |
118 | // Forward declaration |
119 | static Status FieldToFlatbuffer(FBB& fbb, const Field& field, |
120 | DictionaryMemo* dictionary_memo, FieldOffset* offset); |
121 | |
122 | static Offset IntToFlatbuffer(FBB& fbb, int bitWidth, bool is_signed) { |
123 | return flatbuf::CreateInt(fbb, bitWidth, is_signed).Union(); |
124 | } |
125 | |
126 | static Offset FloatToFlatbuffer(FBB& fbb, flatbuf::Precision precision) { |
127 | return flatbuf::CreateFloatingPoint(fbb, precision).Union(); |
128 | } |
129 | |
130 | static Status AppendChildFields(FBB& fbb, const DataType& type, |
131 | std::vector<FieldOffset>* out_children, |
132 | DictionaryMemo* dictionary_memo) { |
133 | FieldOffset field; |
134 | for (int i = 0; i < type.num_children(); ++i) { |
135 | RETURN_NOT_OK(FieldToFlatbuffer(fbb, *type.child(i), dictionary_memo, &field)); |
136 | out_children->push_back(field); |
137 | } |
138 | return Status::OK(); |
139 | } |
140 | |
141 | static Status ListToFlatbuffer(FBB& fbb, const DataType& type, |
142 | std::vector<FieldOffset>* out_children, |
143 | DictionaryMemo* dictionary_memo, Offset* offset) { |
144 | RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); |
145 | *offset = flatbuf::CreateList(fbb).Union(); |
146 | return Status::OK(); |
147 | } |
148 | |
149 | static Status StructToFlatbuffer(FBB& fbb, const DataType& type, |
150 | std::vector<FieldOffset>* out_children, |
151 | DictionaryMemo* dictionary_memo, Offset* offset) { |
152 | RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); |
153 | *offset = flatbuf::CreateStruct_(fbb).Union(); |
154 | return Status::OK(); |
155 | } |
156 | |
157 | // ---------------------------------------------------------------------- |
158 | // Union implementation |
159 | |
160 | static Status UnionFromFlatbuffer(const flatbuf::Union* union_data, |
161 | const std::vector<std::shared_ptr<Field>>& children, |
162 | std::shared_ptr<DataType>* out) { |
163 | UnionMode::type mode = |
164 | (union_data->mode() == flatbuf::UnionMode_Sparse ? UnionMode::SPARSE |
165 | : UnionMode::DENSE); |
166 | |
167 | std::vector<uint8_t> type_codes; |
168 | |
169 | const flatbuffers::Vector<int32_t>* fb_type_ids = union_data->typeIds(); |
170 | if (fb_type_ids == nullptr) { |
171 | for (uint8_t i = 0; i < children.size(); ++i) { |
172 | type_codes.push_back(i); |
173 | } |
174 | } else { |
175 | for (int32_t id : (*fb_type_ids)) { |
176 | // TODO(wesm): can these values exceed 255? |
177 | type_codes.push_back(static_cast<uint8_t>(id)); |
178 | } |
179 | } |
180 | |
181 | *out = union_(children, type_codes, mode); |
182 | return Status::OK(); |
183 | } |
184 | |
185 | static Status UnionToFlatBuffer(FBB& fbb, const DataType& type, |
186 | std::vector<FieldOffset>* out_children, |
187 | DictionaryMemo* dictionary_memo, Offset* offset) { |
188 | RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); |
189 | |
190 | const auto& union_type = checked_cast<const UnionType&>(type); |
191 | |
192 | flatbuf::UnionMode mode = union_type.mode() == UnionMode::SPARSE |
193 | ? flatbuf::UnionMode_Sparse |
194 | : flatbuf::UnionMode_Dense; |
195 | |
196 | std::vector<int32_t> type_ids; |
197 | type_ids.reserve(union_type.type_codes().size()); |
198 | for (uint8_t code : union_type.type_codes()) { |
199 | type_ids.push_back(code); |
200 | } |
201 | |
202 | auto fb_type_ids = fbb.CreateVector(type_ids); |
203 | |
204 | *offset = flatbuf::CreateUnion(fbb, mode, fb_type_ids).Union(); |
205 | return Status::OK(); |
206 | } |
207 | |
208 | #define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \ |
209 | *out_type = flatbuf::Type_Int; \ |
210 | *offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \ |
211 | break; |
212 | |
213 | static inline flatbuf::TimeUnit ToFlatbufferUnit(TimeUnit::type unit) { |
214 | switch (unit) { |
215 | case TimeUnit::SECOND: |
216 | return flatbuf::TimeUnit_SECOND; |
217 | case TimeUnit::MILLI: |
218 | return flatbuf::TimeUnit_MILLISECOND; |
219 | case TimeUnit::MICRO: |
220 | return flatbuf::TimeUnit_MICROSECOND; |
221 | case TimeUnit::NANO: |
222 | return flatbuf::TimeUnit_NANOSECOND; |
223 | default: |
224 | break; |
225 | } |
226 | return flatbuf::TimeUnit_MIN; |
227 | } |
228 | |
229 | static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) { |
230 | switch (unit) { |
231 | case flatbuf::TimeUnit_SECOND: |
232 | return TimeUnit::SECOND; |
233 | case flatbuf::TimeUnit_MILLISECOND: |
234 | return TimeUnit::MILLI; |
235 | case flatbuf::TimeUnit_MICROSECOND: |
236 | return TimeUnit::MICRO; |
237 | case flatbuf::TimeUnit_NANOSECOND: |
238 | return TimeUnit::NANO; |
239 | default: |
240 | break; |
241 | } |
242 | // cannot reach |
243 | return TimeUnit::SECOND; |
244 | } |
245 | |
246 | static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data, |
247 | const std::vector<std::shared_ptr<Field>>& children, |
248 | std::shared_ptr<DataType>* out) { |
249 | switch (type) { |
250 | case flatbuf::Type_NONE: |
251 | return Status::Invalid("Type metadata cannot be none" ); |
252 | case flatbuf::Type_Null: |
253 | *out = null(); |
254 | return Status::OK(); |
255 | case flatbuf::Type_Int: |
256 | return IntFromFlatbuffer(static_cast<const flatbuf::Int*>(type_data), out); |
257 | case flatbuf::Type_FloatingPoint: |
258 | return FloatFromFlatbuffer(static_cast<const flatbuf::FloatingPoint*>(type_data), |
259 | out); |
260 | case flatbuf::Type_Binary: |
261 | *out = binary(); |
262 | return Status::OK(); |
263 | case flatbuf::Type_FixedSizeBinary: { |
264 | auto fw_binary = static_cast<const flatbuf::FixedSizeBinary*>(type_data); |
265 | *out = fixed_size_binary(fw_binary->byteWidth()); |
266 | return Status::OK(); |
267 | } |
268 | case flatbuf::Type_Utf8: |
269 | *out = utf8(); |
270 | return Status::OK(); |
271 | case flatbuf::Type_Bool: |
272 | *out = boolean(); |
273 | return Status::OK(); |
274 | case flatbuf::Type_Decimal: { |
275 | auto dec_type = static_cast<const flatbuf::Decimal*>(type_data); |
276 | *out = decimal(dec_type->precision(), dec_type->scale()); |
277 | return Status::OK(); |
278 | } |
279 | case flatbuf::Type_Date: { |
280 | auto date_type = static_cast<const flatbuf::Date*>(type_data); |
281 | if (date_type->unit() == flatbuf::DateUnit_DAY) { |
282 | *out = date32(); |
283 | } else { |
284 | *out = date64(); |
285 | } |
286 | return Status::OK(); |
287 | } |
288 | case flatbuf::Type_Time: { |
289 | auto time_type = static_cast<const flatbuf::Time*>(type_data); |
290 | TimeUnit::type unit = FromFlatbufferUnit(time_type->unit()); |
291 | int32_t bit_width = time_type->bitWidth(); |
292 | switch (unit) { |
293 | case TimeUnit::SECOND: |
294 | case TimeUnit::MILLI: |
295 | if (bit_width != 32) { |
296 | return Status::Invalid("Time is 32 bits for second/milli unit" ); |
297 | } |
298 | *out = time32(unit); |
299 | break; |
300 | default: |
301 | if (bit_width != 64) { |
302 | return Status::Invalid("Time is 64 bits for micro/nano unit" ); |
303 | } |
304 | *out = time64(unit); |
305 | break; |
306 | } |
307 | return Status::OK(); |
308 | } |
309 | case flatbuf::Type_Timestamp: { |
310 | auto ts_type = static_cast<const flatbuf::Timestamp*>(type_data); |
311 | TimeUnit::type unit = FromFlatbufferUnit(ts_type->unit()); |
312 | if (ts_type->timezone() != 0 && ts_type->timezone()->Length() > 0) { |
313 | *out = timestamp(unit, ts_type->timezone()->str()); |
314 | } else { |
315 | *out = timestamp(unit); |
316 | } |
317 | return Status::OK(); |
318 | } |
319 | case flatbuf::Type_Interval: |
320 | return Status::NotImplemented("Interval" ); |
321 | case flatbuf::Type_List: |
322 | if (children.size() != 1) { |
323 | return Status::Invalid("List must have exactly 1 child field" ); |
324 | } |
325 | *out = std::make_shared<ListType>(children[0]); |
326 | return Status::OK(); |
327 | case flatbuf::Type_Struct_: |
328 | *out = std::make_shared<StructType>(children); |
329 | return Status::OK(); |
330 | case flatbuf::Type_Union: |
331 | return UnionFromFlatbuffer(static_cast<const flatbuf::Union*>(type_data), children, |
332 | out); |
333 | default: |
334 | return Status::Invalid("Unrecognized type" ); |
335 | } |
336 | } |
337 | |
338 | // TODO(wesm): Convert this to visitor pattern |
339 | static Status TypeToFlatbuffer(FBB& fbb, const DataType& type, |
340 | std::vector<FieldOffset>* children, |
341 | flatbuf::Type* out_type, DictionaryMemo* dictionary_memo, |
342 | Offset* offset) { |
343 | const DataType* value_type = &type; |
344 | |
345 | if (type.id() == Type::DICTIONARY) { |
346 | // In this library, the dictionary "type" is a logical construct. Here we |
347 | // pass through to the value type, as we've already captured the index |
348 | // type in the DictionaryEncoding metadata in the parent field |
349 | value_type = checked_cast<const DictionaryType&>(type).dictionary()->type().get(); |
350 | } |
351 | |
352 | switch (value_type->id()) { |
353 | case Type::NA: |
354 | *out_type = flatbuf::Type_Null; |
355 | *offset = flatbuf::CreateNull(fbb).Union(); |
356 | break; |
357 | case Type::BOOL: |
358 | *out_type = flatbuf::Type_Bool; |
359 | *offset = flatbuf::CreateBool(fbb).Union(); |
360 | break; |
361 | case Type::UINT8: |
362 | INT_TO_FB_CASE(8, false); |
363 | case Type::INT8: |
364 | INT_TO_FB_CASE(8, true); |
365 | case Type::UINT16: |
366 | INT_TO_FB_CASE(16, false); |
367 | case Type::INT16: |
368 | INT_TO_FB_CASE(16, true); |
369 | case Type::UINT32: |
370 | INT_TO_FB_CASE(32, false); |
371 | case Type::INT32: |
372 | INT_TO_FB_CASE(32, true); |
373 | case Type::UINT64: |
374 | INT_TO_FB_CASE(64, false); |
375 | case Type::INT64: |
376 | INT_TO_FB_CASE(64, true); |
377 | case Type::HALF_FLOAT: |
378 | *out_type = flatbuf::Type_FloatingPoint; |
379 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_HALF); |
380 | break; |
381 | case Type::FLOAT: |
382 | *out_type = flatbuf::Type_FloatingPoint; |
383 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_SINGLE); |
384 | break; |
385 | case Type::DOUBLE: |
386 | *out_type = flatbuf::Type_FloatingPoint; |
387 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_DOUBLE); |
388 | break; |
389 | case Type::FIXED_SIZE_BINARY: { |
390 | const auto& fw_type = checked_cast<const FixedSizeBinaryType&>(*value_type); |
391 | *out_type = flatbuf::Type_FixedSizeBinary; |
392 | *offset = flatbuf::CreateFixedSizeBinary(fbb, fw_type.byte_width()).Union(); |
393 | } break; |
394 | case Type::BINARY: |
395 | *out_type = flatbuf::Type_Binary; |
396 | *offset = flatbuf::CreateBinary(fbb).Union(); |
397 | break; |
398 | case Type::STRING: |
399 | *out_type = flatbuf::Type_Utf8; |
400 | *offset = flatbuf::CreateUtf8(fbb).Union(); |
401 | break; |
402 | case Type::DATE32: |
403 | *out_type = flatbuf::Type_Date; |
404 | *offset = flatbuf::CreateDate(fbb, flatbuf::DateUnit_DAY).Union(); |
405 | break; |
406 | case Type::DATE64: |
407 | *out_type = flatbuf::Type_Date; |
408 | *offset = flatbuf::CreateDate(fbb, flatbuf::DateUnit_MILLISECOND).Union(); |
409 | break; |
410 | case Type::TIME32: { |
411 | const auto& time_type = checked_cast<const Time32Type&>(*value_type); |
412 | *out_type = flatbuf::Type_Time; |
413 | *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit()), 32).Union(); |
414 | } break; |
415 | case Type::TIME64: { |
416 | const auto& time_type = checked_cast<const Time64Type&>(*value_type); |
417 | *out_type = flatbuf::Type_Time; |
418 | *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit()), 64).Union(); |
419 | } break; |
420 | case Type::TIMESTAMP: { |
421 | const auto& ts_type = checked_cast<const TimestampType&>(*value_type); |
422 | *out_type = flatbuf::Type_Timestamp; |
423 | |
424 | flatbuf::TimeUnit fb_unit = ToFlatbufferUnit(ts_type.unit()); |
425 | FBString fb_timezone = 0; |
426 | if (ts_type.timezone().size() > 0) { |
427 | fb_timezone = fbb.CreateString(ts_type.timezone()); |
428 | } |
429 | *offset = flatbuf::CreateTimestamp(fbb, fb_unit, fb_timezone).Union(); |
430 | } break; |
431 | case Type::DECIMAL: { |
432 | const auto& dec_type = checked_cast<const Decimal128Type&>(*value_type); |
433 | *out_type = flatbuf::Type_Decimal; |
434 | *offset = |
435 | flatbuf::CreateDecimal(fbb, dec_type.precision(), dec_type.scale()).Union(); |
436 | } break; |
437 | case Type::LIST: |
438 | *out_type = flatbuf::Type_List; |
439 | return ListToFlatbuffer(fbb, *value_type, children, dictionary_memo, offset); |
440 | case Type::STRUCT: |
441 | *out_type = flatbuf::Type_Struct_; |
442 | return StructToFlatbuffer(fbb, *value_type, children, dictionary_memo, offset); |
443 | case Type::UNION: |
444 | *out_type = flatbuf::Type_Union; |
445 | return UnionToFlatBuffer(fbb, *value_type, children, dictionary_memo, offset); |
446 | default: |
447 | *out_type = flatbuf::Type_NONE; // Make clang-tidy happy |
448 | return Status::NotImplemented("Unable to convert type: " , type.ToString()); |
449 | } |
450 | return Status::OK(); |
451 | } |
452 | |
453 | static Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type, |
454 | flatbuf::Type* out_type, Offset* offset) { |
455 | switch (type.id()) { |
456 | case Type::UINT8: |
457 | INT_TO_FB_CASE(8, false); |
458 | case Type::INT8: |
459 | INT_TO_FB_CASE(8, true); |
460 | case Type::UINT16: |
461 | INT_TO_FB_CASE(16, false); |
462 | case Type::INT16: |
463 | INT_TO_FB_CASE(16, true); |
464 | case Type::UINT32: |
465 | INT_TO_FB_CASE(32, false); |
466 | case Type::INT32: |
467 | INT_TO_FB_CASE(32, true); |
468 | case Type::UINT64: |
469 | INT_TO_FB_CASE(64, false); |
470 | case Type::INT64: |
471 | INT_TO_FB_CASE(64, true); |
472 | case Type::HALF_FLOAT: |
473 | *out_type = flatbuf::Type_FloatingPoint; |
474 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_HALF); |
475 | break; |
476 | case Type::FLOAT: |
477 | *out_type = flatbuf::Type_FloatingPoint; |
478 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_SINGLE); |
479 | break; |
480 | case Type::DOUBLE: |
481 | *out_type = flatbuf::Type_FloatingPoint; |
482 | *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_DOUBLE); |
483 | break; |
484 | default: |
485 | *out_type = flatbuf::Type_NONE; // Make clang-tidy happy |
486 | return Status::NotImplemented("Unable to convert type: " , type.ToString()); |
487 | } |
488 | return Status::OK(); |
489 | } |
490 | |
491 | static DictionaryOffset GetDictionaryEncoding(FBB& fbb, const DictionaryType& type, |
492 | DictionaryMemo* memo) { |
493 | int64_t dictionary_id = memo->GetId(type.dictionary()); |
494 | |
495 | // We assume that the dictionary index type (as an integer) has already been |
496 | // validated elsewhere, and can safely assume we are dealing with signed |
497 | // integers |
498 | const auto& fw_index_type = checked_cast<const FixedWidthType&>(*type.index_type()); |
499 | |
500 | auto index_type_offset = flatbuf::CreateInt(fbb, fw_index_type.bit_width(), true); |
501 | |
502 | // TODO(wesm): ordered dictionaries |
503 | return flatbuf::CreateDictionaryEncoding(fbb, dictionary_id, index_type_offset, |
504 | type.ordered()); |
505 | } |
506 | |
507 | static flatbuffers::Offset<flatbuffers::Vector<KeyValueOffset>> |
508 | KeyValueMetadataToFlatbuffer(FBB& fbb, const KeyValueMetadata& metadata) { |
509 | std::vector<KeyValueOffset> key_value_offsets; |
510 | |
511 | size_t metadata_size = metadata.size(); |
512 | key_value_offsets.reserve(metadata_size); |
513 | |
514 | for (size_t i = 0; i < metadata_size; ++i) { |
515 | const auto& key = metadata.key(i); |
516 | const auto& value = metadata.value(i); |
517 | key_value_offsets.push_back( |
518 | flatbuf::CreateKeyValue(fbb, fbb.CreateString(key), fbb.CreateString(value))); |
519 | } |
520 | |
521 | return fbb.CreateVector(key_value_offsets); |
522 | } |
523 | |
524 | static Status KeyValueMetadataFromFlatbuffer( |
525 | const flatbuffers::Vector<KeyValueOffset>* fb_metadata, |
526 | std::shared_ptr<KeyValueMetadata>* out) { |
527 | auto metadata = std::make_shared<KeyValueMetadata>(); |
528 | |
529 | metadata->reserve(fb_metadata->size()); |
530 | for (const auto& pair : *fb_metadata) { |
531 | if (pair->key() == nullptr) { |
532 | return Status::IOError( |
533 | "Key-pointer in custom metadata of flatbuffer-encoded Schema is null." ); |
534 | } |
535 | if (pair->value() == nullptr) { |
536 | return Status::IOError( |
537 | "Value-pointer in custom metadata of flatbuffer-encoded Schema is null." ); |
538 | } |
539 | metadata->Append(pair->key()->str(), pair->value()->str()); |
540 | } |
541 | |
542 | *out = metadata; |
543 | |
544 | return Status::OK(); |
545 | } |
546 | |
547 | static Status FieldToFlatbuffer(FBB& fbb, const Field& field, |
548 | DictionaryMemo* dictionary_memo, FieldOffset* offset) { |
549 | auto fb_name = fbb.CreateString(field.name()); |
550 | |
551 | flatbuf::Type type_enum; |
552 | Offset type_offset; |
553 | std::vector<FieldOffset> children; |
554 | |
555 | RETURN_NOT_OK(TypeToFlatbuffer(fbb, *field.type(), &children, &type_enum, |
556 | dictionary_memo, &type_offset)); |
557 | auto fb_children = fbb.CreateVector(children); |
558 | |
559 | DictionaryOffset dictionary = 0; |
560 | if (field.type()->id() == Type::DICTIONARY) { |
561 | dictionary = GetDictionaryEncoding( |
562 | fbb, checked_cast<const DictionaryType&>(*field.type()), dictionary_memo); |
563 | } |
564 | |
565 | auto metadata = field.metadata(); |
566 | if (metadata != nullptr) { |
567 | auto fb_custom_metadata = KeyValueMetadataToFlatbuffer(fbb, *metadata); |
568 | *offset = flatbuf::CreateField(fbb, fb_name, field.nullable(), type_enum, type_offset, |
569 | dictionary, fb_children, fb_custom_metadata); |
570 | } else { |
571 | *offset = flatbuf::CreateField(fbb, fb_name, field.nullable(), type_enum, type_offset, |
572 | dictionary, fb_children); |
573 | } |
574 | |
575 | return Status::OK(); |
576 | } |
577 | |
578 | static Status FieldFromFlatbuffer(const flatbuf::Field* field, |
579 | const DictionaryMemo& dictionary_memo, |
580 | std::shared_ptr<Field>* out) { |
581 | std::shared_ptr<DataType> type; |
582 | |
583 | const flatbuf::DictionaryEncoding* encoding = field->dictionary(); |
584 | |
585 | if (encoding == nullptr) { |
586 | // The field is not dictionary encoded. We must potentially visit its |
587 | // children to fully reconstruct the data type |
588 | auto children = field->children(); |
589 | std::vector<std::shared_ptr<Field>> child_fields(children->size()); |
590 | for (int i = 0; i < static_cast<int>(children->size()); ++i) { |
591 | RETURN_NOT_OK( |
592 | FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); |
593 | } |
594 | RETURN_NOT_OK( |
595 | TypeFromFlatbuffer(field->type_type(), field->type(), child_fields, &type)); |
596 | } else { |
597 | // The field is dictionary encoded. The type of the dictionary values has |
598 | // been determined elsewhere, and is stored in the DictionaryMemo. Here we |
599 | // construct the logical DictionaryType object |
600 | |
601 | std::shared_ptr<Array> dictionary; |
602 | RETURN_NOT_OK(dictionary_memo.GetDictionary(encoding->id(), &dictionary)); |
603 | |
604 | std::shared_ptr<DataType> index_type; |
605 | RETURN_NOT_OK(IntFromFlatbuffer(encoding->indexType(), &index_type)); |
606 | type = ::arrow::dictionary(index_type, dictionary, encoding->isOrdered()); |
607 | } |
608 | |
609 | auto fb_metadata = field->custom_metadata(); |
610 | std::shared_ptr<KeyValueMetadata> metadata; |
611 | |
612 | if (fb_metadata != nullptr) { |
613 | RETURN_NOT_OK(KeyValueMetadataFromFlatbuffer(fb_metadata, &metadata)); |
614 | } |
615 | |
616 | *out = std::make_shared<Field>(field->name()->str(), type, field->nullable(), metadata); |
617 | |
618 | return Status::OK(); |
619 | } |
620 | |
621 | static Status FieldFromFlatbufferDictionary(const flatbuf::Field* field, |
622 | std::shared_ptr<Field>* out) { |
623 | // Need an empty memo to pass down for constructing children |
624 | DictionaryMemo dummy_memo; |
625 | |
626 | // Any DictionaryEncoding set is ignored here |
627 | |
628 | std::shared_ptr<DataType> type; |
629 | auto children = field->children(); |
630 | std::vector<std::shared_ptr<Field>> child_fields(children->size()); |
631 | for (int i = 0; i < static_cast<int>(children->size()); ++i) { |
632 | RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), dummy_memo, &child_fields[i])); |
633 | } |
634 | |
635 | RETURN_NOT_OK( |
636 | TypeFromFlatbuffer(field->type_type(), field->type(), child_fields, &type)); |
637 | |
638 | *out = std::make_shared<Field>(field->name()->str(), type, field->nullable()); |
639 | return Status::OK(); |
640 | } |
641 | |
642 | // will return the endianness of the system we are running on |
643 | // based the NUMPY_API function. See NOTICE.txt |
644 | flatbuf::Endianness endianness() { |
645 | union { |
646 | uint32_t i; |
647 | char c[4]; |
648 | } bint = {0x01020304}; |
649 | |
650 | return bint.c[0] == 1 ? flatbuf::Endianness_Big : flatbuf::Endianness_Little; |
651 | } |
652 | |
653 | static Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, |
654 | DictionaryMemo* dictionary_memo, |
655 | flatbuffers::Offset<flatbuf::Schema>* out) { |
656 | /// Fields |
657 | std::vector<FieldOffset> field_offsets; |
658 | for (int i = 0; i < schema.num_fields(); ++i) { |
659 | FieldOffset offset; |
660 | RETURN_NOT_OK(FieldToFlatbuffer(fbb, *schema.field(i), dictionary_memo, &offset)); |
661 | field_offsets.push_back(offset); |
662 | } |
663 | |
664 | auto fb_offsets = fbb.CreateVector(field_offsets); |
665 | |
666 | /// Custom metadata |
667 | auto metadata = schema.metadata(); |
668 | if (metadata != nullptr) { |
669 | auto fb_custom_metadata = KeyValueMetadataToFlatbuffer(fbb, *metadata); |
670 | *out = flatbuf::CreateSchema(fbb, endianness(), fb_offsets, fb_custom_metadata); |
671 | } else { |
672 | *out = flatbuf::CreateSchema(fbb, endianness(), fb_offsets); |
673 | } |
674 | |
675 | return Status::OK(); |
676 | } |
677 | |
678 | static Status (FBB& fbb, flatbuf::MessageHeader , |
679 | flatbuffers::Offset<void> , int64_t body_length, |
680 | std::shared_ptr<Buffer>* out) { |
681 | auto message = flatbuf::CreateMessage(fbb, kCurrentMetadataVersion, header_type, header, |
682 | body_length); |
683 | fbb.Finish(message); |
684 | return WriteFlatbufferBuilder(fbb, out); |
685 | } |
686 | |
687 | Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo, |
688 | std::shared_ptr<Buffer>* out) { |
689 | FBB fbb; |
690 | flatbuffers::Offset<flatbuf::Schema> fb_schema; |
691 | RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); |
692 | return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out); |
693 | } |
694 | |
695 | using FieldNodeVector = |
696 | flatbuffers::Offset<flatbuffers::Vector<const flatbuf::FieldNode*>>; |
697 | using BufferVector = flatbuffers::Offset<flatbuffers::Vector<const flatbuf::Buffer*>>; |
698 | |
699 | static Status WriteFieldNodes(FBB& fbb, const std::vector<FieldMetadata>& nodes, |
700 | FieldNodeVector* out) { |
701 | std::vector<flatbuf::FieldNode> fb_nodes; |
702 | fb_nodes.reserve(nodes.size()); |
703 | |
704 | for (size_t i = 0; i < nodes.size(); ++i) { |
705 | const FieldMetadata& node = nodes[i]; |
706 | if (node.offset != 0) { |
707 | return Status::Invalid("Field metadata for IPC must have offset 0" ); |
708 | } |
709 | fb_nodes.emplace_back(node.length, node.null_count); |
710 | } |
711 | *out = fbb.CreateVectorOfStructs(fb_nodes); |
712 | return Status::OK(); |
713 | } |
714 | |
715 | static Status WriteBuffers(FBB& fbb, const std::vector<BufferMetadata>& buffers, |
716 | BufferVector* out) { |
717 | std::vector<flatbuf::Buffer> fb_buffers; |
718 | fb_buffers.reserve(buffers.size()); |
719 | |
720 | for (size_t i = 0; i < buffers.size(); ++i) { |
721 | const BufferMetadata& buffer = buffers[i]; |
722 | fb_buffers.emplace_back(buffer.offset, buffer.length); |
723 | } |
724 | *out = fbb.CreateVectorOfStructs(fb_buffers); |
725 | return Status::OK(); |
726 | } |
727 | |
728 | static Status MakeRecordBatch(FBB& fbb, int64_t length, int64_t body_length, |
729 | const std::vector<FieldMetadata>& nodes, |
730 | const std::vector<BufferMetadata>& buffers, |
731 | RecordBatchOffset* offset) { |
732 | FieldNodeVector fb_nodes; |
733 | BufferVector fb_buffers; |
734 | |
735 | RETURN_NOT_OK(WriteFieldNodes(fbb, nodes, &fb_nodes)); |
736 | RETURN_NOT_OK(WriteBuffers(fbb, buffers, &fb_buffers)); |
737 | |
738 | *offset = flatbuf::CreateRecordBatch(fbb, length, fb_nodes, fb_buffers); |
739 | return Status::OK(); |
740 | } |
741 | |
742 | Status WriteRecordBatchMessage(int64_t length, int64_t body_length, |
743 | const std::vector<FieldMetadata>& nodes, |
744 | const std::vector<BufferMetadata>& buffers, |
745 | std::shared_ptr<Buffer>* out) { |
746 | FBB fbb; |
747 | RecordBatchOffset record_batch; |
748 | RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); |
749 | return WriteFBMessage(fbb, flatbuf::MessageHeader_RecordBatch, record_batch.Union(), |
750 | body_length, out); |
751 | } |
752 | |
753 | Status WriteTensorMessage(const Tensor& tensor, int64_t buffer_start_offset, |
754 | std::shared_ptr<Buffer>* out) { |
755 | using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>; |
756 | using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>; |
757 | |
758 | FBB fbb; |
759 | |
760 | const auto& type = checked_cast<const FixedWidthType&>(*tensor.type()); |
761 | const int elem_size = type.bit_width() / 8; |
762 | |
763 | flatbuf::Type fb_type_type; |
764 | Offset fb_type; |
765 | RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, *tensor.type(), &fb_type_type, &fb_type)); |
766 | |
767 | std::vector<TensorDimOffset> dims; |
768 | for (int i = 0; i < tensor.ndim(); ++i) { |
769 | FBString name = fbb.CreateString(tensor.dim_name(i)); |
770 | dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name)); |
771 | } |
772 | |
773 | auto fb_shape = fbb.CreateVector(dims); |
774 | auto fb_strides = fbb.CreateVector(tensor.strides()); |
775 | |
776 | int64_t body_length = tensor.size() * elem_size; |
777 | flatbuf::Buffer buffer(buffer_start_offset, body_length); |
778 | |
779 | TensorOffset fb_tensor = |
780 | flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer); |
781 | |
782 | return WriteFBMessage(fbb, flatbuf::MessageHeader_Tensor, fb_tensor.Union(), |
783 | body_length, out); |
784 | } |
785 | |
786 | Status MakeSparseTensorIndexCOO(FBB& fbb, const SparseCOOIndex& sparse_index, |
787 | const std::vector<BufferMetadata>& buffers, |
788 | flatbuf::SparseTensorIndex* fb_sparse_index_type, |
789 | Offset* fb_sparse_index, size_t* num_buffers) { |
790 | *fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseTensorIndexCOO; |
791 | const BufferMetadata& indices_metadata = buffers[0]; |
792 | flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length); |
793 | *fb_sparse_index = flatbuf::CreateSparseTensorIndexCOO(fbb, &indices).Union(); |
794 | *num_buffers = 1; |
795 | return Status::OK(); |
796 | } |
797 | |
798 | Status MakeSparseMatrixIndexCSR(FBB& fbb, const SparseCSRIndex& sparse_index, |
799 | const std::vector<BufferMetadata>& buffers, |
800 | flatbuf::SparseTensorIndex* fb_sparse_index_type, |
801 | Offset* fb_sparse_index, size_t* num_buffers) { |
802 | *fb_sparse_index_type = flatbuf::SparseTensorIndex_SparseMatrixIndexCSR; |
803 | const BufferMetadata& indptr_metadata = buffers[0]; |
804 | const BufferMetadata& indices_metadata = buffers[1]; |
805 | flatbuf::Buffer indptr(indptr_metadata.offset, indptr_metadata.length); |
806 | flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length); |
807 | *fb_sparse_index = flatbuf::CreateSparseMatrixIndexCSR(fbb, &indptr, &indices).Union(); |
808 | *num_buffers = 2; |
809 | return Status::OK(); |
810 | } |
811 | |
812 | Status MakeSparseTensorIndex(FBB& fbb, const SparseIndex& sparse_index, |
813 | const std::vector<BufferMetadata>& buffers, |
814 | flatbuf::SparseTensorIndex* fb_sparse_index_type, |
815 | Offset* fb_sparse_index, size_t* num_buffers) { |
816 | switch (sparse_index.format_id()) { |
817 | case SparseTensorFormat::COO: |
818 | RETURN_NOT_OK(MakeSparseTensorIndexCOO( |
819 | fbb, checked_cast<const SparseCOOIndex&>(sparse_index), buffers, |
820 | fb_sparse_index_type, fb_sparse_index, num_buffers)); |
821 | break; |
822 | |
823 | case SparseTensorFormat::CSR: |
824 | RETURN_NOT_OK(MakeSparseMatrixIndexCSR( |
825 | fbb, checked_cast<const SparseCSRIndex&>(sparse_index), buffers, |
826 | fb_sparse_index_type, fb_sparse_index, num_buffers)); |
827 | break; |
828 | |
829 | default: |
830 | std::stringstream ss; |
831 | ss << "Unsupporoted sparse tensor format:: " << sparse_index.ToString() |
832 | << std::endl; |
833 | return Status::NotImplemented(ss.str()); |
834 | } |
835 | |
836 | return Status::OK(); |
837 | } |
838 | |
839 | Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor, int64_t body_length, |
840 | const std::vector<BufferMetadata>& buffers, |
841 | SparseTensorOffset* offset) { |
842 | flatbuf::Type fb_type_type; |
843 | Offset fb_type; |
844 | RETURN_NOT_OK( |
845 | TensorTypeToFlatbuffer(fbb, *sparse_tensor.type(), &fb_type_type, &fb_type)); |
846 | |
847 | using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>; |
848 | std::vector<TensorDimOffset> dims; |
849 | for (int i = 0; i < sparse_tensor.ndim(); ++i) { |
850 | FBString name = fbb.CreateString(sparse_tensor.dim_name(i)); |
851 | dims.push_back(flatbuf::CreateTensorDim(fbb, sparse_tensor.shape()[i], name)); |
852 | } |
853 | |
854 | auto fb_shape = fbb.CreateVector(dims); |
855 | |
856 | flatbuf::SparseTensorIndex fb_sparse_index_type; |
857 | Offset fb_sparse_index; |
858 | size_t num_index_buffers = 0; |
859 | RETURN_NOT_OK(MakeSparseTensorIndex(fbb, *sparse_tensor.sparse_index(), buffers, |
860 | &fb_sparse_index_type, &fb_sparse_index, |
861 | &num_index_buffers)); |
862 | |
863 | const BufferMetadata& data_metadata = buffers[num_index_buffers]; |
864 | flatbuf::Buffer data(data_metadata.offset, data_metadata.length); |
865 | |
866 | const int64_t non_zero_length = sparse_tensor.non_zero_length(); |
867 | |
868 | *offset = |
869 | flatbuf::CreateSparseTensor(fbb, fb_type_type, fb_type, fb_shape, non_zero_length, |
870 | fb_sparse_index_type, fb_sparse_index, &data); |
871 | |
872 | return Status::OK(); |
873 | } |
874 | |
875 | Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_length, |
876 | const std::vector<BufferMetadata>& buffers, |
877 | std::shared_ptr<Buffer>* out) { |
878 | FBB fbb; |
879 | SparseTensorOffset fb_sparse_tensor; |
880 | RETURN_NOT_OK( |
881 | MakeSparseTensor(fbb, sparse_tensor, body_length, buffers, &fb_sparse_tensor)); |
882 | return WriteFBMessage(fbb, flatbuf::MessageHeader_SparseTensor, |
883 | fb_sparse_tensor.Union(), body_length, out); |
884 | } |
885 | |
886 | Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length, |
887 | const std::vector<FieldMetadata>& nodes, |
888 | const std::vector<BufferMetadata>& buffers, |
889 | std::shared_ptr<Buffer>* out) { |
890 | FBB fbb; |
891 | RecordBatchOffset record_batch; |
892 | RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); |
893 | auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union(); |
894 | return WriteFBMessage(fbb, flatbuf::MessageHeader_DictionaryBatch, dictionary_batch, |
895 | body_length, out); |
896 | } |
897 | |
898 | static flatbuffers::Offset<flatbuffers::Vector<const flatbuf::Block*>> |
899 | FileBlocksToFlatbuffer(FBB& fbb, const std::vector<FileBlock>& blocks) { |
900 | std::vector<flatbuf::Block> fb_blocks; |
901 | |
902 | for (const FileBlock& block : blocks) { |
903 | fb_blocks.emplace_back(block.offset, block.metadata_length, block.body_length); |
904 | } |
905 | |
906 | return fbb.CreateVectorOfStructs(fb_blocks); |
907 | } |
908 | |
909 | Status (const Schema& schema, const std::vector<FileBlock>& dictionaries, |
910 | const std::vector<FileBlock>& record_batches, |
911 | DictionaryMemo* dictionary_memo, io::OutputStream* out) { |
912 | FBB fbb; |
913 | |
914 | flatbuffers::Offset<flatbuf::Schema> fb_schema; |
915 | RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); |
916 | |
917 | #ifndef NDEBUG |
918 | for (size_t i = 0; i < dictionaries.size(); ++i) { |
919 | DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].offset)) << i; |
920 | DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].metadata_length)) << i; |
921 | DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].body_length)) << i; |
922 | } |
923 | |
924 | for (size_t i = 0; i < record_batches.size(); ++i) { |
925 | DCHECK(BitUtil::IsMultipleOf8(record_batches[i].offset)) << i; |
926 | DCHECK(BitUtil::IsMultipleOf8(record_batches[i].metadata_length)) << i; |
927 | DCHECK(BitUtil::IsMultipleOf8(record_batches[i].body_length)) << i; |
928 | } |
929 | #endif |
930 | |
931 | auto fb_dictionaries = FileBlocksToFlatbuffer(fbb, dictionaries); |
932 | auto fb_record_batches = FileBlocksToFlatbuffer(fbb, record_batches); |
933 | |
934 | auto = flatbuf::CreateFooter(fbb, kCurrentMetadataVersion, fb_schema, |
935 | fb_dictionaries, fb_record_batches); |
936 | |
937 | fbb.Finish(footer); |
938 | |
939 | int32_t size = fbb.GetSize(); |
940 | |
941 | return out->Write(fbb.GetBufferPointer(), size); |
942 | } |
943 | |
944 | // ---------------------------------------------------------------------- |
945 | |
946 | static Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) { |
947 | const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary(); |
948 | if (dict_metadata == nullptr) { |
949 | // Field is not dictionary encoded. Visit children |
950 | auto children = field->children(); |
951 | if (children == nullptr) { |
952 | return Status::IOError("Children-pointer of flatbuffer-encoded Field is null." ); |
953 | } |
954 | for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) { |
955 | RETURN_NOT_OK(VisitField(children->Get(i), id_to_field)); |
956 | } |
957 | } else { |
958 | // Field is dictionary encoded. Construct the data type for the |
959 | // dictionary (no descendents can be dictionary encoded) |
960 | std::shared_ptr<Field> dictionary_field; |
961 | RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field)); |
962 | (*id_to_field)[dict_metadata->id()] = dictionary_field; |
963 | } |
964 | return Status::OK(); |
965 | } |
966 | |
967 | Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field) { |
968 | auto schema = static_cast<const flatbuf::Schema*>(opaque_schema); |
969 | if (schema->fields() == nullptr) { |
970 | return Status::IOError("Fields-pointer of flatbuffer-encoded Schema is null." ); |
971 | } |
972 | int num_fields = static_cast<int>(schema->fields()->size()); |
973 | for (int i = 0; i < num_fields; ++i) { |
974 | auto field = schema->fields()->Get(i); |
975 | if (field == nullptr) { |
976 | return Status::IOError("Field-pointer of flatbuffer-encoded Schema is null." ); |
977 | } |
978 | RETURN_NOT_OK(VisitField(field, id_to_field)); |
979 | } |
980 | return Status::OK(); |
981 | } |
982 | |
983 | Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo, |
984 | std::shared_ptr<Schema>* out) { |
985 | auto schema = static_cast<const flatbuf::Schema*>(opaque_schema); |
986 | if (schema->fields() == nullptr) { |
987 | return Status::IOError("Fields-pointer of flatbuffer-encoded Schema is null." ); |
988 | } |
989 | int num_fields = static_cast<int>(schema->fields()->size()); |
990 | |
991 | std::vector<std::shared_ptr<Field>> fields(num_fields); |
992 | for (int i = 0; i < num_fields; ++i) { |
993 | const flatbuf::Field* field = schema->fields()->Get(i); |
994 | RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); |
995 | } |
996 | |
997 | auto fb_metadata = schema->custom_metadata(); |
998 | std::shared_ptr<KeyValueMetadata> metadata; |
999 | |
1000 | if (fb_metadata != nullptr) { |
1001 | RETURN_NOT_OK(KeyValueMetadataFromFlatbuffer(fb_metadata, &metadata)); |
1002 | } |
1003 | |
1004 | *out = ::arrow::schema(std::move(fields), metadata); |
1005 | |
1006 | return Status::OK(); |
1007 | } |
1008 | |
1009 | Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type, |
1010 | std::vector<int64_t>* shape, std::vector<int64_t>* strides, |
1011 | std::vector<std::string>* dim_names) { |
1012 | auto message = flatbuf::GetMessage(metadata.data()); |
1013 | auto tensor = reinterpret_cast<const flatbuf::Tensor*>(message->header()); |
1014 | |
1015 | int ndim = static_cast<int>(tensor->shape()->size()); |
1016 | |
1017 | for (int i = 0; i < ndim; ++i) { |
1018 | auto dim = tensor->shape()->Get(i); |
1019 | |
1020 | shape->push_back(dim->size()); |
1021 | auto fb_name = dim->name(); |
1022 | if (fb_name == 0) { |
1023 | dim_names->push_back("" ); |
1024 | } else { |
1025 | dim_names->push_back(fb_name->str()); |
1026 | } |
1027 | } |
1028 | |
1029 | if (tensor->strides()->size() > 0) { |
1030 | for (int i = 0; i < ndim; ++i) { |
1031 | strides->push_back(tensor->strides()->Get(i)); |
1032 | } |
1033 | } |
1034 | |
1035 | return TypeFromFlatbuffer(tensor->type_type(), tensor->type(), {}, type); |
1036 | } |
1037 | |
1038 | Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type, |
1039 | std::vector<int64_t>* shape, |
1040 | std::vector<std::string>* dim_names, |
1041 | int64_t* non_zero_length, |
1042 | SparseTensorFormat::type* sparse_tensor_format_id) { |
1043 | auto message = flatbuf::GetMessage(metadata.data()); |
1044 | if (message->header_type() != flatbuf::MessageHeader_SparseTensor) { |
1045 | return Status::IOError("Header of flatbuffer-encoded Message is not SparseTensor." ); |
1046 | } |
1047 | if (message->header() == nullptr) { |
1048 | return Status::IOError("Header-pointer of flatbuffer-encoded Message is null." ); |
1049 | } |
1050 | |
1051 | auto sparse_tensor = reinterpret_cast<const flatbuf::SparseTensor*>(message->header()); |
1052 | int ndim = static_cast<int>(sparse_tensor->shape()->size()); |
1053 | |
1054 | for (int i = 0; i < ndim; ++i) { |
1055 | auto dim = sparse_tensor->shape()->Get(i); |
1056 | |
1057 | shape->push_back(dim->size()); |
1058 | auto fb_name = dim->name(); |
1059 | if (fb_name == 0) { |
1060 | dim_names->push_back("" ); |
1061 | } else { |
1062 | dim_names->push_back(fb_name->str()); |
1063 | } |
1064 | } |
1065 | |
1066 | *non_zero_length = sparse_tensor->non_zero_length(); |
1067 | |
1068 | switch (sparse_tensor->sparseIndex_type()) { |
1069 | case flatbuf::SparseTensorIndex_SparseTensorIndexCOO: |
1070 | *sparse_tensor_format_id = SparseTensorFormat::COO; |
1071 | break; |
1072 | |
1073 | case flatbuf::SparseTensorIndex_SparseMatrixIndexCSR: |
1074 | *sparse_tensor_format_id = SparseTensorFormat::CSR; |
1075 | break; |
1076 | |
1077 | default: |
1078 | return Status::Invalid("Unrecognized sparse index type" ); |
1079 | } |
1080 | |
1081 | return TypeFromFlatbuffer(sparse_tensor->type_type(), sparse_tensor->type(), {}, type); |
1082 | } |
1083 | |
1084 | // ---------------------------------------------------------------------- |
1085 | // Implement message writing |
1086 | |
1087 | Status WriteMessage(const Buffer& message, int32_t alignment, io::OutputStream* file, |
1088 | int32_t* message_length) { |
1089 | // ARROW-3212: We do not make assumptions that the output stream is aligned |
1090 | int32_t padded_message_length = static_cast<int32_t>(message.size()) + 4; |
1091 | const int32_t remainder = padded_message_length % alignment; |
1092 | if (remainder != 0) { |
1093 | padded_message_length += alignment - remainder; |
1094 | } |
1095 | |
1096 | // The returned message size includes the length prefix, the flatbuffer, |
1097 | // plus padding |
1098 | *message_length = padded_message_length; |
1099 | |
1100 | // Write the flatbuffer size prefix including padding |
1101 | int32_t flatbuffer_size = padded_message_length - 4; |
1102 | RETURN_NOT_OK(file->Write(&flatbuffer_size, sizeof(int32_t))); |
1103 | |
1104 | // Write the flatbuffer |
1105 | RETURN_NOT_OK(file->Write(message.data(), message.size())); |
1106 | |
1107 | // Write any padding |
1108 | int32_t padding = padded_message_length - static_cast<int32_t>(message.size()) - 4; |
1109 | if (padding > 0) { |
1110 | RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); |
1111 | } |
1112 | |
1113 | return Status::OK(); |
1114 | } |
1115 | |
1116 | } // namespace internal |
1117 | } // namespace ipc |
1118 | } // namespace arrow |
1119 | |