1#include "config_formats.h"
2#if USE_PROTOBUF
3
4#include "ProtobufWriter.h"
5
6#include <cassert>
7#include <optional>
8#include <math.h>
9#include <AggregateFunctions/IAggregateFunction.h>
10#include <DataTypes/DataTypesDecimal.h>
11#include <boost/numeric/conversion/cast.hpp>
12#include <google/protobuf/descriptor.h>
13#include <google/protobuf/descriptor.pb.h>
14#include <IO/ReadHelpers.h>
15#include <IO/WriteHelpers.h>
16
17
18namespace DB
19{
20namespace ErrorCodes
21{
22 extern const int NOT_IMPLEMENTED;
23 extern const int NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD;
24 extern const int PROTOBUF_BAD_CAST;
25 extern const int PROTOBUF_FIELD_NOT_REPEATED;
26}
27
28
29namespace
30{
31 constexpr size_t MAX_VARINT_SIZE = 10;
32 constexpr size_t REPEATED_PACK_PADDING = 2 * MAX_VARINT_SIZE;
33 constexpr size_t NESTED_MESSAGE_PADDING = 2 * MAX_VARINT_SIZE;
34
35 // Note: There is a difference between this function and writeVarUInt() from IO/VarInt.h:
36 // Google protobuf's representation of 64-bit integer contains from 1 to 10 bytes,
37 // whileas writeVarUInt() writes from 1 to 9 bytes because it omits the tenth byte (which is not necessary to decode actually).
38 void writeVarint(UInt64 value, WriteBuffer & out)
39 {
40 while (value >= 0x80)
41 {
42 out.write(static_cast<char>(value | 0x80));
43 value >>= 7;
44 }
45 out.write(static_cast<char>(value));
46 }
47
48 UInt8 * writeVarint(UInt64 value, UInt8 * ptr)
49 {
50 while (value >= 0x80)
51 {
52 *ptr++ = static_cast<UInt8>(value | 0x80);
53 value >>= 7;
54 }
55 *ptr++ = static_cast<UInt8>(value);
56 return ptr;
57 }
58
59 void writeVarint(UInt64 value, PODArray<UInt8> & buf)
60 {
61 size_t old_size = buf.size();
62 buf.reserve(old_size + MAX_VARINT_SIZE);
63 UInt8 * ptr = buf.data() + old_size;
64 ptr = writeVarint(value, ptr);
65 buf.resize_assume_reserved(ptr - buf.data());
66 }
67
68 UInt64 encodeZigZag(Int64 value) { return (static_cast<UInt64>(value) << 1) ^ static_cast<UInt64>(value >> 63); }
69
70 enum WireType
71 {
72 VARINT = 0,
73 BITS64 = 1,
74 LENGTH_DELIMITED = 2,
75 GROUP_START = 3,
76 GROUP_END = 4,
77 BITS32 = 5
78 };
79
80 UInt8 * writeFieldNumber(UInt32 field_number, WireType wire_type, UInt8 * ptr)
81 {
82 return writeVarint((field_number << 3) | wire_type, ptr);
83 }
84
85 void writeFieldNumber(UInt32 field_number, WireType wire_type, PODArray<UInt8> & buf) { writeVarint((field_number << 3) | wire_type, buf); }
86
87 // Should we pack repeated values while storing them.
88 // It depends on type of the field in the protobuf schema and the syntax of that schema.
89 bool shouldPackRepeated(const google::protobuf::FieldDescriptor * field)
90 {
91 if (!field->is_repeated())
92 return false;
93 switch (field->type())
94 {
95 case google::protobuf::FieldDescriptor::TYPE_INT32:
96 case google::protobuf::FieldDescriptor::TYPE_UINT32:
97 case google::protobuf::FieldDescriptor::TYPE_SINT32:
98 case google::protobuf::FieldDescriptor::TYPE_INT64:
99 case google::protobuf::FieldDescriptor::TYPE_UINT64:
100 case google::protobuf::FieldDescriptor::TYPE_SINT64:
101 case google::protobuf::FieldDescriptor::TYPE_FIXED32:
102 case google::protobuf::FieldDescriptor::TYPE_SFIXED32:
103 case google::protobuf::FieldDescriptor::TYPE_FIXED64:
104 case google::protobuf::FieldDescriptor::TYPE_SFIXED64:
105 case google::protobuf::FieldDescriptor::TYPE_FLOAT:
106 case google::protobuf::FieldDescriptor::TYPE_DOUBLE:
107 case google::protobuf::FieldDescriptor::TYPE_BOOL:
108 case google::protobuf::FieldDescriptor::TYPE_ENUM:
109 break;
110 default:
111 return false;
112 }
113 if (field->options().has_packed())
114 return field->options().packed();
115 return field->file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3;
116 }
117
118 // Should we omit null values (zero for numbers / empty string for strings) while storing them.
119 bool shouldSkipNullValue(const google::protobuf::FieldDescriptor * field)
120 {
121 return field->is_optional() && (field->file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3);
122 }
123}
124
125
126// SimpleWriter is an utility class to serialize protobufs.
127// Knows nothing about protobuf schemas, just provides useful functions to serialize data.
128ProtobufWriter::SimpleWriter::SimpleWriter(WriteBuffer & out_) : out(out_), current_piece_start(0), num_bytes_skipped(0)
129{
130}
131
132ProtobufWriter::SimpleWriter::~SimpleWriter() = default;
133
134void ProtobufWriter::SimpleWriter::startMessage()
135{
136}
137
138void ProtobufWriter::SimpleWriter::endMessage()
139{
140 pieces.emplace_back(current_piece_start, buffer.size());
141 size_t size_of_message = buffer.size() - num_bytes_skipped;
142 writeVarint(size_of_message, out);
143 for (const auto & piece : pieces)
144 if (piece.end > piece.start)
145 out.write(reinterpret_cast<char *>(&buffer[piece.start]), piece.end - piece.start);
146 buffer.clear();
147 pieces.clear();
148 num_bytes_skipped = 0;
149 current_piece_start = 0;
150}
151
152void ProtobufWriter::SimpleWriter::startNestedMessage()
153{
154 nested_infos.emplace_back(pieces.size(), num_bytes_skipped);
155 pieces.emplace_back(current_piece_start, buffer.size());
156
157 // We skip enough bytes to have place for inserting the field number and the size of the nested message afterwards
158 // when we finish writing the nested message itself. We don't know the size of the nested message at the point of
159 // calling startNestedMessage(), that's why we have to do this skipping.
160 current_piece_start = buffer.size() + NESTED_MESSAGE_PADDING;
161 buffer.resize(current_piece_start);
162 num_bytes_skipped = NESTED_MESSAGE_PADDING;
163}
164
165void ProtobufWriter::SimpleWriter::endNestedMessage(UInt32 field_number, bool is_group, bool skip_if_empty)
166{
167 const auto & nested_info = nested_infos.back();
168 size_t num_pieces_at_start = nested_info.num_pieces_at_start;
169 size_t num_bytes_skipped_at_start = nested_info.num_bytes_skipped_at_start;
170 nested_infos.pop_back();
171 auto & piece_before_message = pieces[num_pieces_at_start];
172 size_t message_start = piece_before_message.end;
173 size_t message_size = buffer.size() - message_start - num_bytes_skipped;
174 if (!message_size && skip_if_empty)
175 {
176 current_piece_start = piece_before_message.start;
177 buffer.resize(piece_before_message.end);
178 pieces.resize(num_pieces_at_start);
179 num_bytes_skipped = num_bytes_skipped_at_start;
180 return;
181 }
182 size_t num_bytes_inserted;
183 if (is_group)
184 {
185 writeFieldNumber(field_number, GROUP_END, buffer);
186 UInt8 * ptr = &buffer[piece_before_message.end];
187 UInt8 * endptr = writeFieldNumber(field_number, GROUP_START, ptr);
188 num_bytes_inserted = endptr - ptr;
189 }
190 else
191 {
192 UInt8 * ptr = &buffer[piece_before_message.end];
193 UInt8 * endptr = writeFieldNumber(field_number, LENGTH_DELIMITED, ptr);
194 endptr = writeVarint(message_size, endptr);
195 num_bytes_inserted = endptr - ptr;
196 }
197 piece_before_message.end += num_bytes_inserted;
198 num_bytes_skipped += num_bytes_skipped_at_start - num_bytes_inserted;
199}
200
201void ProtobufWriter::SimpleWriter::writeUInt(UInt32 field_number, UInt64 value)
202{
203 size_t old_size = buffer.size();
204 buffer.reserve(old_size + 2 * MAX_VARINT_SIZE);
205 UInt8 * ptr = buffer.data() + old_size;
206 ptr = writeFieldNumber(field_number, VARINT, ptr);
207 ptr = writeVarint(value, ptr);
208 buffer.resize_assume_reserved(ptr - buffer.data());
209}
210
211void ProtobufWriter::SimpleWriter::writeInt(UInt32 field_number, Int64 value)
212{
213 writeUInt(field_number, static_cast<UInt64>(value));
214}
215
216void ProtobufWriter::SimpleWriter::writeSInt(UInt32 field_number, Int64 value)
217{
218 writeUInt(field_number, encodeZigZag(value));
219}
220
221template <typename T>
222void ProtobufWriter::SimpleWriter::writeFixed(UInt32 field_number, T value)
223{
224 static_assert((sizeof(T) == 4) || (sizeof(T) == 8));
225 constexpr WireType wire_type = (sizeof(T) == 4) ? BITS32 : BITS64;
226 size_t old_size = buffer.size();
227 buffer.reserve(old_size + MAX_VARINT_SIZE + sizeof(T));
228 UInt8 * ptr = buffer.data() + old_size;
229 ptr = writeFieldNumber(field_number, wire_type, ptr);
230 memcpy(ptr, &value, sizeof(T));
231 ptr += sizeof(T);
232 buffer.resize_assume_reserved(ptr - buffer.data());
233}
234
235void ProtobufWriter::SimpleWriter::writeString(UInt32 field_number, const StringRef & str)
236{
237 size_t old_size = buffer.size();
238 buffer.reserve(old_size + 2 * MAX_VARINT_SIZE + str.size);
239 UInt8 * ptr = buffer.data() + old_size;
240 ptr = writeFieldNumber(field_number, LENGTH_DELIMITED, ptr);
241 ptr = writeVarint(str.size, ptr);
242 memcpy(ptr, str.data, str.size);
243 ptr += str.size;
244 buffer.resize_assume_reserved(ptr - buffer.data());
245}
246
247void ProtobufWriter::SimpleWriter::startRepeatedPack()
248{
249 pieces.emplace_back(current_piece_start, buffer.size());
250
251 // We skip enough bytes to have place for inserting the field number and the size of the repeated pack afterwards
252 // when we finish writing the repeated pack itself. We don't know the size of the repeated pack at the point of
253 // calling startRepeatedPack(), that's why we have to do this skipping.
254 current_piece_start = buffer.size() + REPEATED_PACK_PADDING;
255 buffer.resize(current_piece_start);
256 num_bytes_skipped += REPEATED_PACK_PADDING;
257}
258
259void ProtobufWriter::SimpleWriter::endRepeatedPack(UInt32 field_number)
260{
261 size_t size = buffer.size() - current_piece_start;
262 if (!size)
263 {
264 current_piece_start = pieces.back().start;
265 buffer.resize(pieces.back().end);
266 pieces.pop_back();
267 num_bytes_skipped -= REPEATED_PACK_PADDING;
268 return;
269 }
270 UInt8 * ptr = &buffer[pieces.back().end];
271 UInt8 * endptr = writeFieldNumber(field_number, LENGTH_DELIMITED, ptr);
272 endptr = writeVarint(size, endptr);
273 size_t num_bytes_inserted = endptr - ptr;
274 pieces.back().end += num_bytes_inserted;
275 num_bytes_skipped -= num_bytes_inserted;
276}
277
278void ProtobufWriter::SimpleWriter::addUIntToRepeatedPack(UInt64 value)
279{
280 writeVarint(value, buffer);
281}
282
283void ProtobufWriter::SimpleWriter::addIntToRepeatedPack(Int64 value)
284{
285 writeVarint(static_cast<UInt64>(value), buffer);
286}
287
288void ProtobufWriter::SimpleWriter::addSIntToRepeatedPack(Int64 value)
289{
290 writeVarint(encodeZigZag(value), buffer);
291}
292
293template <typename T>
294void ProtobufWriter::SimpleWriter::addFixedToRepeatedPack(T value)
295{
296 static_assert((sizeof(T) == 4) || (sizeof(T) == 8));
297 size_t old_size = buffer.size();
298 buffer.resize(old_size + sizeof(T));
299 memcpy(buffer.data() + old_size, &value, sizeof(T));
300}
301
302
303// Implementation for a converter from any DB data type to any protobuf field type.
304class ProtobufWriter::ConverterBaseImpl : public IConverter
305{
306public:
307 ConverterBaseImpl(SimpleWriter & simple_writer_, const google::protobuf::FieldDescriptor * field_)
308 : simple_writer(simple_writer_), field(field_)
309 {
310 field_number = field->number();
311 }
312
313 virtual void writeString(const StringRef &) override { cannotConvertType("String"); }
314 virtual void writeInt8(Int8) override { cannotConvertType("Int8"); }
315 virtual void writeUInt8(UInt8) override { cannotConvertType("UInt8"); }
316 virtual void writeInt16(Int16) override { cannotConvertType("Int16"); }
317 virtual void writeUInt16(UInt16) override { cannotConvertType("UInt16"); }
318 virtual void writeInt32(Int32) override { cannotConvertType("Int32"); }
319 virtual void writeUInt32(UInt32) override { cannotConvertType("UInt32"); }
320 virtual void writeInt64(Int64) override { cannotConvertType("Int64"); }
321 virtual void writeUInt64(UInt64) override { cannotConvertType("UInt64"); }
322 virtual void writeUInt128(const UInt128 &) override { cannotConvertType("UInt128"); }
323 virtual void writeFloat32(Float32) override { cannotConvertType("Float32"); }
324 virtual void writeFloat64(Float64) override { cannotConvertType("Float64"); }
325 virtual void prepareEnumMapping8(const std::vector<std::pair<std::string, Int8>> &) override {}
326 virtual void prepareEnumMapping16(const std::vector<std::pair<std::string, Int16>> &) override {}
327 virtual void writeEnum8(Int8) override { cannotConvertType("Enum"); }
328 virtual void writeEnum16(Int16) override { cannotConvertType("Enum"); }
329 virtual void writeUUID(const UUID &) override { cannotConvertType("UUID"); }
330 virtual void writeDate(DayNum) override { cannotConvertType("Date"); }
331 virtual void writeDateTime(time_t) override { cannotConvertType("DateTime"); }
332 virtual void writeDateTime64(DateTime64, UInt32) override { cannotConvertType("DateTime64"); }
333 virtual void writeDecimal32(Decimal32, UInt32) override { cannotConvertType("Decimal32"); }
334 virtual void writeDecimal64(Decimal64, UInt32) override { cannotConvertType("Decimal64"); }
335 virtual void writeDecimal128(const Decimal128 &, UInt32) override { cannotConvertType("Decimal128"); }
336
337 virtual void writeAggregateFunction(const AggregateFunctionPtr &, ConstAggregateDataPtr) override { cannotConvertType("AggregateFunction"); }
338
339protected:
340 [[noreturn]] void cannotConvertType(const String & type_name)
341 {
342 throw Exception(
343 "Could not convert data type '" + type_name + "' to protobuf type '" + field->type_name() + "' (field: " + field->name() + ")",
344 ErrorCodes::PROTOBUF_BAD_CAST);
345 }
346
347 [[noreturn]] void cannotConvertValue(const String & value)
348 {
349 throw Exception(
350 "Could not convert value '" + value + "' to protobuf type '" + field->type_name() + "' (field: " + field->name() + ")",
351 ErrorCodes::PROTOBUF_BAD_CAST);
352 }
353
354 template <typename To, typename From>
355 To numericCast(From value)
356 {
357 if constexpr (std::is_same_v<To, From>)
358 return value;
359 To result;
360 try
361 {
362 result = boost::numeric_cast<To>(value);
363 }
364 catch (boost::numeric::bad_numeric_cast &)
365 {
366 cannotConvertValue(toString(value));
367 }
368 return result;
369 }
370
371 template <typename To>
372 To parseFromString(const StringRef & str)
373 {
374 To result;
375 try
376 {
377 result = ::DB::parse<To>(str.data, str.size);
378 }
379 catch (...)
380 {
381 cannotConvertValue(str.toString());
382 }
383 return result;
384 }
385
386 SimpleWriter & simple_writer;
387 const google::protobuf::FieldDescriptor * field;
388 UInt32 field_number;
389};
390
391
392template <bool skip_null_value>
393class ProtobufWriter::ConverterToString : public ConverterBaseImpl
394{
395public:
396 using ConverterBaseImpl::ConverterBaseImpl;
397
398 void writeString(const StringRef & str) override { writeField(str); }
399
400 void writeInt8(Int8 value) override { convertToStringAndWriteField(value); }
401 void writeUInt8(UInt8 value) override { convertToStringAndWriteField(value); }
402 void writeInt16(Int16 value) override { convertToStringAndWriteField(value); }
403 void writeUInt16(UInt16 value) override { convertToStringAndWriteField(value); }
404 void writeInt32(Int32 value) override { convertToStringAndWriteField(value); }
405 void writeUInt32(UInt32 value) override { convertToStringAndWriteField(value); }
406 void writeInt64(Int64 value) override { convertToStringAndWriteField(value); }
407 void writeUInt64(UInt64 value) override { convertToStringAndWriteField(value); }
408 void writeFloat32(Float32 value) override { convertToStringAndWriteField(value); }
409 void writeFloat64(Float64 value) override { convertToStringAndWriteField(value); }
410
411 void prepareEnumMapping8(const std::vector<std::pair<String, Int8>> & name_value_pairs) override
412 {
413 prepareEnumValueToNameMap(name_value_pairs);
414 }
415 void prepareEnumMapping16(const std::vector<std::pair<String, Int16>> & name_value_pairs) override
416 {
417 prepareEnumValueToNameMap(name_value_pairs);
418 }
419
420 void writeEnum8(Int8 value) override { writeEnum16(value); }
421
422 void writeEnum16(Int16 value) override
423 {
424 auto it = enum_value_to_name_map->find(value);
425 if (it == enum_value_to_name_map->end())
426 cannotConvertValue(toString(value));
427 writeField(it->second);
428 }
429
430 void writeUUID(const UUID & uuid) override { convertToStringAndWriteField(uuid); }
431 void writeDate(DayNum date) override { convertToStringAndWriteField(date); }
432
433 void writeDateTime(time_t tm) override
434 {
435 writeDateTimeText(tm, text_buffer);
436 writeField(text_buffer.stringRef());
437 text_buffer.restart();
438 }
439
440 void writeDateTime64(DateTime64 date_time, UInt32 scale) override
441 {
442 writeDateTimeText(date_time, scale, text_buffer);
443 writeField(text_buffer.stringRef());
444 text_buffer.restart();
445 }
446
447 void writeDecimal32(Decimal32 decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
448 void writeDecimal64(Decimal64 decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
449 void writeDecimal128(const Decimal128 & decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
450
451 void writeAggregateFunction(const AggregateFunctionPtr & function, ConstAggregateDataPtr place) override
452 {
453 function->serialize(place, text_buffer);
454 writeField(text_buffer.stringRef());
455 text_buffer.restart();
456 }
457
458private:
459 template <typename T>
460 void convertToStringAndWriteField(T value)
461 {
462 writeText(value, text_buffer);
463 writeField(text_buffer.stringRef());
464 text_buffer.restart();
465 }
466
467 template <typename T>
468 void writeDecimal(const Decimal<T> & decimal, UInt32 scale)
469 {
470 writeText(decimal, scale, text_buffer);
471 writeField(text_buffer.stringRef());
472 text_buffer.restart();
473 }
474
475 template <typename T>
476 void prepareEnumValueToNameMap(const std::vector<std::pair<String, T>> & name_value_pairs)
477 {
478 if (enum_value_to_name_map.has_value())
479 return;
480 enum_value_to_name_map.emplace();
481 for (const auto & name_value_pair : name_value_pairs)
482 enum_value_to_name_map->emplace(name_value_pair.second, name_value_pair.first);
483 }
484
485 void writeField(const StringRef & str)
486 {
487 if constexpr (skip_null_value)
488 {
489 if (!str.size)
490 return;
491 }
492 simple_writer.writeString(field_number, str);
493 }
494
495 WriteBufferFromOwnString text_buffer;
496 std::optional<std::unordered_map<Int16, String>> enum_value_to_name_map;
497};
498
499#define PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(field_type_id) \
500 template <> \
501 std::unique_ptr<ProtobufWriter::IConverter> ProtobufWriter::createConverter<field_type_id>( \
502 const google::protobuf::FieldDescriptor * field) \
503 { \
504 if (shouldSkipNullValue(field)) \
505 return std::make_unique<ConverterToString<true>>(simple_writer, field); \
506 else \
507 return std::make_unique<ConverterToString<false>>(simple_writer, field); \
508 }
509PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_STRING)
510PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_BYTES)
511#undef PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS
512
513
514template <int field_type_id, typename ToType, bool skip_null_value, bool pack_repeated>
515class ProtobufWriter::ConverterToNumber : public ConverterBaseImpl
516{
517public:
518 using ConverterBaseImpl::ConverterBaseImpl;
519
520 void writeString(const StringRef & str) override { writeField(parseFromString<ToType>(str)); }
521
522 void writeInt8(Int8 value) override { castNumericAndWriteField(value); }
523 void writeUInt8(UInt8 value) override { castNumericAndWriteField(value); }
524 void writeInt16(Int16 value) override { castNumericAndWriteField(value); }
525 void writeUInt16(UInt16 value) override { castNumericAndWriteField(value); }
526 void writeInt32(Int32 value) override { castNumericAndWriteField(value); }
527 void writeUInt32(UInt32 value) override { castNumericAndWriteField(value); }
528 void writeInt64(Int64 value) override { castNumericAndWriteField(value); }
529 void writeUInt64(UInt64 value) override { castNumericAndWriteField(value); }
530 void writeFloat32(Float32 value) override { castNumericAndWriteField(value); }
531 void writeFloat64(Float64 value) override { castNumericAndWriteField(value); }
532
533 void writeEnum8(Int8 value) override { writeEnum16(value); }
534
535 void writeEnum16(Int16 value) override
536 {
537 if constexpr (!is_integral_v<ToType>)
538 cannotConvertType("Enum"); // It's not correct to convert enum to floating point.
539 castNumericAndWriteField(value);
540 }
541
542 void writeDate(DayNum date) override { castNumericAndWriteField(static_cast<UInt16>(date)); }
543 void writeDateTime(time_t tm) override { castNumericAndWriteField(tm); }
544 void writeDateTime64(DateTime64 date_time, UInt32 scale) override { writeDecimal(date_time, scale); }
545 void writeDecimal32(Decimal32 decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
546 void writeDecimal64(Decimal64 decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
547 void writeDecimal128(const Decimal128 & decimal, UInt32 scale) override { writeDecimal(decimal, scale); }
548
549private:
550 template <typename FromType>
551 void castNumericAndWriteField(FromType value)
552 {
553 writeField(numericCast<ToType>(value));
554 }
555
556 template <typename S>
557 void writeDecimal(const Decimal<S> & decimal, UInt32 scale)
558 {
559 castNumericAndWriteField(convertFromDecimal<DataTypeDecimal<Decimal<S>>, DataTypeNumber<ToType>>(decimal.value, scale));
560 }
561
562 void writeField(ToType value)
563 {
564 if constexpr (skip_null_value)
565 {
566 if (value == 0)
567 return;
568 }
569 if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT32) && std::is_same_v<ToType, Int32>)
570 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT64) && std::is_same_v<ToType, Int64>))
571 {
572 if constexpr (pack_repeated)
573 simple_writer.addIntToRepeatedPack(value);
574 else
575 simple_writer.writeInt(field_number, value);
576 }
577 else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT32) && std::is_same_v<ToType, Int32>)
578 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT64) && std::is_same_v<ToType, Int64>))
579 {
580 if constexpr (pack_repeated)
581 simple_writer.addSIntToRepeatedPack(value);
582 else
583 simple_writer.writeSInt(field_number, value);
584 }
585 else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT32) && std::is_same_v<ToType, UInt32>)
586 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT64) && std::is_same_v<ToType, UInt64>))
587 {
588 if constexpr (pack_repeated)
589 simple_writer.addUIntToRepeatedPack(value);
590 else
591 simple_writer.writeUInt(field_number, value);
592 }
593 else
594 {
595 static_assert(((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED32) && std::is_same_v<ToType, UInt32>)
596 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED32) && std::is_same_v<ToType, Int32>)
597 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED64) && std::is_same_v<ToType, UInt64>)
598 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED64) && std::is_same_v<ToType, Int64>)
599 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FLOAT) && std::is_same_v<ToType, float>)
600 || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_DOUBLE) && std::is_same_v<ToType, double>));
601 if constexpr (pack_repeated)
602 simple_writer.addFixedToRepeatedPack(value);
603 else
604 simple_writer.writeFixed(field_number, value);
605 }
606 }
607};
608
609#define PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(field_type_id, field_type) \
610 template <> \
611 std::unique_ptr<ProtobufWriter::IConverter> ProtobufWriter::createConverter<field_type_id>( \
612 const google::protobuf::FieldDescriptor * field) \
613 { \
614 if (shouldSkipNullValue(field)) \
615 return std::make_unique<ConverterToNumber<field_type_id, field_type, true, false>>(simple_writer, field); \
616 else if (shouldPackRepeated(field)) \
617 return std::make_unique<ConverterToNumber<field_type_id, field_type, false, true>>(simple_writer, field); \
618 else \
619 return std::make_unique<ConverterToNumber<field_type_id, field_type, false, false>>(simple_writer, field); \
620 }
621PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT32, Int32);
622PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT32, Int32);
623PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT32, UInt32);
624PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT64, Int64);
625PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT64, Int64);
626PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT64, UInt64);
627PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED32, UInt32);
628PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED32, Int32);
629PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED64, UInt64);
630PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED64, Int64);
631PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FLOAT, float);
632PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_DOUBLE, double);
633#undef PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS
634
635
636template <bool skip_null_value, bool pack_repeated>
637class ProtobufWriter::ConverterToBool : public ConverterBaseImpl
638{
639public:
640 using ConverterBaseImpl::ConverterBaseImpl;
641
642 void writeString(const StringRef & str) override
643 {
644 if (str == "true")
645 writeField(true);
646 else if (str == "false")
647 writeField(false);
648 else
649 cannotConvertValue(str.toString());
650 }
651
652 void writeInt8(Int8 value) override { convertToBoolAndWriteField(value); }
653 void writeUInt8(UInt8 value) override { convertToBoolAndWriteField(value); }
654 void writeInt16(Int16 value) override { convertToBoolAndWriteField(value); }
655 void writeUInt16(UInt16 value) override { convertToBoolAndWriteField(value); }
656 void writeInt32(Int32 value) override { convertToBoolAndWriteField(value); }
657 void writeUInt32(UInt32 value) override { convertToBoolAndWriteField(value); }
658 void writeInt64(Int64 value) override { convertToBoolAndWriteField(value); }
659 void writeUInt64(UInt64 value) override { convertToBoolAndWriteField(value); }
660 void writeFloat32(Float32 value) override { convertToBoolAndWriteField(value); }
661 void writeFloat64(Float64 value) override { convertToBoolAndWriteField(value); }
662 void writeDecimal32(Decimal32 decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); }
663 void writeDecimal64(Decimal64 decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); }
664 void writeDecimal128(const Decimal128 & decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); }
665
666private:
667 template <typename T>
668 void convertToBoolAndWriteField(T value)
669 {
670 writeField(static_cast<bool>(value));
671 }
672
673 void writeField(bool b)
674 {
675 if constexpr (skip_null_value)
676 {
677 if (!b)
678 return;
679 }
680 if constexpr (pack_repeated)
681 simple_writer.addUIntToRepeatedPack(b);
682 else
683 simple_writer.writeUInt(field_number, b);
684 }
685};
686
687template <>
688std::unique_ptr<ProtobufWriter::IConverter> ProtobufWriter::createConverter<google::protobuf::FieldDescriptor::TYPE_BOOL>(
689 const google::protobuf::FieldDescriptor * field)
690{
691 if (shouldSkipNullValue(field))
692 return std::make_unique<ConverterToBool<true, false>>(simple_writer, field);
693 else if (shouldPackRepeated(field))
694 return std::make_unique<ConverterToBool<false, true>>(simple_writer, field);
695 else
696 return std::make_unique<ConverterToBool<false, false>>(simple_writer, field);
697}
698
699
700template <bool skip_null_value, bool pack_repeated>
701class ProtobufWriter::ConverterToEnum : public ConverterBaseImpl
702{
703public:
704 using ConverterBaseImpl::ConverterBaseImpl;
705
706 void writeString(const StringRef & str) override
707 {
708 prepareEnumNameToPbNumberMap();
709 auto it = enum_name_to_pbnumber_map->find(str);
710 if (it == enum_name_to_pbnumber_map->end())
711 cannotConvertValue(str.toString());
712 writeField(it->second);
713 }
714
715 void writeInt8(Int8 value) override { convertToEnumAndWriteField(value); }
716 void writeUInt8(UInt8 value) override { convertToEnumAndWriteField(value); }
717 void writeInt16(Int16 value) override { convertToEnumAndWriteField(value); }
718 void writeUInt16(UInt16 value) override { convertToEnumAndWriteField(value); }
719 void writeInt32(Int32 value) override { convertToEnumAndWriteField(value); }
720 void writeUInt32(UInt32 value) override { convertToEnumAndWriteField(value); }
721 void writeInt64(Int64 value) override { convertToEnumAndWriteField(value); }
722 void writeUInt64(UInt64 value) override { convertToEnumAndWriteField(value); }
723
724 void prepareEnumMapping8(const std::vector<std::pair<String, Int8>> & name_value_pairs) override
725 {
726 prepareEnumValueToPbNumberMap(name_value_pairs);
727 }
728 void prepareEnumMapping16(const std::vector<std::pair<String, Int16>> & name_value_pairs) override
729 {
730 prepareEnumValueToPbNumberMap(name_value_pairs);
731 }
732
733 void writeEnum8(Int8 value) override { writeEnum16(value); }
734
735 void writeEnum16(Int16 value) override
736 {
737 int pbnumber;
738 if (enum_value_always_equals_pbnumber)
739 pbnumber = value;
740 else
741 {
742 auto it = enum_value_to_pbnumber_map->find(value);
743 if (it == enum_value_to_pbnumber_map->end())
744 cannotConvertValue(toString(value));
745 pbnumber = it->second;
746 }
747 writeField(pbnumber);
748 }
749
750private:
751 template <typename T>
752 void convertToEnumAndWriteField(T value)
753 {
754 const auto * enum_descriptor = field->enum_type()->FindValueByNumber(numericCast<int>(value));
755 if (!enum_descriptor)
756 cannotConvertValue(toString(value));
757 writeField(enum_descriptor->number());
758 }
759
760 void prepareEnumNameToPbNumberMap()
761 {
762 if (enum_name_to_pbnumber_map.has_value())
763 return;
764 enum_name_to_pbnumber_map.emplace();
765 const auto * enum_type = field->enum_type();
766 for (int i = 0; i != enum_type->value_count(); ++i)
767 {
768 const auto * enum_value = enum_type->value(i);
769 enum_name_to_pbnumber_map->emplace(enum_value->name(), enum_value->number());
770 }
771 }
772
773 template <typename T>
774 void prepareEnumValueToPbNumberMap(const std::vector<std::pair<String, T>> & name_value_pairs)
775 {
776 if (enum_value_to_pbnumber_map.has_value())
777 return;
778 enum_value_to_pbnumber_map.emplace();
779 enum_value_always_equals_pbnumber = true;
780 for (const auto & name_value_pair : name_value_pairs)
781 {
782 Int16 value = name_value_pair.second;
783 const auto * enum_descriptor = field->enum_type()->FindValueByName(name_value_pair.first);
784 if (enum_descriptor)
785 {
786 enum_value_to_pbnumber_map->emplace(value, enum_descriptor->number());
787 if (value != enum_descriptor->number())
788 enum_value_always_equals_pbnumber = false;
789 }
790 else
791 enum_value_always_equals_pbnumber = false;
792 }
793 }
794
795 void writeField(int enum_pbnumber)
796 {
797 if constexpr (skip_null_value)
798 {
799 if (!enum_pbnumber)
800 return;
801 }
802 if constexpr (pack_repeated)
803 simple_writer.addUIntToRepeatedPack(enum_pbnumber);
804 else
805 simple_writer.writeUInt(field_number, enum_pbnumber);
806 }
807
808 std::optional<std::unordered_map<StringRef, int>> enum_name_to_pbnumber_map;
809 std::optional<std::unordered_map<Int16, int>> enum_value_to_pbnumber_map;
810 bool enum_value_always_equals_pbnumber;
811};
812
813template <>
814std::unique_ptr<ProtobufWriter::IConverter> ProtobufWriter::createConverter<google::protobuf::FieldDescriptor::TYPE_ENUM>(
815 const google::protobuf::FieldDescriptor * field)
816{
817 if (shouldSkipNullValue(field))
818 return std::make_unique<ConverterToEnum<true, false>>(simple_writer, field);
819 else if (shouldPackRepeated(field))
820 return std::make_unique<ConverterToEnum<false, true>>(simple_writer, field);
821 else
822 return std::make_unique<ConverterToEnum<false, false>>(simple_writer, field);
823}
824
825
826ProtobufWriter::ProtobufWriter(
827 WriteBuffer & out, const google::protobuf::Descriptor * message_type, const std::vector<String> & column_names)
828 : simple_writer(out)
829{
830 std::vector<const google::protobuf::FieldDescriptor *> field_descriptors_without_match;
831 root_message = ProtobufColumnMatcher::matchColumns<ColumnMatcherTraits>(column_names, message_type, field_descriptors_without_match);
832 for (const auto * field_descriptor_without_match : field_descriptors_without_match)
833 {
834 if (field_descriptor_without_match->is_required())
835 throw Exception(
836 "Output doesn't have a column named '" + field_descriptor_without_match->name()
837 + "' which is required to write the output in the protobuf format.",
838 ErrorCodes::NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD);
839 }
840 setTraitsDataAfterMatchingColumns(root_message.get());
841}
842
843ProtobufWriter::~ProtobufWriter() = default;
844
845void ProtobufWriter::setTraitsDataAfterMatchingColumns(Message * message)
846{
847 Field * parent_field = message->parent ? &message->parent->fields[message->index_in_parent] : nullptr;
848 message->data.parent_field_number = parent_field ? parent_field->field_number : 0;
849 message->data.is_required = parent_field && parent_field->data.is_required;
850
851 if (parent_field && parent_field->data.is_repeatable)
852 message->data.repeatable_container_message = message;
853 else if (message->parent)
854 message->data.repeatable_container_message = message->parent->data.repeatable_container_message;
855 else
856 message->data.repeatable_container_message = nullptr;
857
858 message->data.is_group = parent_field && (parent_field->field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_GROUP);
859
860 for (auto & field : message->fields)
861 {
862 field.data.is_repeatable = field.field_descriptor->is_repeated();
863 field.data.is_required = field.field_descriptor->is_required();
864 field.data.repeatable_container_message = message->data.repeatable_container_message;
865 field.data.should_pack_repeated = shouldPackRepeated(field.field_descriptor);
866
867 if (field.nested_message)
868 {
869 setTraitsDataAfterMatchingColumns(field.nested_message.get());
870 continue;
871 }
872 switch (field.field_descriptor->type())
873 {
874#define PROTOBUF_WRITER_CONVERTER_CREATING_CASE(field_type_id) \
875 case field_type_id: \
876 field.data.converter = createConverter<field_type_id>(field.field_descriptor); \
877 break
878 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_STRING);
879 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BYTES);
880 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT32);
881 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT32);
882 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT32);
883 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED32);
884 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED32);
885 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT64);
886 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT64);
887 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT64);
888 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED64);
889 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED64);
890 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FLOAT);
891 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_DOUBLE);
892 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BOOL);
893 PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_ENUM);
894#undef PROTOBUF_WRITER_CONVERTER_CREATING_CASE
895 default:
896 throw Exception(
897 String("Protobuf type '") + field.field_descriptor->type_name() + "' isn't supported", ErrorCodes::NOT_IMPLEMENTED);
898 }
899 }
900}
901
902void ProtobufWriter::startMessage()
903{
904 current_message = root_message.get();
905 current_field_index = 0;
906 simple_writer.startMessage();
907}
908
909void ProtobufWriter::endMessage()
910{
911 if (!current_message)
912 return;
913 endWritingField();
914 while (current_message->parent)
915 {
916 simple_writer.endNestedMessage(
917 current_message->data.parent_field_number, current_message->data.is_group, !current_message->data.is_required);
918 current_message = current_message->parent;
919 }
920 simple_writer.endMessage();
921 current_message = nullptr;
922}
923
924bool ProtobufWriter::writeField(size_t & column_index)
925{
926 endWritingField();
927 while (true)
928 {
929 if (current_field_index < current_message->fields.size())
930 {
931 Field & field = current_message->fields[current_field_index];
932 if (!field.nested_message)
933 {
934 current_field = &current_message->fields[current_field_index];
935 current_converter = current_field->data.converter.get();
936 column_index = current_field->column_index;
937 if (current_field->data.should_pack_repeated)
938 simple_writer.startRepeatedPack();
939 return true;
940 }
941 simple_writer.startNestedMessage();
942 current_message = field.nested_message.get();
943 current_message->data.need_repeat = false;
944 current_field_index = 0;
945 continue;
946 }
947 if (current_message->parent)
948 {
949 simple_writer.endNestedMessage(
950 current_message->data.parent_field_number, current_message->data.is_group, !current_message->data.is_required);
951 if (current_message->data.need_repeat)
952 {
953 simple_writer.startNestedMessage();
954 current_message->data.need_repeat = false;
955 current_field_index = 0;
956 continue;
957 }
958 current_field_index = current_message->index_in_parent + 1;
959 current_message = current_message->parent;
960 continue;
961 }
962 return false;
963 }
964}
965
966void ProtobufWriter::endWritingField()
967{
968 if (!current_field)
969 return;
970 if (current_field->data.should_pack_repeated)
971 simple_writer.endRepeatedPack(current_field->field_number);
972 else if ((num_values == 0) && current_field->data.is_required)
973 throw Exception(
974 "No data for the required field '" + current_field->field_descriptor->name() + "'",
975 ErrorCodes::NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD);
976
977 current_field = nullptr;
978 current_converter = nullptr;
979 num_values = 0;
980 ++current_field_index;
981}
982
983void ProtobufWriter::setNestedMessageNeedsRepeat()
984{
985 if (current_field->data.repeatable_container_message)
986 current_field->data.repeatable_container_message->data.need_repeat = true;
987 else
988 throw Exception(
989 "Cannot write more than single value to the non-repeated field '" + current_field->field_descriptor->name() + "'",
990 ErrorCodes::PROTOBUF_FIELD_NOT_REPEATED);
991}
992
993}
994#endif
995