1 | // Protocol Buffers - Google's data interchange format |
2 | // Copyright 2008 Google Inc. All rights reserved. |
3 | // https://developers.google.com/protocol-buffers/ |
4 | // |
5 | // Redistribution and use in source and binary forms, with or without |
6 | // modification, are permitted provided that the following conditions are |
7 | // met: |
8 | // |
9 | // * Redistributions of source code must retain the above copyright |
10 | // notice, this list of conditions and the following disclaimer. |
11 | // * Redistributions in binary form must reproduce the above |
12 | // copyright notice, this list of conditions and the following disclaimer |
13 | // in the documentation and/or other materials provided with the |
14 | // distribution. |
15 | // * Neither the name of Google Inc. nor the names of its |
16 | // contributors may be used to endorse or promote products derived from |
17 | // this software without specific prior written permission. |
18 | // |
19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
30 | |
31 | #include <google/protobuf/compiler/python/pyi_generator.h> |
32 | |
33 | #include <string> |
34 | |
35 | #include <google/protobuf/stubs/strutil.h> |
36 | #include <google/protobuf/compiler/python/helpers.h> |
37 | #include <google/protobuf/descriptor.h> |
38 | #include <google/protobuf/descriptor.pb.h> |
39 | #include <google/protobuf/io/printer.h> |
40 | #include <google/protobuf/io/zero_copy_stream.h> |
41 | |
42 | namespace google { |
43 | namespace protobuf { |
44 | namespace compiler { |
45 | namespace python { |
46 | |
47 | template <typename DescriptorT> |
48 | struct SortByName { |
49 | bool operator()(const DescriptorT* l, const DescriptorT* r) const { |
50 | return l->name() < r->name(); |
51 | } |
52 | }; |
53 | |
54 | PyiGenerator::PyiGenerator() : file_(nullptr) {} |
55 | |
56 | PyiGenerator::~PyiGenerator() {} |
57 | |
58 | void PyiGenerator::PrintItemMap( |
59 | const std::map<std::string, std::string>& item_map) const { |
60 | for (const auto& entry : item_map) { |
61 | printer_->Print(text: "$key$: $value$\n" , args: "key" , args: entry.first, args: "value" , |
62 | args: entry.second); |
63 | } |
64 | } |
65 | |
66 | template <typename DescriptorT> |
67 | std::string PyiGenerator::ModuleLevelName( |
68 | const DescriptorT& descriptor, |
69 | const std::map<std::string, std::string>& import_map) const { |
70 | std::string name = NamePrefixedWithNestedTypes(descriptor, "." ); |
71 | if (descriptor.file() != file_) { |
72 | std::string module_alias; |
73 | std::string filename = descriptor.file()->name(); |
74 | if (import_map.find(x: filename) == import_map.end()) { |
75 | std::string module_name = ModuleName(descriptor.file()->name()); |
76 | std::vector<std::string> tokens = Split(full: module_name, delim: "." ); |
77 | module_alias = "_" + tokens.back(); |
78 | } else { |
79 | module_alias = import_map.at(k: filename); |
80 | } |
81 | name = module_alias + "." + name; |
82 | } |
83 | return name; |
84 | } |
85 | |
86 | struct ImportModules { |
87 | bool has_repeated = false; // _containers |
88 | bool has_iterable = false; // typing.Iterable |
89 | bool has_messages = false; // _message |
90 | bool has_enums = false; // _enum_type_wrapper |
91 | bool has_extendable = false; // _python_message |
92 | bool has_mapping = false; // typing.Mapping |
93 | bool has_optional = false; // typing.Optional |
94 | bool has_union = false; // typing.Union |
95 | bool has_well_known_type = false; |
96 | }; |
97 | |
98 | // Checks whether a descriptor name matches a well-known type. |
99 | bool IsWellKnownType(const std::string& name) { |
100 | // LINT.IfChange(wktbases) |
101 | return (name == "google.protobuf.Any" || |
102 | name == "google.protobuf.Duration" || |
103 | name == "google.protobuf.FieldMask" || |
104 | name == "google.protobuf.ListValue" || |
105 | name == "google.protobuf.Struct" || |
106 | name == "google.protobuf.Timestamp" ); |
107 | // LINT.ThenChange(//depot/google3/net/proto2/python/internal/well_known_types.py:wktbases) |
108 | } |
109 | |
110 | // Checks what modules should be imported for this message |
111 | // descriptor. |
112 | void CheckImportModules(const Descriptor* descriptor, |
113 | ImportModules* import_modules) { |
114 | if (descriptor->extension_range_count() > 0) { |
115 | import_modules->has_extendable = true; |
116 | } |
117 | if (descriptor->enum_type_count() > 0) { |
118 | import_modules->has_enums = true; |
119 | } |
120 | if (IsWellKnownType(name: descriptor->full_name())) { |
121 | import_modules->has_well_known_type = true; |
122 | } |
123 | for (int i = 0; i < descriptor->field_count(); ++i) { |
124 | const FieldDescriptor* field = descriptor->field(index: i); |
125 | if (IsPythonKeyword(name: field->name())) { |
126 | continue; |
127 | } |
128 | import_modules->has_optional = true; |
129 | if (field->is_repeated()) { |
130 | import_modules->has_repeated = true; |
131 | } |
132 | if (field->is_map()) { |
133 | import_modules->has_mapping = true; |
134 | const FieldDescriptor* value_des = field->message_type()->field(index: 1); |
135 | if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE || |
136 | value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
137 | import_modules->has_union = true; |
138 | } |
139 | } else { |
140 | if (field->is_repeated()) { |
141 | import_modules->has_iterable = true; |
142 | } |
143 | if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
144 | import_modules->has_union = true; |
145 | import_modules->has_mapping = true; |
146 | } |
147 | if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
148 | import_modules->has_union = true; |
149 | } |
150 | } |
151 | } |
152 | for (int i = 0; i < descriptor->nested_type_count(); ++i) { |
153 | CheckImportModules(descriptor: descriptor->nested_type(index: i), import_modules); |
154 | } |
155 | } |
156 | |
157 | void PyiGenerator::PrintImportForDescriptor( |
158 | const FileDescriptor& desc, |
159 | std::map<std::string, std::string>* import_map, |
160 | std::set<std::string>* seen_aliases) const { |
161 | const std::string& filename = desc.name(); |
162 | std::string module_name = StrippedModuleName(filename); |
163 | size_t last_dot_pos = module_name.rfind(c: '.'); |
164 | std::string import_statement; |
165 | if (last_dot_pos == std::string::npos) { |
166 | import_statement = "import " + module_name; |
167 | } else { |
168 | import_statement = "from " + module_name.substr(pos: 0, n: last_dot_pos) + |
169 | " import " + module_name.substr(pos: last_dot_pos + 1); |
170 | module_name = module_name.substr(pos: last_dot_pos + 1); |
171 | } |
172 | std::string alias = "_" + module_name; |
173 | // Generate a unique alias by adding _1 suffixes until we get an unused alias. |
174 | while (seen_aliases->find(x: alias) != seen_aliases->end()) { |
175 | alias = alias + "_1" ; |
176 | } |
177 | printer_->Print(text: "$statement$ as $alias$\n" , args: "statement" , |
178 | args: import_statement, args: "alias" , args: alias); |
179 | (*import_map)[filename] = alias; |
180 | seen_aliases->insert(x: alias); |
181 | } |
182 | |
183 | void PyiGenerator::PrintImports( |
184 | std::map<std::string, std::string>* item_map, |
185 | std::map<std::string, std::string>* import_map) const { |
186 | // Prints imported dependent _pb2 files. |
187 | std::set<std::string> seen_aliases; |
188 | for (int i = 0; i < file_->dependency_count(); ++i) { |
189 | const FileDescriptor* dep = file_->dependency(index: i); |
190 | PrintImportForDescriptor(desc: *dep, import_map, seen_aliases: &seen_aliases); |
191 | for (int j = 0; j < dep->public_dependency_count(); ++j) { |
192 | PrintImportForDescriptor( |
193 | desc: *dep->public_dependency(index: j), import_map, seen_aliases: &seen_aliases); |
194 | } |
195 | } |
196 | |
197 | // Checks what modules should be imported. |
198 | ImportModules import_modules; |
199 | if (file_->message_type_count() > 0) { |
200 | import_modules.has_messages = true; |
201 | } |
202 | if (file_->enum_type_count() > 0) { |
203 | import_modules.has_enums = true; |
204 | } |
205 | for (int i = 0; i < file_->message_type_count(); i++) { |
206 | CheckImportModules(descriptor: file_->message_type(index: i), import_modules: &import_modules); |
207 | } |
208 | |
209 | // Prints modules (e.g. _containers, _messages, typing) that are |
210 | // required in the proto file. |
211 | if (import_modules.has_repeated) { |
212 | printer_->Print( |
213 | text: "from google.protobuf.internal import containers as " |
214 | "_containers\n" ); |
215 | } |
216 | if (import_modules.has_enums) { |
217 | printer_->Print( |
218 | text: "from google.protobuf.internal import enum_type_wrapper" |
219 | " as _enum_type_wrapper\n" ); |
220 | } |
221 | if (import_modules.has_extendable) { |
222 | printer_->Print( |
223 | text: "from google.protobuf.internal import python_message" |
224 | " as _python_message\n" ); |
225 | } |
226 | if (import_modules.has_well_known_type) { |
227 | printer_->Print( |
228 | text: "from google.protobuf.internal import well_known_types" |
229 | " as _well_known_types\n" ); |
230 | } |
231 | printer_->Print( |
232 | text: "from google.protobuf import" |
233 | " descriptor as _descriptor\n" ); |
234 | if (import_modules.has_messages) { |
235 | printer_->Print( |
236 | text: "from google.protobuf import message as _message\n" ); |
237 | } |
238 | if (HasGenericServices(file: file_)) { |
239 | printer_->Print( |
240 | text: "from google.protobuf import service as" |
241 | " _service\n" ); |
242 | } |
243 | printer_->Print(text: "from typing import " ); |
244 | printer_->Print(text: "ClassVar as _ClassVar" ); |
245 | if (import_modules.has_iterable) { |
246 | printer_->Print(text: ", Iterable as _Iterable" ); |
247 | } |
248 | if (import_modules.has_mapping) { |
249 | printer_->Print(text: ", Mapping as _Mapping" ); |
250 | } |
251 | if (import_modules.has_optional) { |
252 | printer_->Print(text: ", Optional as _Optional" ); |
253 | } |
254 | if (import_modules.has_union) { |
255 | printer_->Print(text: ", Union as _Union" ); |
256 | } |
257 | printer_->Print(text: "\n\n" ); |
258 | |
259 | // Public imports |
260 | for (int i = 0; i < file_->public_dependency_count(); ++i) { |
261 | const FileDescriptor* public_dep = file_->public_dependency(index: i); |
262 | std::string module_name = StrippedModuleName(filename: public_dep->name()); |
263 | // Top level messages in public imports |
264 | for (int i = 0; i < public_dep->message_type_count(); ++i) { |
265 | printer_->Print(text: "from $module$ import $message_class$\n" , args: "module" , |
266 | args: module_name, args: "message_class" , |
267 | args: public_dep->message_type(index: i)->name()); |
268 | } |
269 | // Top level enums for public imports |
270 | for (int i = 0; i < public_dep->enum_type_count(); ++i) { |
271 | printer_->Print(text: "from $module$ import $enum_class$\n" , args: "module" , |
272 | args: module_name, args: "enum_class" , |
273 | args: public_dep->enum_type(index: i)->name()); |
274 | } |
275 | // Enum values for public imports |
276 | for (int i = 0; i < public_dep->enum_type_count(); ++i) { |
277 | const EnumDescriptor* enum_descriptor = public_dep->enum_type(index: i); |
278 | for (int j = 0; j < enum_descriptor->value_count(); ++j) { |
279 | (*item_map)[enum_descriptor->value(index: j)->name()] = |
280 | ModuleLevelName(descriptor: *enum_descriptor, import_map: *import_map); |
281 | } |
282 | } |
283 | // Top level extensions for public imports |
284 | AddExtensions(descriptor: *public_dep, item_map); |
285 | } |
286 | } |
287 | |
288 | void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const { |
289 | std::string enum_name = enum_descriptor.name(); |
290 | printer_->Print( |
291 | text: "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n" |
292 | " __slots__ = []\n" , |
293 | args: "enum_name" , args: enum_name); |
294 | } |
295 | |
296 | // Adds enum value to item map which will be ordered and printed later. |
297 | void PyiGenerator::AddEnumValue( |
298 | const EnumDescriptor& enum_descriptor, |
299 | std::map<std::string, std::string>* item_map, |
300 | const std::map<std::string, std::string>& import_map) const { |
301 | // enum values |
302 | std::string module_enum_name = ModuleLevelName(descriptor: enum_descriptor, import_map); |
303 | for (int j = 0; j < enum_descriptor.value_count(); ++j) { |
304 | const EnumValueDescriptor* value_descriptor = enum_descriptor.value(index: j); |
305 | (*item_map)[value_descriptor->name()] = module_enum_name; |
306 | } |
307 | } |
308 | |
309 | // Prints top level enums |
310 | void PyiGenerator::PrintTopLevelEnums() const { |
311 | for (int i = 0; i < file_->enum_type_count(); ++i) { |
312 | printer_->Print(text: "\n" ); |
313 | PrintEnum(enum_descriptor: *file_->enum_type(index: i)); |
314 | } |
315 | } |
316 | |
317 | // Add top level extensions to item_map which will be ordered and |
318 | // printed later. |
319 | template <typename DescriptorT> |
320 | void PyiGenerator::AddExtensions( |
321 | const DescriptorT& descriptor, |
322 | std::map<std::string, std::string>* item_map) const { |
323 | for (int i = 0; i < descriptor.extension_count(); ++i) { |
324 | const FieldDescriptor* extension_field = descriptor.extension(i); |
325 | std::string constant_name = extension_field->name() + "_FIELD_NUMBER" ; |
326 | ToUpper(s: &constant_name); |
327 | (*item_map)[constant_name] = "_ClassVar[int]" ; |
328 | (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor" ; |
329 | } |
330 | } |
331 | |
332 | // Returns the string format of a field's cpp_type |
333 | std::string PyiGenerator::GetFieldType( |
334 | const FieldDescriptor& field_des, const Descriptor& containing_des, |
335 | const std::map<std::string, std::string>& import_map) const { |
336 | switch (field_des.cpp_type()) { |
337 | case FieldDescriptor::CPPTYPE_INT32: |
338 | case FieldDescriptor::CPPTYPE_UINT32: |
339 | case FieldDescriptor::CPPTYPE_INT64: |
340 | case FieldDescriptor::CPPTYPE_UINT64: |
341 | return "int" ; |
342 | case FieldDescriptor::CPPTYPE_DOUBLE: |
343 | case FieldDescriptor::CPPTYPE_FLOAT: |
344 | return "float" ; |
345 | case FieldDescriptor::CPPTYPE_BOOL: |
346 | return "bool" ; |
347 | case FieldDescriptor::CPPTYPE_ENUM: |
348 | return ModuleLevelName(descriptor: *field_des.enum_type(), import_map); |
349 | case FieldDescriptor::CPPTYPE_STRING: |
350 | if (field_des.type() == FieldDescriptor::TYPE_STRING) { |
351 | return "str" ; |
352 | } else { |
353 | return "bytes" ; |
354 | } |
355 | case FieldDescriptor::CPPTYPE_MESSAGE: { |
356 | // If the field is inside a nested message and the nested message has the |
357 | // same name as a top-level message, then we need to prefix the field type |
358 | // with the module name for disambiguation. |
359 | std::string name = ModuleLevelName(descriptor: *field_des.message_type(), import_map); |
360 | if ((containing_des.containing_type() != nullptr && |
361 | name == containing_des.name())) { |
362 | std::string module = ModuleName(filename: field_des.file()->name()); |
363 | name = module + "." + name; |
364 | } |
365 | return name; |
366 | } |
367 | default: |
368 | GOOGLE_LOG(FATAL) << "Unsupported field type." ; |
369 | } |
370 | return "" ; |
371 | } |
372 | |
373 | void PyiGenerator::PrintMessage( |
374 | const Descriptor& message_descriptor, bool is_nested, |
375 | const std::map<std::string, std::string>& import_map) const { |
376 | if (!is_nested) { |
377 | printer_->Print(text: "\n" ); |
378 | } |
379 | std::string class_name = message_descriptor.name(); |
380 | std::string ; |
381 | // A well-known type needs to inherit from its corresponding base class in |
382 | // net/proto2/python/internal/well_known_types. |
383 | if (IsWellKnownType(name: message_descriptor.full_name())) { |
384 | extra_base = ", _well_known_types." + message_descriptor.name(); |
385 | } else { |
386 | extra_base = "" ; |
387 | } |
388 | printer_->Print(text: "class $class_name$(_message.Message$extra_base$):\n" , |
389 | args: "class_name" , args: class_name, args: "extra_base" , args: extra_base); |
390 | printer_->Indent(); |
391 | printer_->Indent(); |
392 | |
393 | std::vector<const FieldDescriptor*> fields; |
394 | fields.reserve(n: message_descriptor.field_count()); |
395 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
396 | fields.push_back(x: message_descriptor.field(index: i)); |
397 | } |
398 | std::sort(first: fields.begin(), last: fields.end(), comp: SortByName<FieldDescriptor>()); |
399 | |
400 | // Prints slots |
401 | printer_->Print(text: "__slots__ = [" , args: "class_name" , args: class_name); |
402 | bool first_item = true; |
403 | for (const auto& field_des : fields) { |
404 | if (IsPythonKeyword(name: field_des->name())) { |
405 | continue; |
406 | } |
407 | if (first_item) { |
408 | first_item = false; |
409 | } else { |
410 | printer_->Print(text: ", " ); |
411 | } |
412 | printer_->Print(text: "\"$field_name$\"" , args: "field_name" , args: field_des->name()); |
413 | } |
414 | printer_->Print(text: "]\n" ); |
415 | |
416 | std::map<std::string, std::string> item_map; |
417 | // Prints Extensions for extendable messages |
418 | if (message_descriptor.extension_range_count() > 0) { |
419 | item_map["Extensions" ] = "_python_message._ExtensionDict" ; |
420 | } |
421 | |
422 | // Prints nested enums |
423 | std::vector<const EnumDescriptor*> nested_enums; |
424 | nested_enums.reserve(n: message_descriptor.enum_type_count()); |
425 | for (int i = 0; i < message_descriptor.enum_type_count(); ++i) { |
426 | nested_enums.push_back(x: message_descriptor.enum_type(index: i)); |
427 | } |
428 | std::sort(first: nested_enums.begin(), last: nested_enums.end(), |
429 | comp: SortByName<EnumDescriptor>()); |
430 | |
431 | for (const auto& entry : nested_enums) { |
432 | PrintEnum(enum_descriptor: *entry); |
433 | // Adds enum value to item_map which will be ordered and printed later |
434 | AddEnumValue(enum_descriptor: *entry, item_map: &item_map, import_map); |
435 | } |
436 | |
437 | // Prints nested messages |
438 | std::vector<const Descriptor*> nested_messages; |
439 | nested_messages.reserve(n: message_descriptor.nested_type_count()); |
440 | for (int i = 0; i < message_descriptor.nested_type_count(); ++i) { |
441 | nested_messages.push_back(x: message_descriptor.nested_type(index: i)); |
442 | } |
443 | std::sort(first: nested_messages.begin(), last: nested_messages.end(), |
444 | comp: SortByName<Descriptor>()); |
445 | |
446 | for (const auto& entry : nested_messages) { |
447 | PrintMessage(message_descriptor: *entry, is_nested: true, import_map); |
448 | } |
449 | |
450 | // Adds extensions to item_map which will be ordered and printed later |
451 | AddExtensions(descriptor: message_descriptor, item_map: &item_map); |
452 | |
453 | // Adds field number and field descriptor to item_map |
454 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
455 | const FieldDescriptor& field_des = *message_descriptor.field(index: i); |
456 | item_map[ToUpper(s: field_des.name()) + "_FIELD_NUMBER" ] = |
457 | "_ClassVar[int]" ; |
458 | if (IsPythonKeyword(name: field_des.name())) { |
459 | continue; |
460 | } |
461 | std::string field_type = "" ; |
462 | if (field_des.is_map()) { |
463 | const FieldDescriptor* key_des = field_des.message_type()->field(index: 0); |
464 | const FieldDescriptor* value_des = field_des.message_type()->field(index: 1); |
465 | field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE |
466 | ? "_containers.MessageMap[" |
467 | : "_containers.ScalarMap[" ); |
468 | field_type += GetFieldType(field_des: *key_des, containing_des: message_descriptor, import_map); |
469 | field_type += ", " ; |
470 | field_type += GetFieldType(field_des: *value_des, containing_des: message_descriptor, import_map); |
471 | } else { |
472 | if (field_des.is_repeated()) { |
473 | field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE |
474 | ? "_containers.RepeatedCompositeFieldContainer[" |
475 | : "_containers.RepeatedScalarFieldContainer[" ); |
476 | } |
477 | field_type += GetFieldType(field_des, containing_des: message_descriptor, import_map); |
478 | } |
479 | |
480 | if (field_des.is_repeated()) { |
481 | field_type += "]" ; |
482 | } |
483 | item_map[field_des.name()] = field_type; |
484 | } |
485 | |
486 | // Prints all items in item_map |
487 | PrintItemMap(item_map); |
488 | |
489 | // Prints __init__ |
490 | printer_->Print(text: "def __init__(self" ); |
491 | bool has_key_words = false; |
492 | bool is_first = true; |
493 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
494 | const FieldDescriptor* field_des = message_descriptor.field(index: i); |
495 | if (IsPythonKeyword(name: field_des->name())) { |
496 | has_key_words = true; |
497 | continue; |
498 | } |
499 | std::string field_name = field_des->name(); |
500 | if (is_first && field_name == "self" ) { |
501 | // See b/144146793 for an example of real code that generates a (self, |
502 | // self) method signature. Since repeating a parameter name is illegal in |
503 | // Python, we rename the duplicate self. |
504 | field_name = "self_" ; |
505 | } |
506 | is_first = false; |
507 | printer_->Print(text: ", $field_name$: " , args: "field_name" , args: field_name); |
508 | if (field_des->is_repeated() || |
509 | field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) { |
510 | printer_->Print(text: "_Optional[" ); |
511 | } |
512 | if (field_des->is_map()) { |
513 | const Descriptor* map_entry = field_des->message_type(); |
514 | printer_->Print( |
515 | text: "_Mapping[$key_type$, $value_type$]" , args: "key_type" , |
516 | args: GetFieldType(field_des: *map_entry->field(index: 0), containing_des: message_descriptor, import_map), |
517 | args: "value_type" , |
518 | args: GetFieldType(field_des: *map_entry->field(index: 1), containing_des: message_descriptor, import_map)); |
519 | } else { |
520 | if (field_des->is_repeated()) { |
521 | printer_->Print(text: "_Iterable[" ); |
522 | } |
523 | if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
524 | printer_->Print( |
525 | text: "_Union[$type_name$, _Mapping]" , args: "type_name" , |
526 | args: GetFieldType(field_des: *field_des, containing_des: message_descriptor, import_map)); |
527 | } else { |
528 | if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
529 | printer_->Print(text: "_Union[$type_name$, str]" , args: "type_name" , |
530 | args: ModuleLevelName(descriptor: *field_des->enum_type(), import_map)); |
531 | } else { |
532 | printer_->Print( |
533 | text: "$type_name$" , args: "type_name" , |
534 | args: GetFieldType(field_des: *field_des, containing_des: message_descriptor, import_map)); |
535 | } |
536 | } |
537 | if (field_des->is_repeated()) { |
538 | printer_->Print(text: "]" ); |
539 | } |
540 | } |
541 | if (field_des->is_repeated() || |
542 | field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) { |
543 | printer_->Print(text: "]" ); |
544 | } |
545 | printer_->Print(text: " = ..." ); |
546 | } |
547 | if (has_key_words) { |
548 | printer_->Print(text: ", **kwargs" ); |
549 | } |
550 | printer_->Print(text: ") -> None: ...\n" ); |
551 | |
552 | printer_->Outdent(); |
553 | printer_->Outdent(); |
554 | } |
555 | |
556 | void PyiGenerator::PrintMessages( |
557 | const std::map<std::string, std::string>& import_map) const { |
558 | // Deterministically order the descriptors. |
559 | std::vector<const Descriptor*> messages; |
560 | messages.reserve(n: file_->message_type_count()); |
561 | for (int i = 0; i < file_->message_type_count(); ++i) { |
562 | messages.push_back(x: file_->message_type(index: i)); |
563 | } |
564 | std::sort(first: messages.begin(), last: messages.end(), comp: SortByName<Descriptor>()); |
565 | |
566 | for (const auto& entry : messages) { |
567 | PrintMessage(message_descriptor: *entry, is_nested: false, import_map); |
568 | } |
569 | } |
570 | |
571 | void PyiGenerator::PrintServices() const { |
572 | std::vector<const ServiceDescriptor*> services; |
573 | services.reserve(n: file_->service_count()); |
574 | for (int i = 0; i < file_->service_count(); ++i) { |
575 | services.push_back(x: file_->service(index: i)); |
576 | } |
577 | std::sort(first: services.begin(), last: services.end(), comp: SortByName<ServiceDescriptor>()); |
578 | |
579 | // Prints $Service$ and $Service$_Stub classes |
580 | for (const auto& entry : services) { |
581 | printer_->Print(text: "\n" ); |
582 | printer_->Print( |
583 | text: "class $service_name$(_service.service): ...\n\n" |
584 | "class $service_name$_Stub($service_name$): ...\n" , |
585 | args: "service_name" , args: entry->name()); |
586 | } |
587 | } |
588 | |
589 | bool PyiGenerator::Generate(const FileDescriptor* file, |
590 | const std::string& parameter, |
591 | GeneratorContext* context, |
592 | std::string* error) const { |
593 | MutexLock lock(&mutex_); |
594 | // Calculate file name. |
595 | file_ = file; |
596 | std::string filename = |
597 | parameter.empty() ? GetFileName(file_des: file, suffix: ".pyi" ) : parameter; |
598 | |
599 | std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename)); |
600 | GOOGLE_CHECK(output.get()); |
601 | io::Printer printer(output.get(), '$'); |
602 | printer_ = &printer; |
603 | |
604 | // item map will store "DESCRIPTOR", top level extensions, top level enum |
605 | // values. The items will be sorted and printed later. |
606 | std::map<std::string, std::string> item_map; |
607 | |
608 | // Adds "DESCRIPTOR" into item_map. |
609 | item_map["DESCRIPTOR" ] = "_descriptor.FileDescriptor" ; |
610 | |
611 | // import_map will be a mapping from filename to module alias, e.g. |
612 | // "google3/foo/bar.py" -> "_bar" |
613 | std::map<std::string, std::string> import_map; |
614 | |
615 | PrintImports(item_map: &item_map, import_map: &import_map); |
616 | // Adds top level enum values to item_map. |
617 | for (int i = 0; i < file_->enum_type_count(); ++i) { |
618 | AddEnumValue(enum_descriptor: *file_->enum_type(index: i), item_map: &item_map, import_map); |
619 | } |
620 | // Adds top level extensions to item_map. |
621 | AddExtensions(descriptor: *file_, item_map: &item_map); |
622 | // Prints item map |
623 | PrintItemMap(item_map); |
624 | |
625 | PrintMessages(import_map); |
626 | PrintTopLevelEnums(); |
627 | if (HasGenericServices(file)) { |
628 | PrintServices(); |
629 | } |
630 | return true; |
631 | } |
632 | |
633 | } // namespace python |
634 | } // namespace compiler |
635 | } // namespace protobuf |
636 | } // namespace google |
637 | |