| 1 | // Protocol Buffers - Google's data interchange format |
| 2 | // Copyright 2008 Google Inc. All rights reserved. |
| 3 | // https://developers.google.com/protocol-buffers/ |
| 4 | // |
| 5 | // Redistribution and use in source and binary forms, with or without |
| 6 | // modification, are permitted provided that the following conditions are |
| 7 | // met: |
| 8 | // |
| 9 | // * Redistributions of source code must retain the above copyright |
| 10 | // notice, this list of conditions and the following disclaimer. |
| 11 | // * Redistributions in binary form must reproduce the above |
| 12 | // copyright notice, this list of conditions and the following disclaimer |
| 13 | // in the documentation and/or other materials provided with the |
| 14 | // distribution. |
| 15 | // * Neither the name of Google Inc. nor the names of its |
| 16 | // contributors may be used to endorse or promote products derived from |
| 17 | // this software without specific prior written permission. |
| 18 | // |
| 19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 30 | |
| 31 | #include <google/protobuf/compiler/python/pyi_generator.h> |
| 32 | |
| 33 | #include <string> |
| 34 | |
| 35 | #include <google/protobuf/stubs/strutil.h> |
| 36 | #include <google/protobuf/compiler/python/helpers.h> |
| 37 | #include <google/protobuf/descriptor.h> |
| 38 | #include <google/protobuf/descriptor.pb.h> |
| 39 | #include <google/protobuf/io/printer.h> |
| 40 | #include <google/protobuf/io/zero_copy_stream.h> |
| 41 | |
| 42 | namespace google { |
| 43 | namespace protobuf { |
| 44 | namespace compiler { |
| 45 | namespace python { |
| 46 | |
| 47 | template <typename DescriptorT> |
| 48 | struct SortByName { |
| 49 | bool operator()(const DescriptorT* l, const DescriptorT* r) const { |
| 50 | return l->name() < r->name(); |
| 51 | } |
| 52 | }; |
| 53 | |
| 54 | PyiGenerator::PyiGenerator() : file_(nullptr) {} |
| 55 | |
| 56 | PyiGenerator::~PyiGenerator() {} |
| 57 | |
| 58 | void PyiGenerator::PrintItemMap( |
| 59 | const std::map<std::string, std::string>& item_map) const { |
| 60 | for (const auto& entry : item_map) { |
| 61 | printer_->Print(text: "$key$: $value$\n" , args: "key" , args: entry.first, args: "value" , |
| 62 | args: entry.second); |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | template <typename DescriptorT> |
| 67 | std::string PyiGenerator::ModuleLevelName( |
| 68 | const DescriptorT& descriptor, |
| 69 | const std::map<std::string, std::string>& import_map) const { |
| 70 | std::string name = NamePrefixedWithNestedTypes(descriptor, "." ); |
| 71 | if (descriptor.file() != file_) { |
| 72 | std::string module_alias; |
| 73 | std::string filename = descriptor.file()->name(); |
| 74 | if (import_map.find(x: filename) == import_map.end()) { |
| 75 | std::string module_name = ModuleName(descriptor.file()->name()); |
| 76 | std::vector<std::string> tokens = Split(full: module_name, delim: "." ); |
| 77 | module_alias = "_" + tokens.back(); |
| 78 | } else { |
| 79 | module_alias = import_map.at(k: filename); |
| 80 | } |
| 81 | name = module_alias + "." + name; |
| 82 | } |
| 83 | return name; |
| 84 | } |
| 85 | |
| 86 | struct ImportModules { |
| 87 | bool has_repeated = false; // _containers |
| 88 | bool has_iterable = false; // typing.Iterable |
| 89 | bool has_messages = false; // _message |
| 90 | bool has_enums = false; // _enum_type_wrapper |
| 91 | bool has_extendable = false; // _python_message |
| 92 | bool has_mapping = false; // typing.Mapping |
| 93 | bool has_optional = false; // typing.Optional |
| 94 | bool has_union = false; // typing.Union |
| 95 | bool has_well_known_type = false; |
| 96 | }; |
| 97 | |
| 98 | // Checks whether a descriptor name matches a well-known type. |
| 99 | bool IsWellKnownType(const std::string& name) { |
| 100 | // LINT.IfChange(wktbases) |
| 101 | return (name == "google.protobuf.Any" || |
| 102 | name == "google.protobuf.Duration" || |
| 103 | name == "google.protobuf.FieldMask" || |
| 104 | name == "google.protobuf.ListValue" || |
| 105 | name == "google.protobuf.Struct" || |
| 106 | name == "google.protobuf.Timestamp" ); |
| 107 | // LINT.ThenChange(//depot/google3/net/proto2/python/internal/well_known_types.py:wktbases) |
| 108 | } |
| 109 | |
| 110 | // Checks what modules should be imported for this message |
| 111 | // descriptor. |
| 112 | void CheckImportModules(const Descriptor* descriptor, |
| 113 | ImportModules* import_modules) { |
| 114 | if (descriptor->extension_range_count() > 0) { |
| 115 | import_modules->has_extendable = true; |
| 116 | } |
| 117 | if (descriptor->enum_type_count() > 0) { |
| 118 | import_modules->has_enums = true; |
| 119 | } |
| 120 | if (IsWellKnownType(name: descriptor->full_name())) { |
| 121 | import_modules->has_well_known_type = true; |
| 122 | } |
| 123 | for (int i = 0; i < descriptor->field_count(); ++i) { |
| 124 | const FieldDescriptor* field = descriptor->field(index: i); |
| 125 | if (IsPythonKeyword(name: field->name())) { |
| 126 | continue; |
| 127 | } |
| 128 | import_modules->has_optional = true; |
| 129 | if (field->is_repeated()) { |
| 130 | import_modules->has_repeated = true; |
| 131 | } |
| 132 | if (field->is_map()) { |
| 133 | import_modules->has_mapping = true; |
| 134 | const FieldDescriptor* value_des = field->message_type()->field(index: 1); |
| 135 | if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE || |
| 136 | value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
| 137 | import_modules->has_union = true; |
| 138 | } |
| 139 | } else { |
| 140 | if (field->is_repeated()) { |
| 141 | import_modules->has_iterable = true; |
| 142 | } |
| 143 | if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
| 144 | import_modules->has_union = true; |
| 145 | import_modules->has_mapping = true; |
| 146 | } |
| 147 | if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
| 148 | import_modules->has_union = true; |
| 149 | } |
| 150 | } |
| 151 | } |
| 152 | for (int i = 0; i < descriptor->nested_type_count(); ++i) { |
| 153 | CheckImportModules(descriptor: descriptor->nested_type(index: i), import_modules); |
| 154 | } |
| 155 | } |
| 156 | |
| 157 | void PyiGenerator::PrintImportForDescriptor( |
| 158 | const FileDescriptor& desc, |
| 159 | std::map<std::string, std::string>* import_map, |
| 160 | std::set<std::string>* seen_aliases) const { |
| 161 | const std::string& filename = desc.name(); |
| 162 | std::string module_name = StrippedModuleName(filename); |
| 163 | size_t last_dot_pos = module_name.rfind(c: '.'); |
| 164 | std::string import_statement; |
| 165 | if (last_dot_pos == std::string::npos) { |
| 166 | import_statement = "import " + module_name; |
| 167 | } else { |
| 168 | import_statement = "from " + module_name.substr(pos: 0, n: last_dot_pos) + |
| 169 | " import " + module_name.substr(pos: last_dot_pos + 1); |
| 170 | module_name = module_name.substr(pos: last_dot_pos + 1); |
| 171 | } |
| 172 | std::string alias = "_" + module_name; |
| 173 | // Generate a unique alias by adding _1 suffixes until we get an unused alias. |
| 174 | while (seen_aliases->find(x: alias) != seen_aliases->end()) { |
| 175 | alias = alias + "_1" ; |
| 176 | } |
| 177 | printer_->Print(text: "$statement$ as $alias$\n" , args: "statement" , |
| 178 | args: import_statement, args: "alias" , args: alias); |
| 179 | (*import_map)[filename] = alias; |
| 180 | seen_aliases->insert(x: alias); |
| 181 | } |
| 182 | |
| 183 | void PyiGenerator::PrintImports( |
| 184 | std::map<std::string, std::string>* item_map, |
| 185 | std::map<std::string, std::string>* import_map) const { |
| 186 | // Prints imported dependent _pb2 files. |
| 187 | std::set<std::string> seen_aliases; |
| 188 | for (int i = 0; i < file_->dependency_count(); ++i) { |
| 189 | const FileDescriptor* dep = file_->dependency(index: i); |
| 190 | PrintImportForDescriptor(desc: *dep, import_map, seen_aliases: &seen_aliases); |
| 191 | for (int j = 0; j < dep->public_dependency_count(); ++j) { |
| 192 | PrintImportForDescriptor( |
| 193 | desc: *dep->public_dependency(index: j), import_map, seen_aliases: &seen_aliases); |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | // Checks what modules should be imported. |
| 198 | ImportModules import_modules; |
| 199 | if (file_->message_type_count() > 0) { |
| 200 | import_modules.has_messages = true; |
| 201 | } |
| 202 | if (file_->enum_type_count() > 0) { |
| 203 | import_modules.has_enums = true; |
| 204 | } |
| 205 | for (int i = 0; i < file_->message_type_count(); i++) { |
| 206 | CheckImportModules(descriptor: file_->message_type(index: i), import_modules: &import_modules); |
| 207 | } |
| 208 | |
| 209 | // Prints modules (e.g. _containers, _messages, typing) that are |
| 210 | // required in the proto file. |
| 211 | if (import_modules.has_repeated) { |
| 212 | printer_->Print( |
| 213 | text: "from google.protobuf.internal import containers as " |
| 214 | "_containers\n" ); |
| 215 | } |
| 216 | if (import_modules.has_enums) { |
| 217 | printer_->Print( |
| 218 | text: "from google.protobuf.internal import enum_type_wrapper" |
| 219 | " as _enum_type_wrapper\n" ); |
| 220 | } |
| 221 | if (import_modules.has_extendable) { |
| 222 | printer_->Print( |
| 223 | text: "from google.protobuf.internal import python_message" |
| 224 | " as _python_message\n" ); |
| 225 | } |
| 226 | if (import_modules.has_well_known_type) { |
| 227 | printer_->Print( |
| 228 | text: "from google.protobuf.internal import well_known_types" |
| 229 | " as _well_known_types\n" ); |
| 230 | } |
| 231 | printer_->Print( |
| 232 | text: "from google.protobuf import" |
| 233 | " descriptor as _descriptor\n" ); |
| 234 | if (import_modules.has_messages) { |
| 235 | printer_->Print( |
| 236 | text: "from google.protobuf import message as _message\n" ); |
| 237 | } |
| 238 | if (HasGenericServices(file: file_)) { |
| 239 | printer_->Print( |
| 240 | text: "from google.protobuf import service as" |
| 241 | " _service\n" ); |
| 242 | } |
| 243 | printer_->Print(text: "from typing import " ); |
| 244 | printer_->Print(text: "ClassVar as _ClassVar" ); |
| 245 | if (import_modules.has_iterable) { |
| 246 | printer_->Print(text: ", Iterable as _Iterable" ); |
| 247 | } |
| 248 | if (import_modules.has_mapping) { |
| 249 | printer_->Print(text: ", Mapping as _Mapping" ); |
| 250 | } |
| 251 | if (import_modules.has_optional) { |
| 252 | printer_->Print(text: ", Optional as _Optional" ); |
| 253 | } |
| 254 | if (import_modules.has_union) { |
| 255 | printer_->Print(text: ", Union as _Union" ); |
| 256 | } |
| 257 | printer_->Print(text: "\n\n" ); |
| 258 | |
| 259 | // Public imports |
| 260 | for (int i = 0; i < file_->public_dependency_count(); ++i) { |
| 261 | const FileDescriptor* public_dep = file_->public_dependency(index: i); |
| 262 | std::string module_name = StrippedModuleName(filename: public_dep->name()); |
| 263 | // Top level messages in public imports |
| 264 | for (int i = 0; i < public_dep->message_type_count(); ++i) { |
| 265 | printer_->Print(text: "from $module$ import $message_class$\n" , args: "module" , |
| 266 | args: module_name, args: "message_class" , |
| 267 | args: public_dep->message_type(index: i)->name()); |
| 268 | } |
| 269 | // Top level enums for public imports |
| 270 | for (int i = 0; i < public_dep->enum_type_count(); ++i) { |
| 271 | printer_->Print(text: "from $module$ import $enum_class$\n" , args: "module" , |
| 272 | args: module_name, args: "enum_class" , |
| 273 | args: public_dep->enum_type(index: i)->name()); |
| 274 | } |
| 275 | // Enum values for public imports |
| 276 | for (int i = 0; i < public_dep->enum_type_count(); ++i) { |
| 277 | const EnumDescriptor* enum_descriptor = public_dep->enum_type(index: i); |
| 278 | for (int j = 0; j < enum_descriptor->value_count(); ++j) { |
| 279 | (*item_map)[enum_descriptor->value(index: j)->name()] = |
| 280 | ModuleLevelName(descriptor: *enum_descriptor, import_map: *import_map); |
| 281 | } |
| 282 | } |
| 283 | // Top level extensions for public imports |
| 284 | AddExtensions(descriptor: *public_dep, item_map); |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const { |
| 289 | std::string enum_name = enum_descriptor.name(); |
| 290 | printer_->Print( |
| 291 | text: "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n" |
| 292 | " __slots__ = []\n" , |
| 293 | args: "enum_name" , args: enum_name); |
| 294 | } |
| 295 | |
| 296 | // Adds enum value to item map which will be ordered and printed later. |
| 297 | void PyiGenerator::AddEnumValue( |
| 298 | const EnumDescriptor& enum_descriptor, |
| 299 | std::map<std::string, std::string>* item_map, |
| 300 | const std::map<std::string, std::string>& import_map) const { |
| 301 | // enum values |
| 302 | std::string module_enum_name = ModuleLevelName(descriptor: enum_descriptor, import_map); |
| 303 | for (int j = 0; j < enum_descriptor.value_count(); ++j) { |
| 304 | const EnumValueDescriptor* value_descriptor = enum_descriptor.value(index: j); |
| 305 | (*item_map)[value_descriptor->name()] = module_enum_name; |
| 306 | } |
| 307 | } |
| 308 | |
| 309 | // Prints top level enums |
| 310 | void PyiGenerator::PrintTopLevelEnums() const { |
| 311 | for (int i = 0; i < file_->enum_type_count(); ++i) { |
| 312 | printer_->Print(text: "\n" ); |
| 313 | PrintEnum(enum_descriptor: *file_->enum_type(index: i)); |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | // Add top level extensions to item_map which will be ordered and |
| 318 | // printed later. |
| 319 | template <typename DescriptorT> |
| 320 | void PyiGenerator::AddExtensions( |
| 321 | const DescriptorT& descriptor, |
| 322 | std::map<std::string, std::string>* item_map) const { |
| 323 | for (int i = 0; i < descriptor.extension_count(); ++i) { |
| 324 | const FieldDescriptor* extension_field = descriptor.extension(i); |
| 325 | std::string constant_name = extension_field->name() + "_FIELD_NUMBER" ; |
| 326 | ToUpper(s: &constant_name); |
| 327 | (*item_map)[constant_name] = "_ClassVar[int]" ; |
| 328 | (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor" ; |
| 329 | } |
| 330 | } |
| 331 | |
| 332 | // Returns the string format of a field's cpp_type |
| 333 | std::string PyiGenerator::GetFieldType( |
| 334 | const FieldDescriptor& field_des, const Descriptor& containing_des, |
| 335 | const std::map<std::string, std::string>& import_map) const { |
| 336 | switch (field_des.cpp_type()) { |
| 337 | case FieldDescriptor::CPPTYPE_INT32: |
| 338 | case FieldDescriptor::CPPTYPE_UINT32: |
| 339 | case FieldDescriptor::CPPTYPE_INT64: |
| 340 | case FieldDescriptor::CPPTYPE_UINT64: |
| 341 | return "int" ; |
| 342 | case FieldDescriptor::CPPTYPE_DOUBLE: |
| 343 | case FieldDescriptor::CPPTYPE_FLOAT: |
| 344 | return "float" ; |
| 345 | case FieldDescriptor::CPPTYPE_BOOL: |
| 346 | return "bool" ; |
| 347 | case FieldDescriptor::CPPTYPE_ENUM: |
| 348 | return ModuleLevelName(descriptor: *field_des.enum_type(), import_map); |
| 349 | case FieldDescriptor::CPPTYPE_STRING: |
| 350 | if (field_des.type() == FieldDescriptor::TYPE_STRING) { |
| 351 | return "str" ; |
| 352 | } else { |
| 353 | return "bytes" ; |
| 354 | } |
| 355 | case FieldDescriptor::CPPTYPE_MESSAGE: { |
| 356 | // If the field is inside a nested message and the nested message has the |
| 357 | // same name as a top-level message, then we need to prefix the field type |
| 358 | // with the module name for disambiguation. |
| 359 | std::string name = ModuleLevelName(descriptor: *field_des.message_type(), import_map); |
| 360 | if ((containing_des.containing_type() != nullptr && |
| 361 | name == containing_des.name())) { |
| 362 | std::string module = ModuleName(filename: field_des.file()->name()); |
| 363 | name = module + "." + name; |
| 364 | } |
| 365 | return name; |
| 366 | } |
| 367 | default: |
| 368 | GOOGLE_LOG(FATAL) << "Unsupported field type." ; |
| 369 | } |
| 370 | return "" ; |
| 371 | } |
| 372 | |
| 373 | void PyiGenerator::PrintMessage( |
| 374 | const Descriptor& message_descriptor, bool is_nested, |
| 375 | const std::map<std::string, std::string>& import_map) const { |
| 376 | if (!is_nested) { |
| 377 | printer_->Print(text: "\n" ); |
| 378 | } |
| 379 | std::string class_name = message_descriptor.name(); |
| 380 | std::string ; |
| 381 | // A well-known type needs to inherit from its corresponding base class in |
| 382 | // net/proto2/python/internal/well_known_types. |
| 383 | if (IsWellKnownType(name: message_descriptor.full_name())) { |
| 384 | extra_base = ", _well_known_types." + message_descriptor.name(); |
| 385 | } else { |
| 386 | extra_base = "" ; |
| 387 | } |
| 388 | printer_->Print(text: "class $class_name$(_message.Message$extra_base$):\n" , |
| 389 | args: "class_name" , args: class_name, args: "extra_base" , args: extra_base); |
| 390 | printer_->Indent(); |
| 391 | printer_->Indent(); |
| 392 | |
| 393 | std::vector<const FieldDescriptor*> fields; |
| 394 | fields.reserve(n: message_descriptor.field_count()); |
| 395 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
| 396 | fields.push_back(x: message_descriptor.field(index: i)); |
| 397 | } |
| 398 | std::sort(first: fields.begin(), last: fields.end(), comp: SortByName<FieldDescriptor>()); |
| 399 | |
| 400 | // Prints slots |
| 401 | printer_->Print(text: "__slots__ = [" , args: "class_name" , args: class_name); |
| 402 | bool first_item = true; |
| 403 | for (const auto& field_des : fields) { |
| 404 | if (IsPythonKeyword(name: field_des->name())) { |
| 405 | continue; |
| 406 | } |
| 407 | if (first_item) { |
| 408 | first_item = false; |
| 409 | } else { |
| 410 | printer_->Print(text: ", " ); |
| 411 | } |
| 412 | printer_->Print(text: "\"$field_name$\"" , args: "field_name" , args: field_des->name()); |
| 413 | } |
| 414 | printer_->Print(text: "]\n" ); |
| 415 | |
| 416 | std::map<std::string, std::string> item_map; |
| 417 | // Prints Extensions for extendable messages |
| 418 | if (message_descriptor.extension_range_count() > 0) { |
| 419 | item_map["Extensions" ] = "_python_message._ExtensionDict" ; |
| 420 | } |
| 421 | |
| 422 | // Prints nested enums |
| 423 | std::vector<const EnumDescriptor*> nested_enums; |
| 424 | nested_enums.reserve(n: message_descriptor.enum_type_count()); |
| 425 | for (int i = 0; i < message_descriptor.enum_type_count(); ++i) { |
| 426 | nested_enums.push_back(x: message_descriptor.enum_type(index: i)); |
| 427 | } |
| 428 | std::sort(first: nested_enums.begin(), last: nested_enums.end(), |
| 429 | comp: SortByName<EnumDescriptor>()); |
| 430 | |
| 431 | for (const auto& entry : nested_enums) { |
| 432 | PrintEnum(enum_descriptor: *entry); |
| 433 | // Adds enum value to item_map which will be ordered and printed later |
| 434 | AddEnumValue(enum_descriptor: *entry, item_map: &item_map, import_map); |
| 435 | } |
| 436 | |
| 437 | // Prints nested messages |
| 438 | std::vector<const Descriptor*> nested_messages; |
| 439 | nested_messages.reserve(n: message_descriptor.nested_type_count()); |
| 440 | for (int i = 0; i < message_descriptor.nested_type_count(); ++i) { |
| 441 | nested_messages.push_back(x: message_descriptor.nested_type(index: i)); |
| 442 | } |
| 443 | std::sort(first: nested_messages.begin(), last: nested_messages.end(), |
| 444 | comp: SortByName<Descriptor>()); |
| 445 | |
| 446 | for (const auto& entry : nested_messages) { |
| 447 | PrintMessage(message_descriptor: *entry, is_nested: true, import_map); |
| 448 | } |
| 449 | |
| 450 | // Adds extensions to item_map which will be ordered and printed later |
| 451 | AddExtensions(descriptor: message_descriptor, item_map: &item_map); |
| 452 | |
| 453 | // Adds field number and field descriptor to item_map |
| 454 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
| 455 | const FieldDescriptor& field_des = *message_descriptor.field(index: i); |
| 456 | item_map[ToUpper(s: field_des.name()) + "_FIELD_NUMBER" ] = |
| 457 | "_ClassVar[int]" ; |
| 458 | if (IsPythonKeyword(name: field_des.name())) { |
| 459 | continue; |
| 460 | } |
| 461 | std::string field_type = "" ; |
| 462 | if (field_des.is_map()) { |
| 463 | const FieldDescriptor* key_des = field_des.message_type()->field(index: 0); |
| 464 | const FieldDescriptor* value_des = field_des.message_type()->field(index: 1); |
| 465 | field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE |
| 466 | ? "_containers.MessageMap[" |
| 467 | : "_containers.ScalarMap[" ); |
| 468 | field_type += GetFieldType(field_des: *key_des, containing_des: message_descriptor, import_map); |
| 469 | field_type += ", " ; |
| 470 | field_type += GetFieldType(field_des: *value_des, containing_des: message_descriptor, import_map); |
| 471 | } else { |
| 472 | if (field_des.is_repeated()) { |
| 473 | field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE |
| 474 | ? "_containers.RepeatedCompositeFieldContainer[" |
| 475 | : "_containers.RepeatedScalarFieldContainer[" ); |
| 476 | } |
| 477 | field_type += GetFieldType(field_des, containing_des: message_descriptor, import_map); |
| 478 | } |
| 479 | |
| 480 | if (field_des.is_repeated()) { |
| 481 | field_type += "]" ; |
| 482 | } |
| 483 | item_map[field_des.name()] = field_type; |
| 484 | } |
| 485 | |
| 486 | // Prints all items in item_map |
| 487 | PrintItemMap(item_map); |
| 488 | |
| 489 | // Prints __init__ |
| 490 | printer_->Print(text: "def __init__(self" ); |
| 491 | bool has_key_words = false; |
| 492 | bool is_first = true; |
| 493 | for (int i = 0; i < message_descriptor.field_count(); ++i) { |
| 494 | const FieldDescriptor* field_des = message_descriptor.field(index: i); |
| 495 | if (IsPythonKeyword(name: field_des->name())) { |
| 496 | has_key_words = true; |
| 497 | continue; |
| 498 | } |
| 499 | std::string field_name = field_des->name(); |
| 500 | if (is_first && field_name == "self" ) { |
| 501 | // See b/144146793 for an example of real code that generates a (self, |
| 502 | // self) method signature. Since repeating a parameter name is illegal in |
| 503 | // Python, we rename the duplicate self. |
| 504 | field_name = "self_" ; |
| 505 | } |
| 506 | is_first = false; |
| 507 | printer_->Print(text: ", $field_name$: " , args: "field_name" , args: field_name); |
| 508 | if (field_des->is_repeated() || |
| 509 | field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) { |
| 510 | printer_->Print(text: "_Optional[" ); |
| 511 | } |
| 512 | if (field_des->is_map()) { |
| 513 | const Descriptor* map_entry = field_des->message_type(); |
| 514 | printer_->Print( |
| 515 | text: "_Mapping[$key_type$, $value_type$]" , args: "key_type" , |
| 516 | args: GetFieldType(field_des: *map_entry->field(index: 0), containing_des: message_descriptor, import_map), |
| 517 | args: "value_type" , |
| 518 | args: GetFieldType(field_des: *map_entry->field(index: 1), containing_des: message_descriptor, import_map)); |
| 519 | } else { |
| 520 | if (field_des->is_repeated()) { |
| 521 | printer_->Print(text: "_Iterable[" ); |
| 522 | } |
| 523 | if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
| 524 | printer_->Print( |
| 525 | text: "_Union[$type_name$, _Mapping]" , args: "type_name" , |
| 526 | args: GetFieldType(field_des: *field_des, containing_des: message_descriptor, import_map)); |
| 527 | } else { |
| 528 | if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
| 529 | printer_->Print(text: "_Union[$type_name$, str]" , args: "type_name" , |
| 530 | args: ModuleLevelName(descriptor: *field_des->enum_type(), import_map)); |
| 531 | } else { |
| 532 | printer_->Print( |
| 533 | text: "$type_name$" , args: "type_name" , |
| 534 | args: GetFieldType(field_des: *field_des, containing_des: message_descriptor, import_map)); |
| 535 | } |
| 536 | } |
| 537 | if (field_des->is_repeated()) { |
| 538 | printer_->Print(text: "]" ); |
| 539 | } |
| 540 | } |
| 541 | if (field_des->is_repeated() || |
| 542 | field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) { |
| 543 | printer_->Print(text: "]" ); |
| 544 | } |
| 545 | printer_->Print(text: " = ..." ); |
| 546 | } |
| 547 | if (has_key_words) { |
| 548 | printer_->Print(text: ", **kwargs" ); |
| 549 | } |
| 550 | printer_->Print(text: ") -> None: ...\n" ); |
| 551 | |
| 552 | printer_->Outdent(); |
| 553 | printer_->Outdent(); |
| 554 | } |
| 555 | |
| 556 | void PyiGenerator::PrintMessages( |
| 557 | const std::map<std::string, std::string>& import_map) const { |
| 558 | // Deterministically order the descriptors. |
| 559 | std::vector<const Descriptor*> messages; |
| 560 | messages.reserve(n: file_->message_type_count()); |
| 561 | for (int i = 0; i < file_->message_type_count(); ++i) { |
| 562 | messages.push_back(x: file_->message_type(index: i)); |
| 563 | } |
| 564 | std::sort(first: messages.begin(), last: messages.end(), comp: SortByName<Descriptor>()); |
| 565 | |
| 566 | for (const auto& entry : messages) { |
| 567 | PrintMessage(message_descriptor: *entry, is_nested: false, import_map); |
| 568 | } |
| 569 | } |
| 570 | |
| 571 | void PyiGenerator::PrintServices() const { |
| 572 | std::vector<const ServiceDescriptor*> services; |
| 573 | services.reserve(n: file_->service_count()); |
| 574 | for (int i = 0; i < file_->service_count(); ++i) { |
| 575 | services.push_back(x: file_->service(index: i)); |
| 576 | } |
| 577 | std::sort(first: services.begin(), last: services.end(), comp: SortByName<ServiceDescriptor>()); |
| 578 | |
| 579 | // Prints $Service$ and $Service$_Stub classes |
| 580 | for (const auto& entry : services) { |
| 581 | printer_->Print(text: "\n" ); |
| 582 | printer_->Print( |
| 583 | text: "class $service_name$(_service.service): ...\n\n" |
| 584 | "class $service_name$_Stub($service_name$): ...\n" , |
| 585 | args: "service_name" , args: entry->name()); |
| 586 | } |
| 587 | } |
| 588 | |
| 589 | bool PyiGenerator::Generate(const FileDescriptor* file, |
| 590 | const std::string& parameter, |
| 591 | GeneratorContext* context, |
| 592 | std::string* error) const { |
| 593 | MutexLock lock(&mutex_); |
| 594 | // Calculate file name. |
| 595 | file_ = file; |
| 596 | std::string filename = |
| 597 | parameter.empty() ? GetFileName(file_des: file, suffix: ".pyi" ) : parameter; |
| 598 | |
| 599 | std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename)); |
| 600 | GOOGLE_CHECK(output.get()); |
| 601 | io::Printer printer(output.get(), '$'); |
| 602 | printer_ = &printer; |
| 603 | |
| 604 | // item map will store "DESCRIPTOR", top level extensions, top level enum |
| 605 | // values. The items will be sorted and printed later. |
| 606 | std::map<std::string, std::string> item_map; |
| 607 | |
| 608 | // Adds "DESCRIPTOR" into item_map. |
| 609 | item_map["DESCRIPTOR" ] = "_descriptor.FileDescriptor" ; |
| 610 | |
| 611 | // import_map will be a mapping from filename to module alias, e.g. |
| 612 | // "google3/foo/bar.py" -> "_bar" |
| 613 | std::map<std::string, std::string> import_map; |
| 614 | |
| 615 | PrintImports(item_map: &item_map, import_map: &import_map); |
| 616 | // Adds top level enum values to item_map. |
| 617 | for (int i = 0; i < file_->enum_type_count(); ++i) { |
| 618 | AddEnumValue(enum_descriptor: *file_->enum_type(index: i), item_map: &item_map, import_map); |
| 619 | } |
| 620 | // Adds top level extensions to item_map. |
| 621 | AddExtensions(descriptor: *file_, item_map: &item_map); |
| 622 | // Prints item map |
| 623 | PrintItemMap(item_map); |
| 624 | |
| 625 | PrintMessages(import_map); |
| 626 | PrintTopLevelEnums(); |
| 627 | if (HasGenericServices(file)) { |
| 628 | PrintServices(); |
| 629 | } |
| 630 | return true; |
| 631 | } |
| 632 | |
| 633 | } // namespace python |
| 634 | } // namespace compiler |
| 635 | } // namespace protobuf |
| 636 | } // namespace google |
| 637 | |