1 | #include "config_formats.h" |
2 | #if USE_PROTOBUF |
3 | |
4 | #include "ProtobufReader.h" |
5 | |
6 | #include <AggregateFunctions/IAggregateFunction.h> |
7 | #include <boost/numeric/conversion/cast.hpp> |
8 | #include <DataTypes/DataTypesDecimal.h> |
9 | #include <IO/ReadBufferFromString.h> |
10 | #include <IO/ReadHelpers.h> |
11 | #include <IO/WriteBufferFromVector.h> |
12 | #include <IO/WriteHelpers.h> |
13 | #include <optional> |
14 | |
15 | |
16 | namespace DB |
17 | { |
18 | namespace ErrorCodes |
19 | { |
20 | extern const int UNKNOWN_PROTOBUF_FORMAT; |
21 | extern const int PROTOBUF_BAD_CAST; |
22 | } |
23 | |
24 | |
25 | namespace |
26 | { |
27 | enum WireType |
28 | { |
29 | VARINT = 0, |
30 | BITS64 = 1, |
31 | LENGTH_DELIMITED = 2, |
32 | GROUP_START = 3, |
33 | GROUP_END = 4, |
34 | BITS32 = 5, |
35 | }; |
36 | |
37 | // The following condition must always be true: |
38 | // any_cursor_position < min(END_OF_VARINT, END_OF_GROUP) |
39 | // This inequation helps to check conditions in SimpleReader. |
40 | constexpr UInt64 END_OF_VARINT = static_cast<UInt64>(-1); |
41 | constexpr UInt64 END_OF_GROUP = static_cast<UInt64>(-2); |
42 | |
43 | Int64 decodeZigZag(UInt64 n) { return static_cast<Int64>((n >> 1) ^ (~(n & 1) + 1)); } |
44 | |
45 | [[noreturn]] void throwUnknownFormat() |
46 | { |
47 | throw Exception("Protobuf messages are corrupted or don't match the provided schema. Please note that Protobuf stream is length-delimited: every message is prefixed by its length in varint." , ErrorCodes::UNKNOWN_PROTOBUF_FORMAT); |
48 | } |
49 | } |
50 | |
51 | |
52 | // SimpleReader is an utility class to deserialize protobufs. |
53 | // Knows nothing about protobuf schemas, just provides useful functions to deserialize data. |
54 | ProtobufReader::SimpleReader::SimpleReader(ReadBuffer & in_) |
55 | : in(in_) |
56 | , cursor(0) |
57 | , current_message_level(0) |
58 | , current_message_end(0) |
59 | , field_end(0) |
60 | , last_string_pos(-1) |
61 | { |
62 | } |
63 | |
64 | bool ProtobufReader::SimpleReader::startMessage() |
65 | { |
66 | // Start reading a root message. |
67 | assert(!current_message_level); |
68 | if (unlikely(in.eof())) |
69 | return false; |
70 | size_t size_of_message = readVarint(); |
71 | current_message_end = cursor + size_of_message; |
72 | ++current_message_level; |
73 | field_end = cursor; |
74 | return true; |
75 | } |
76 | |
77 | void ProtobufReader::SimpleReader::endMessage(bool ignore_errors) |
78 | { |
79 | if (!current_message_level) |
80 | return; |
81 | |
82 | UInt64 root_message_end = (current_message_level == 1) ? current_message_end : parent_message_ends.front(); |
83 | if (cursor != root_message_end) |
84 | { |
85 | if (cursor < root_message_end) |
86 | ignore(root_message_end - cursor); |
87 | else if (ignore_errors) |
88 | moveCursorBackward(cursor - root_message_end); |
89 | else |
90 | throwUnknownFormat(); |
91 | } |
92 | |
93 | current_message_level = 0; |
94 | parent_message_ends.clear(); |
95 | } |
96 | |
97 | void ProtobufReader::SimpleReader::startNestedMessage() |
98 | { |
99 | assert(current_message_level >= 1); |
100 | // Start reading a nested message which is located inside a length-delimited field |
101 | // of another message. |
102 | parent_message_ends.emplace_back(current_message_end); |
103 | current_message_end = field_end; |
104 | ++current_message_level; |
105 | field_end = cursor; |
106 | } |
107 | |
108 | void ProtobufReader::SimpleReader::endNestedMessage() |
109 | { |
110 | assert(current_message_level >= 2); |
111 | if (cursor != current_message_end) |
112 | { |
113 | if (current_message_end == END_OF_GROUP) |
114 | { |
115 | ignoreGroup(); |
116 | current_message_end = cursor; |
117 | } |
118 | else if (cursor < current_message_end) |
119 | ignore(current_message_end - cursor); |
120 | else |
121 | throwUnknownFormat(); |
122 | } |
123 | |
124 | --current_message_level; |
125 | current_message_end = parent_message_ends.back(); |
126 | parent_message_ends.pop_back(); |
127 | field_end = cursor; |
128 | } |
129 | |
130 | bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number) |
131 | { |
132 | assert(current_message_level); |
133 | if (field_end != cursor) |
134 | { |
135 | if (field_end == END_OF_VARINT) |
136 | { |
137 | ignoreVarint(); |
138 | field_end = cursor; |
139 | } |
140 | else if (field_end == END_OF_GROUP) |
141 | { |
142 | ignoreGroup(); |
143 | field_end = cursor; |
144 | } |
145 | else if (cursor < field_end) |
146 | ignore(field_end - cursor); |
147 | else |
148 | throwUnknownFormat(); |
149 | } |
150 | |
151 | if (cursor >= current_message_end) |
152 | return false; |
153 | |
154 | UInt64 varint = readVarint(); |
155 | if (unlikely(varint & (static_cast<UInt64>(0xFFFFFFFF) << 32))) |
156 | throwUnknownFormat(); |
157 | UInt32 key = static_cast<UInt32>(varint); |
158 | field_number = (key >> 3); |
159 | WireType wire_type = static_cast<WireType>(key & 0x07); |
160 | switch (wire_type) |
161 | { |
162 | case BITS32: |
163 | { |
164 | field_end = cursor + 4; |
165 | return true; |
166 | } |
167 | case BITS64: |
168 | { |
169 | field_end = cursor + 8; |
170 | return true; |
171 | } |
172 | case LENGTH_DELIMITED: |
173 | { |
174 | size_t length = readVarint(); |
175 | field_end = cursor + length; |
176 | return true; |
177 | } |
178 | case VARINT: |
179 | { |
180 | field_end = END_OF_VARINT; |
181 | return true; |
182 | } |
183 | case GROUP_START: |
184 | { |
185 | field_end = END_OF_GROUP; |
186 | return true; |
187 | } |
188 | case GROUP_END: |
189 | { |
190 | if (current_message_end != END_OF_GROUP) |
191 | throwUnknownFormat(); |
192 | current_message_end = cursor; |
193 | return false; |
194 | } |
195 | } |
196 | throwUnknownFormat(); |
197 | } |
198 | |
199 | bool ProtobufReader::SimpleReader::readUInt(UInt64 & value) |
200 | { |
201 | if (unlikely(cursor >= field_end)) |
202 | return false; |
203 | value = readVarint(); |
204 | if (field_end == END_OF_VARINT) |
205 | field_end = cursor; |
206 | return true; |
207 | } |
208 | |
209 | bool ProtobufReader::SimpleReader::readInt(Int64 & value) |
210 | { |
211 | UInt64 varint; |
212 | if (!readUInt(varint)) |
213 | return false; |
214 | value = static_cast<Int64>(varint); |
215 | return true; |
216 | } |
217 | |
218 | bool ProtobufReader::SimpleReader::readSInt(Int64 & value) |
219 | { |
220 | UInt64 varint; |
221 | if (!readUInt(varint)) |
222 | return false; |
223 | value = decodeZigZag(varint); |
224 | return true; |
225 | } |
226 | |
227 | template<typename T> |
228 | bool ProtobufReader::SimpleReader::readFixed(T & value) |
229 | { |
230 | if (unlikely(cursor >= field_end)) |
231 | return false; |
232 | readBinary(&value, sizeof(T)); |
233 | return true; |
234 | } |
235 | |
236 | bool ProtobufReader::SimpleReader::readStringInto(PaddedPODArray<UInt8> & str) |
237 | { |
238 | if (unlikely(cursor == last_string_pos)) |
239 | return false; /// We don't want to read the same empty string again. |
240 | last_string_pos = cursor; |
241 | if (unlikely(cursor > field_end)) |
242 | throwUnknownFormat(); |
243 | size_t length = field_end - cursor; |
244 | size_t old_size = str.size(); |
245 | str.resize(old_size + length); |
246 | readBinary(reinterpret_cast<char*>(str.data() + old_size), length); |
247 | return true; |
248 | } |
249 | |
250 | void ProtobufReader::SimpleReader::readBinary(void* data, size_t size) |
251 | { |
252 | in.readStrict(reinterpret_cast<char*>(data), size); |
253 | cursor += size; |
254 | } |
255 | |
256 | void ProtobufReader::SimpleReader::ignore(UInt64 num_bytes) |
257 | { |
258 | in.ignore(num_bytes); |
259 | cursor += num_bytes; |
260 | } |
261 | |
262 | void ProtobufReader::SimpleReader::moveCursorBackward(UInt64 num_bytes) |
263 | { |
264 | if (in.offset() < num_bytes) |
265 | throwUnknownFormat(); |
266 | in.position() -= num_bytes; |
267 | cursor -= num_bytes; |
268 | } |
269 | |
270 | UInt64 ProtobufReader::SimpleReader::continueReadingVarint(UInt64 first_byte) |
271 | { |
272 | UInt64 result = (first_byte & ~static_cast<UInt64>(0x80)); |
273 | char c; |
274 | |
275 | #define PROTOBUF_READER_READ_VARINT_BYTE(byteNo) \ |
276 | in.readStrict(c); \ |
277 | ++cursor; \ |
278 | if constexpr (byteNo < 10) \ |
279 | { \ |
280 | result |= static_cast<UInt64>(static_cast<UInt8>(c)) << (7 * (byteNo - 1)); \ |
281 | if (likely(!(c & 0x80))) \ |
282 | return result; \ |
283 | } \ |
284 | else \ |
285 | { \ |
286 | if (likely(c == 1)) \ |
287 | return result; \ |
288 | } \ |
289 | if constexpr (byteNo < 9) \ |
290 | result &= ~(static_cast<UInt64>(0x80) << (7 * (byteNo - 1))); |
291 | PROTOBUF_READER_READ_VARINT_BYTE(2) |
292 | PROTOBUF_READER_READ_VARINT_BYTE(3) |
293 | PROTOBUF_READER_READ_VARINT_BYTE(4) |
294 | PROTOBUF_READER_READ_VARINT_BYTE(5) |
295 | PROTOBUF_READER_READ_VARINT_BYTE(6) |
296 | PROTOBUF_READER_READ_VARINT_BYTE(7) |
297 | PROTOBUF_READER_READ_VARINT_BYTE(8) |
298 | PROTOBUF_READER_READ_VARINT_BYTE(9) |
299 | PROTOBUF_READER_READ_VARINT_BYTE(10) |
300 | #undef PROTOBUF_READER_READ_VARINT_BYTE |
301 | |
302 | throwUnknownFormat(); |
303 | } |
304 | |
305 | void ProtobufReader::SimpleReader::ignoreVarint() |
306 | { |
307 | char c; |
308 | |
309 | #define PROTOBUF_READER_IGNORE_VARINT_BYTE(byteNo) \ |
310 | in.readStrict(c); \ |
311 | ++cursor; \ |
312 | if constexpr (byteNo < 10) \ |
313 | { \ |
314 | if (likely(!(c & 0x80))) \ |
315 | return; \ |
316 | } \ |
317 | else \ |
318 | { \ |
319 | if (likely(c == 1)) \ |
320 | return; \ |
321 | } |
322 | PROTOBUF_READER_IGNORE_VARINT_BYTE(1) |
323 | PROTOBUF_READER_IGNORE_VARINT_BYTE(2) |
324 | PROTOBUF_READER_IGNORE_VARINT_BYTE(3) |
325 | PROTOBUF_READER_IGNORE_VARINT_BYTE(4) |
326 | PROTOBUF_READER_IGNORE_VARINT_BYTE(5) |
327 | PROTOBUF_READER_IGNORE_VARINT_BYTE(6) |
328 | PROTOBUF_READER_IGNORE_VARINT_BYTE(7) |
329 | PROTOBUF_READER_IGNORE_VARINT_BYTE(8) |
330 | PROTOBUF_READER_IGNORE_VARINT_BYTE(9) |
331 | PROTOBUF_READER_IGNORE_VARINT_BYTE(10) |
332 | #undef PROTOBUF_READER_IGNORE_VARINT_BYTE |
333 | |
334 | throwUnknownFormat(); |
335 | } |
336 | |
337 | void ProtobufReader::SimpleReader::ignoreGroup() |
338 | { |
339 | size_t level = 1; |
340 | while (true) |
341 | { |
342 | UInt64 varint = readVarint(); |
343 | WireType wire_type = static_cast<WireType>(varint & 0x07); |
344 | switch (wire_type) |
345 | { |
346 | case VARINT: |
347 | { |
348 | ignoreVarint(); |
349 | break; |
350 | } |
351 | case BITS64: |
352 | { |
353 | ignore(8); |
354 | break; |
355 | } |
356 | case LENGTH_DELIMITED: |
357 | { |
358 | ignore(readVarint()); |
359 | break; |
360 | } |
361 | case GROUP_START: |
362 | { |
363 | ++level; |
364 | break; |
365 | } |
366 | case GROUP_END: |
367 | { |
368 | if (!--level) |
369 | return; |
370 | break; |
371 | } |
372 | case BITS32: |
373 | { |
374 | ignore(4); |
375 | break; |
376 | } |
377 | } |
378 | throwUnknownFormat(); |
379 | } |
380 | } |
381 | |
382 | // Implementation for a converter from any protobuf field type to any DB data type. |
383 | class ProtobufReader::ConverterBaseImpl : public ProtobufReader::IConverter |
384 | { |
385 | public: |
386 | ConverterBaseImpl(SimpleReader & simple_reader_, const google::protobuf::FieldDescriptor * field_) |
387 | : simple_reader(simple_reader_), field(field_) {} |
388 | |
389 | bool readStringInto(PaddedPODArray<UInt8> &) override |
390 | { |
391 | cannotConvertType("String" ); |
392 | } |
393 | |
394 | bool readInt8(Int8 &) override |
395 | { |
396 | cannotConvertType("Int8" ); |
397 | } |
398 | |
399 | bool readUInt8(UInt8 &) override |
400 | { |
401 | cannotConvertType("UInt8" ); |
402 | } |
403 | |
404 | bool readInt16(Int16 &) override |
405 | { |
406 | cannotConvertType("Int16" ); |
407 | } |
408 | |
409 | bool readUInt16(UInt16 &) override |
410 | { |
411 | cannotConvertType("UInt16" ); |
412 | } |
413 | |
414 | bool readInt32(Int32 &) override |
415 | { |
416 | cannotConvertType("Int32" ); |
417 | } |
418 | |
419 | bool readUInt32(UInt32 &) override |
420 | { |
421 | cannotConvertType("UInt32" ); |
422 | } |
423 | |
424 | bool readInt64(Int64 &) override |
425 | { |
426 | cannotConvertType("Int64" ); |
427 | } |
428 | |
429 | bool readUInt64(UInt64 &) override |
430 | { |
431 | cannotConvertType("UInt64" ); |
432 | } |
433 | |
434 | bool readUInt128(UInt128 &) override |
435 | { |
436 | cannotConvertType("UInt128" ); |
437 | } |
438 | |
439 | bool readFloat32(Float32 &) override |
440 | { |
441 | cannotConvertType("Float32" ); |
442 | } |
443 | |
444 | bool readFloat64(Float64 &) override |
445 | { |
446 | cannotConvertType("Float64" ); |
447 | } |
448 | |
449 | void prepareEnumMapping8(const std::vector<std::pair<std::string, Int8>> &) override {} |
450 | void prepareEnumMapping16(const std::vector<std::pair<std::string, Int16>> &) override {} |
451 | |
452 | bool readEnum8(Int8 &) override |
453 | { |
454 | cannotConvertType("Enum" ); |
455 | } |
456 | |
457 | bool readEnum16(Int16 &) override |
458 | { |
459 | cannotConvertType("Enum" ); |
460 | } |
461 | |
462 | bool readUUID(UUID &) override |
463 | { |
464 | cannotConvertType("UUID" ); |
465 | } |
466 | |
467 | bool readDate(DayNum &) override |
468 | { |
469 | cannotConvertType("Date" ); |
470 | } |
471 | |
472 | bool readDateTime(time_t &) override |
473 | { |
474 | cannotConvertType("DateTime" ); |
475 | } |
476 | |
477 | bool readDateTime64(DateTime64 &, UInt32) override |
478 | { |
479 | cannotConvertType("DateTime64" ); |
480 | } |
481 | |
482 | bool readDecimal32(Decimal32 &, UInt32, UInt32) override |
483 | { |
484 | cannotConvertType("Decimal32" ); |
485 | } |
486 | |
487 | bool readDecimal64(Decimal64 &, UInt32, UInt32) override |
488 | { |
489 | cannotConvertType("Decimal64" ); |
490 | } |
491 | |
492 | bool readDecimal128(Decimal128 &, UInt32, UInt32) override |
493 | { |
494 | cannotConvertType("Decimal128" ); |
495 | } |
496 | |
497 | bool readAggregateFunction(const AggregateFunctionPtr &, AggregateDataPtr, Arena &) override |
498 | { |
499 | cannotConvertType("AggregateFunction" ); |
500 | } |
501 | |
502 | protected: |
503 | [[noreturn]] void cannotConvertType(const String & type_name) |
504 | { |
505 | throw Exception( |
506 | String("Could not convert type '" ) + field->type_name() + "' from protobuf field '" + field->name() + "' to data type '" |
507 | + type_name + "'" , |
508 | ErrorCodes::PROTOBUF_BAD_CAST); |
509 | } |
510 | |
511 | [[noreturn]] void cannotConvertValue(const String & value, const String & type_name) |
512 | { |
513 | throw Exception( |
514 | "Could not convert value '" + value + "' from protobuf field '" + field->name() + "' to data type '" + type_name + "'" , |
515 | ErrorCodes::PROTOBUF_BAD_CAST); |
516 | } |
517 | |
518 | template <typename To, typename From> |
519 | To numericCast(From value) |
520 | { |
521 | if constexpr (std::is_same_v<To, From>) |
522 | return value; |
523 | To result; |
524 | try |
525 | { |
526 | result = boost::numeric_cast<To>(value); |
527 | } |
528 | catch (boost::numeric::bad_numeric_cast &) |
529 | { |
530 | cannotConvertValue(toString(value), TypeName<To>::get()); |
531 | } |
532 | return result; |
533 | } |
534 | |
535 | template <typename To> |
536 | To parseFromString(const PaddedPODArray<UInt8> & str) |
537 | { |
538 | try |
539 | { |
540 | To result; |
541 | ReadBufferFromString buf(str); |
542 | readText(result, buf); |
543 | return result; |
544 | } |
545 | catch (...) |
546 | { |
547 | cannotConvertValue(StringRef(str.data(), str.size()).toString(), TypeName<To>::get()); |
548 | } |
549 | } |
550 | |
551 | SimpleReader & simple_reader; |
552 | const google::protobuf::FieldDescriptor * field; |
553 | }; |
554 | |
555 | |
556 | |
557 | class ProtobufReader::ConverterFromString : public ConverterBaseImpl |
558 | { |
559 | public: |
560 | using ConverterBaseImpl::ConverterBaseImpl; |
561 | |
562 | bool readStringInto(PaddedPODArray<UInt8> & str) override { return simple_reader.readStringInto(str); } |
563 | |
564 | bool readInt8(Int8 & value) override { return readNumeric(value); } |
565 | bool readUInt8(UInt8 & value) override { return readNumeric(value); } |
566 | bool readInt16(Int16 & value) override { return readNumeric(value); } |
567 | bool readUInt16(UInt16 & value) override { return readNumeric(value); } |
568 | bool readInt32(Int32 & value) override { return readNumeric(value); } |
569 | bool readUInt32(UInt32 & value) override { return readNumeric(value); } |
570 | bool readInt64(Int64 & value) override { return readNumeric(value); } |
571 | bool readUInt64(UInt64 & value) override { return readNumeric(value); } |
572 | bool readFloat32(Float32 & value) override { return readNumeric(value); } |
573 | bool readFloat64(Float64 & value) override { return readNumeric(value); } |
574 | |
575 | void prepareEnumMapping8(const std::vector<std::pair<String, Int8>> & name_value_pairs) override |
576 | { |
577 | prepareEnumNameToValueMap(name_value_pairs); |
578 | } |
579 | void prepareEnumMapping16(const std::vector<std::pair<String, Int16>> & name_value_pairs) override |
580 | { |
581 | prepareEnumNameToValueMap(name_value_pairs); |
582 | } |
583 | |
584 | bool readEnum8(Int8 & value) override { return readEnum(value); } |
585 | bool readEnum16(Int16 & value) override { return readEnum(value); } |
586 | |
587 | bool readUUID(UUID & uuid) override |
588 | { |
589 | if (!readTempString()) |
590 | return false; |
591 | ReadBufferFromString buf(temp_string); |
592 | readUUIDText(uuid, buf); |
593 | return true; |
594 | } |
595 | |
596 | bool readDate(DayNum & date) override |
597 | { |
598 | if (!readTempString()) |
599 | return false; |
600 | ReadBufferFromString buf(temp_string); |
601 | readDateText(date, buf); |
602 | return true; |
603 | } |
604 | |
605 | bool readDateTime(time_t & tm) override |
606 | { |
607 | if (!readTempString()) |
608 | return false; |
609 | ReadBufferFromString buf(temp_string); |
610 | readDateTimeText(tm, buf); |
611 | return true; |
612 | } |
613 | |
614 | bool readDateTime64(DateTime64 & date_time, UInt32 scale) override |
615 | { |
616 | if (!readTempString()) |
617 | return false; |
618 | ReadBufferFromString buf(temp_string); |
619 | readDateTime64Text(date_time, scale, buf); |
620 | return true; |
621 | } |
622 | |
623 | bool readDecimal32(Decimal32 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } |
624 | bool readDecimal64(Decimal64 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } |
625 | bool readDecimal128(Decimal128 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } |
626 | |
627 | bool readAggregateFunction(const AggregateFunctionPtr & function, AggregateDataPtr place, Arena & arena) override |
628 | { |
629 | if (!readTempString()) |
630 | return false; |
631 | ReadBufferFromString buf(temp_string); |
632 | function->deserialize(place, buf, &arena); |
633 | return true; |
634 | } |
635 | |
636 | private: |
637 | bool readTempString() |
638 | { |
639 | temp_string.clear(); |
640 | return simple_reader.readStringInto(temp_string); |
641 | } |
642 | |
643 | template <typename T> |
644 | bool readNumeric(T & value) |
645 | { |
646 | if (!readTempString()) |
647 | return false; |
648 | value = parseFromString<T>(temp_string); |
649 | return true; |
650 | } |
651 | |
652 | template<typename T> |
653 | bool readEnum(T & value) |
654 | { |
655 | if (!readTempString()) |
656 | return false; |
657 | StringRef ref(temp_string.data(), temp_string.size()); |
658 | auto it = enum_name_to_value_map->find(ref); |
659 | if (it == enum_name_to_value_map->end()) |
660 | cannotConvertValue(ref.toString(), "Enum" ); |
661 | value = static_cast<T>(it->second); |
662 | return true; |
663 | } |
664 | |
665 | template <typename T> |
666 | bool readDecimal(Decimal<T> & decimal, UInt32 precision, UInt32 scale) |
667 | { |
668 | if (!readTempString()) |
669 | return false; |
670 | ReadBufferFromString buf(temp_string); |
671 | DataTypeDecimal<Decimal<T>>::readText(decimal, buf, precision, scale); |
672 | return true; |
673 | } |
674 | |
675 | template <typename T> |
676 | void prepareEnumNameToValueMap(const std::vector<std::pair<String, T>> & name_value_pairs) |
677 | { |
678 | if (likely(enum_name_to_value_map.has_value())) |
679 | return; |
680 | enum_name_to_value_map.emplace(); |
681 | for (const auto & name_value_pair : name_value_pairs) |
682 | enum_name_to_value_map->emplace(name_value_pair.first, name_value_pair.second); |
683 | } |
684 | |
685 | PaddedPODArray<UInt8> temp_string; |
686 | std::optional<std::unordered_map<StringRef, Int16>> enum_name_to_value_map; |
687 | }; |
688 | |
689 | #define PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(field_type_id) \ |
690 | template <> \ |
691 | std::unique_ptr<ProtobufReader::IConverter> ProtobufReader::createConverter<field_type_id>( \ |
692 | const google::protobuf::FieldDescriptor * field) \ |
693 | { \ |
694 | return std::make_unique<ConverterFromString>(simple_reader, field); \ |
695 | } |
696 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_STRING) |
697 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_BYTES) |
698 | #undef PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS |
699 | |
700 | |
701 | template <int field_type_id, typename FromType> |
702 | class ProtobufReader::ConverterFromNumber : public ConverterBaseImpl |
703 | { |
704 | public: |
705 | using ConverterBaseImpl::ConverterBaseImpl; |
706 | |
707 | bool readStringInto(PaddedPODArray<UInt8> & str) override |
708 | { |
709 | FromType number; |
710 | if (!readField(number)) |
711 | return false; |
712 | WriteBufferFromVector<PaddedPODArray<UInt8>> buf(str); |
713 | writeText(number, buf); |
714 | return true; |
715 | } |
716 | |
717 | bool readInt8(Int8 & value) override { return readNumeric(value); } |
718 | bool readUInt8(UInt8 & value) override { return readNumeric(value); } |
719 | bool readInt16(Int16 & value) override { return readNumeric(value); } |
720 | bool readUInt16(UInt16 & value) override { return readNumeric(value); } |
721 | bool readInt32(Int32 & value) override { return readNumeric(value); } |
722 | bool readUInt32(UInt32 & value) override { return readNumeric(value); } |
723 | bool readInt64(Int64 & value) override { return readNumeric(value); } |
724 | bool readUInt64(UInt64 & value) override { return readNumeric(value); } |
725 | bool readFloat32(Float32 & value) override { return readNumeric(value); } |
726 | bool readFloat64(Float64 & value) override { return readNumeric(value); } |
727 | |
728 | bool readEnum8(Int8 & value) override { return readEnum(value); } |
729 | bool readEnum16(Int16 & value) override { return readEnum(value); } |
730 | |
731 | void prepareEnumMapping8(const std::vector<std::pair<String, Int8>> & name_value_pairs) override |
732 | { |
733 | prepareSetOfEnumValues(name_value_pairs); |
734 | } |
735 | void prepareEnumMapping16(const std::vector<std::pair<String, Int16>> & name_value_pairs) override |
736 | { |
737 | prepareSetOfEnumValues(name_value_pairs); |
738 | } |
739 | |
740 | bool readDate(DayNum & date) override |
741 | { |
742 | UInt16 number; |
743 | if (!readNumeric(number)) |
744 | return false; |
745 | date = DayNum(number); |
746 | return true; |
747 | } |
748 | |
749 | bool readDateTime(time_t & tm) override |
750 | { |
751 | UInt32 number; |
752 | if (!readNumeric(number)) |
753 | return false; |
754 | tm = number; |
755 | return true; |
756 | } |
757 | |
758 | bool readDateTime64(DateTime64 & date_time, UInt32 scale) override |
759 | { |
760 | return readDecimal(date_time, scale); |
761 | } |
762 | |
763 | bool readDecimal32(Decimal32 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } |
764 | bool readDecimal64(Decimal64 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } |
765 | bool readDecimal128(Decimal128 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } |
766 | |
767 | private: |
768 | template <typename To> |
769 | bool readNumeric(To & value) |
770 | { |
771 | FromType number; |
772 | if (!readField(number)) |
773 | return false; |
774 | value = numericCast<To>(number); |
775 | return true; |
776 | } |
777 | |
778 | template<typename EnumType> |
779 | bool readEnum(EnumType & value) |
780 | { |
781 | if constexpr (!is_integral_v<FromType>) |
782 | cannotConvertType("Enum" ); // It's not correct to convert floating point to enum. |
783 | FromType number; |
784 | if (!readField(number)) |
785 | return false; |
786 | value = numericCast<EnumType>(number); |
787 | if (set_of_enum_values->find(value) == set_of_enum_values->end()) |
788 | cannotConvertValue(toString(value), "Enum" ); |
789 | return true; |
790 | } |
791 | |
792 | template<typename EnumType> |
793 | void prepareSetOfEnumValues(const std::vector<std::pair<String, EnumType>> & name_value_pairs) |
794 | { |
795 | if (likely(set_of_enum_values.has_value())) |
796 | return; |
797 | set_of_enum_values.emplace(); |
798 | for (const auto & name_value_pair : name_value_pairs) |
799 | set_of_enum_values->emplace(name_value_pair.second); |
800 | } |
801 | |
802 | template <typename S> |
803 | bool readDecimal(Decimal<S> & decimal, UInt32 scale) |
804 | { |
805 | FromType number; |
806 | if (!readField(number)) |
807 | return false; |
808 | decimal.value = convertToDecimal<DataTypeNumber<FromType>, DataTypeDecimal<Decimal<S>>>(number, scale); |
809 | return true; |
810 | } |
811 | |
812 | bool readField(FromType & value) |
813 | { |
814 | if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT32) && std::is_same_v<FromType, Int64>) |
815 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT64) && std::is_same_v<FromType, Int64>)) |
816 | { |
817 | return simple_reader.readInt(value); |
818 | } |
819 | else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT32) && std::is_same_v<FromType, UInt64>) |
820 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT64) && std::is_same_v<FromType, UInt64>)) |
821 | { |
822 | return simple_reader.readUInt(value); |
823 | } |
824 | |
825 | else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT32) && std::is_same_v<FromType, Int64>) |
826 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT64) && std::is_same_v<FromType, Int64>)) |
827 | { |
828 | return simple_reader.readSInt(value); |
829 | } |
830 | else |
831 | { |
832 | static_assert(((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED32) && std::is_same_v<FromType, UInt32>) |
833 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED32) && std::is_same_v<FromType, Int32>) |
834 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED64) && std::is_same_v<FromType, UInt64>) |
835 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED64) && std::is_same_v<FromType, Int64>) |
836 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FLOAT) && std::is_same_v<FromType, float>) |
837 | || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_DOUBLE) && std::is_same_v<FromType, double>)); |
838 | return simple_reader.readFixed(value); |
839 | } |
840 | } |
841 | |
842 | std::optional<std::unordered_set<Int16>> set_of_enum_values; |
843 | }; |
844 | |
845 | #define PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(field_type_id, field_type) \ |
846 | template <> \ |
847 | std::unique_ptr<ProtobufReader::IConverter> ProtobufReader::createConverter<field_type_id>( \ |
848 | const google::protobuf::FieldDescriptor * field) \ |
849 | { \ |
850 | return std::make_unique<ConverterFromNumber<field_type_id, field_type>>(simple_reader, field); \ |
851 | } |
852 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT32, Int64); |
853 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT32, Int64); |
854 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT32, UInt64); |
855 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT64, Int64); |
856 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT64, Int64); |
857 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT64, UInt64); |
858 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED32, UInt32); |
859 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED32, Int32); |
860 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED64, UInt64); |
861 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED64, Int64); |
862 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FLOAT, float); |
863 | PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_DOUBLE, double); |
864 | #undef PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS |
865 | |
866 | |
867 | |
868 | class ProtobufReader::ConverterFromBool : public ConverterBaseImpl |
869 | { |
870 | public: |
871 | using ConverterBaseImpl::ConverterBaseImpl; |
872 | |
873 | bool readStringInto(PaddedPODArray<UInt8> & str) override |
874 | { |
875 | bool b; |
876 | if (!readField(b)) |
877 | return false; |
878 | StringRef ref(b ? "true" : "false" ); |
879 | str.insert(ref.data, ref.data + ref.size); |
880 | return true; |
881 | } |
882 | |
883 | bool readInt8(Int8 & value) override { return readNumeric(value); } |
884 | bool readUInt8(UInt8 & value) override { return readNumeric(value); } |
885 | bool readInt16(Int16 & value) override { return readNumeric(value); } |
886 | bool readUInt16(UInt16 & value) override { return readNumeric(value); } |
887 | bool readInt32(Int32 & value) override { return readNumeric(value); } |
888 | bool readUInt32(UInt32 & value) override { return readNumeric(value); } |
889 | bool readInt64(Int64 & value) override { return readNumeric(value); } |
890 | bool readUInt64(UInt64 & value) override { return readNumeric(value); } |
891 | bool readFloat32(Float32 & value) override { return readNumeric(value); } |
892 | bool readFloat64(Float64 & value) override { return readNumeric(value); } |
893 | bool readDecimal32(Decimal32 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } |
894 | bool readDecimal64(Decimal64 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } |
895 | bool readDecimal128(Decimal128 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } |
896 | |
897 | private: |
898 | template<typename T> |
899 | bool readNumeric(T & value) |
900 | { |
901 | bool b; |
902 | if (!readField(b)) |
903 | return false; |
904 | value = b ? 1 : 0; |
905 | return true; |
906 | } |
907 | |
908 | bool readField(bool & b) |
909 | { |
910 | UInt64 number; |
911 | if (!simple_reader.readUInt(number)) |
912 | return false; |
913 | b = static_cast<bool>(number); |
914 | return true; |
915 | } |
916 | }; |
917 | |
918 | template <> |
919 | std::unique_ptr<ProtobufReader::IConverter> ProtobufReader::createConverter<google::protobuf::FieldDescriptor::TYPE_BOOL>( |
920 | const google::protobuf::FieldDescriptor * field) |
921 | { |
922 | return std::make_unique<ConverterFromBool>(simple_reader, field); |
923 | } |
924 | |
925 | |
926 | class ProtobufReader:: : public ConverterBaseImpl |
927 | { |
928 | public: |
929 | using ConverterBaseImpl::ConverterBaseImpl; |
930 | |
931 | bool (PaddedPODArray<UInt8> & str) override |
932 | { |
933 | prepareEnumPbNumberToNameMap(); |
934 | Int64 pbnumber; |
935 | if (!readField(pbnumber)) |
936 | return false; |
937 | auto it = enum_pbnumber_to_name_map->find(pbnumber); |
938 | if (it == enum_pbnumber_to_name_map->end()) |
939 | cannotConvertValue(toString(pbnumber), "Enum" ); |
940 | const auto & ref = it->second; |
941 | str.insert(ref.data, ref.data + ref.size); |
942 | return true; |
943 | } |
944 | |
945 | bool (Int8 & value) override { return readNumeric(value); } |
946 | bool (UInt8 & value) override { return readNumeric(value); } |
947 | bool (Int16 & value) override { return readNumeric(value); } |
948 | bool (UInt16 & value) override { return readNumeric(value); } |
949 | bool (Int32 & value) override { return readNumeric(value); } |
950 | bool (UInt32 & value) override { return readNumeric(value); } |
951 | bool (Int64 & value) override { return readNumeric(value); } |
952 | bool (UInt64 & value) override { return readNumeric(value); } |
953 | |
954 | void (const std::vector<std::pair<String, Int8>> & name_value_pairs) override |
955 | { |
956 | prepareEnumPbNumberToValueMap(name_value_pairs); |
957 | } |
958 | void (const std::vector<std::pair<String, Int16>> & name_value_pairs) override |
959 | { |
960 | prepareEnumPbNumberToValueMap(name_value_pairs); |
961 | } |
962 | |
963 | bool (Int8 & value) override { return readEnum(value); } |
964 | bool (Int16 & value) override { return readEnum(value); } |
965 | |
966 | private: |
967 | template <typename T> |
968 | bool (T & value) |
969 | { |
970 | Int64 pbnumber; |
971 | if (!readField(pbnumber)) |
972 | return false; |
973 | value = numericCast<T>(pbnumber); |
974 | return true; |
975 | } |
976 | |
977 | template<typename T> |
978 | bool (T & value) |
979 | { |
980 | Int64 pbnumber; |
981 | if (!readField(pbnumber)) |
982 | return false; |
983 | if (enum_pbnumber_always_equals_value) |
984 | value = static_cast<T>(pbnumber); |
985 | else |
986 | { |
987 | auto it = enum_pbnumber_to_value_map->find(pbnumber); |
988 | if (it == enum_pbnumber_to_value_map->end()) |
989 | cannotConvertValue(toString(pbnumber), "Enum" ); |
990 | value = static_cast<T>(it->second); |
991 | } |
992 | return true; |
993 | } |
994 | |
995 | void () |
996 | { |
997 | if (likely(enum_pbnumber_to_name_map.has_value())) |
998 | return; |
999 | enum_pbnumber_to_name_map.emplace(); |
1000 | const auto * enum_type = field->enum_type(); |
1001 | for (int i = 0; i != enum_type->value_count(); ++i) |
1002 | { |
1003 | const auto * enum_value = enum_type->value(i); |
1004 | enum_pbnumber_to_name_map->emplace(enum_value->number(), enum_value->name()); |
1005 | } |
1006 | } |
1007 | |
1008 | template <typename T> |
1009 | void (const std::vector<std::pair<String, T>> & name_value_pairs) |
1010 | { |
1011 | if (likely(enum_pbnumber_to_value_map.has_value())) |
1012 | return; |
1013 | enum_pbnumber_to_value_map.emplace(); |
1014 | enum_pbnumber_always_equals_value = true; |
1015 | for (const auto & name_value_pair : name_value_pairs) |
1016 | { |
1017 | Int16 value = name_value_pair.second; |
1018 | const auto * enum_descriptor = field->enum_type()->FindValueByName(name_value_pair.first); |
1019 | if (enum_descriptor) |
1020 | { |
1021 | enum_pbnumber_to_value_map->emplace(enum_descriptor->number(), value); |
1022 | if (enum_descriptor->number() != value) |
1023 | enum_pbnumber_always_equals_value = false; |
1024 | } |
1025 | else |
1026 | enum_pbnumber_always_equals_value = false; |
1027 | } |
1028 | } |
1029 | |
1030 | bool (Int64 & enum_pbnumber) |
1031 | { |
1032 | return simple_reader.readInt(enum_pbnumber); |
1033 | } |
1034 | |
1035 | std::optional<std::unordered_map<Int64, StringRef>> ; |
1036 | std::optional<std::unordered_map<Int64, Int16>> ; |
1037 | bool ; |
1038 | }; |
1039 | |
1040 | template <> |
1041 | std::unique_ptr<ProtobufReader::IConverter> ProtobufReader::createConverter<google::protobuf::FieldDescriptor::TYPE_ENUM>( |
1042 | const google::protobuf::FieldDescriptor * field) |
1043 | { |
1044 | return std::make_unique<ConverterFromEnum>(simple_reader, field); |
1045 | } |
1046 | |
1047 | |
1048 | ProtobufReader::ProtobufReader( |
1049 | ReadBuffer & in_, const google::protobuf::Descriptor * message_type, const std::vector<String> & column_names) |
1050 | : simple_reader(in_) |
1051 | { |
1052 | root_message = ProtobufColumnMatcher::matchColumns<ColumnMatcherTraits>(column_names, message_type); |
1053 | setTraitsDataAfterMatchingColumns(root_message.get()); |
1054 | } |
1055 | |
1056 | ProtobufReader::~ProtobufReader() = default; |
1057 | |
1058 | void ProtobufReader::setTraitsDataAfterMatchingColumns(Message * message) |
1059 | { |
1060 | for (Field & field : message->fields) |
1061 | { |
1062 | if (field.nested_message) |
1063 | { |
1064 | setTraitsDataAfterMatchingColumns(field.nested_message.get()); |
1065 | continue; |
1066 | } |
1067 | switch (field.field_descriptor->type()) |
1068 | { |
1069 | #define PROTOBUF_READER_CONVERTER_CREATING_CASE(field_type_id) \ |
1070 | case field_type_id: \ |
1071 | field.data.converter = createConverter<field_type_id>(field.field_descriptor); \ |
1072 | break |
1073 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_STRING); |
1074 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BYTES); |
1075 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT32); |
1076 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT32); |
1077 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT32); |
1078 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED32); |
1079 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED32); |
1080 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT64); |
1081 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT64); |
1082 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT64); |
1083 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED64); |
1084 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED64); |
1085 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FLOAT); |
1086 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_DOUBLE); |
1087 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BOOL); |
1088 | PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_ENUM); |
1089 | #undef PROTOBUF_READER_CONVERTER_CREATING_CASE |
1090 | default: __builtin_unreachable(); |
1091 | } |
1092 | message->data.field_number_to_field_map.emplace(field.field_number, &field); |
1093 | } |
1094 | } |
1095 | |
1096 | bool ProtobufReader::startMessage() |
1097 | { |
1098 | if (!simple_reader.startMessage()) |
1099 | return false; |
1100 | current_message = root_message.get(); |
1101 | current_field_index = 0; |
1102 | return true; |
1103 | } |
1104 | |
1105 | void ProtobufReader::endMessage(bool try_ignore_errors) |
1106 | { |
1107 | simple_reader.endMessage(try_ignore_errors); |
1108 | current_message = nullptr; |
1109 | current_converter = nullptr; |
1110 | } |
1111 | |
1112 | bool ProtobufReader::readColumnIndex(size_t & column_index) |
1113 | { |
1114 | while (true) |
1115 | { |
1116 | UInt32 field_number; |
1117 | if (!simple_reader.readFieldNumber(field_number)) |
1118 | { |
1119 | if (!current_message->parent) |
1120 | { |
1121 | current_converter = nullptr; |
1122 | return false; |
1123 | } |
1124 | simple_reader.endNestedMessage(); |
1125 | current_field_index = current_message->index_in_parent; |
1126 | current_message = current_message->parent; |
1127 | continue; |
1128 | } |
1129 | |
1130 | const Field * field = nullptr; |
1131 | for (; current_field_index < current_message->fields.size(); ++current_field_index) |
1132 | { |
1133 | const Field & f = current_message->fields[current_field_index]; |
1134 | if (f.field_number == field_number) |
1135 | { |
1136 | field = &f; |
1137 | break; |
1138 | } |
1139 | if (f.field_number > field_number) |
1140 | break; |
1141 | } |
1142 | |
1143 | if (!field) |
1144 | { |
1145 | const auto & field_number_to_field_map = current_message->data.field_number_to_field_map; |
1146 | auto it = field_number_to_field_map.find(field_number); |
1147 | if (it == field_number_to_field_map.end()) |
1148 | continue; |
1149 | field = it->second; |
1150 | } |
1151 | |
1152 | if (field->nested_message) |
1153 | { |
1154 | simple_reader.startNestedMessage(); |
1155 | current_message = field->nested_message.get(); |
1156 | current_field_index = 0; |
1157 | continue; |
1158 | } |
1159 | |
1160 | column_index = field->column_index; |
1161 | current_converter = field->data.converter.get(); |
1162 | return true; |
1163 | } |
1164 | } |
1165 | |
1166 | } |
1167 | #endif |
1168 | |