* Update protobuf pregenerated files * Update grpcio-tools to latest version * Implement proto3 field presence * Fix to_dict with None optional fields. * Add test with optional enum * Properly support optional enums * Add tests for 64-bit ints and floats * Support field presence for int64 types * Fix oneof serialization with proto3 field presence (#292) = Description The serialization of a oneof message that contains a message with fields with explicit presence was buggy. For example: ``` message A { oneof kind { B b = 1; C c = 2; } } message B {} message C { optional bool z = 1; } ``` Serializing `A(b=B())` would lead to this payload: ``` 0A # tag1, length delimited 00 # length: 0 12 # tag2, length delimited 00 # length: 0 ``` Which when deserialized, leads to the message `A(c=C())`. = Explanation The issue lies in the post_init method. All fields are introspected, and if different from PLACEHOLDER, the message is marked as having been "serialized_on_wire". Then, when serializing `A(b=B())`, we go through each field of the oneof: - field 'b': this is the selected field from the group, so it is serialized - field 'c': marked as 'serialized_on_wire', so it is added as well. = Fix The issue is that support for explicit presence changed the default value from PLACEHOLDER to None. This breaks the post_init method in that case, which is relatively easy to fix: if a field is optional, and set to None, this is considered as the default value (which it is). This fix however has a side-effect: the group_current for this field (the oneof trick for explicit presence) is no longer set. This changes the behavior when serializing the message in JSON: as the value is the default one (None), and the group is not set (which would force the serialization of the field), so None fields are no longer serialized in JSON. This break one test, and will be fixed in the next commit. * fix: do not serialize None fields in JSON format This is linked to the fix from the previous commit: after it, scalar None fields were not included in the JSON format, but some were still included. This is all cleaned up: None fields are not added in JSON by default, as they indicate the default value of fields with explicit presence. However, if `include_default_values is set, they are included. * Fix: use builtin annotation prefix * Remove comment Co-authored-by: roblabla <unfiltered@roblab.la> Co-authored-by: Vincent Thiberville <vthib@pm.me>
809 lines
27 KiB
Python
809 lines
27 KiB
Python
"""Plugin model dataclasses.
|
|
|
|
These classes are meant to be an intermediate representation
|
|
of protobuf objects. They are used to organize the data collected during parsing.
|
|
|
|
The general intention is to create a doubly-linked tree-like structure
|
|
with the following types of references:
|
|
- Downwards references: from message -> fields, from output package -> messages
|
|
or from service -> service methods
|
|
- Upwards references: from field -> message, message -> package.
|
|
- Input/output message references: from a service method to it's corresponding
|
|
input/output messages, which may even be in another package.
|
|
|
|
There are convenience methods to allow climbing up and down this tree, for
|
|
example to retrieve the list of all messages that are in the same package as
|
|
the current message.
|
|
|
|
Most of these classes take as inputs:
|
|
- proto_obj: A reference to it's corresponding protobuf object as
|
|
presented by the protoc plugin.
|
|
- parent: a reference to the parent object in the tree.
|
|
|
|
With this information, the class is able to expose attributes,
|
|
such as a pythonized name, that will be calculated from proto_obj.
|
|
|
|
The instantiation should also attach a reference to the new object
|
|
into the corresponding place within it's parent object. For example,
|
|
instantiating field `A` with parent message `B` should add a
|
|
reference to `A` to `B`'s `fields` attribute.
|
|
"""
|
|
|
|
|
|
import builtins
|
|
import betterproto
|
|
from betterproto import which_one_of
|
|
from betterproto.casing import sanitize_name
|
|
from betterproto.compile.importing import (
|
|
get_type_reference,
|
|
parse_source_type_name,
|
|
)
|
|
from betterproto.compile.naming import (
|
|
pythonize_class_name,
|
|
pythonize_field_name,
|
|
pythonize_method_name,
|
|
)
|
|
from betterproto.lib.google.protobuf import (
|
|
DescriptorProto,
|
|
EnumDescriptorProto,
|
|
FileDescriptorProto,
|
|
MethodDescriptorProto,
|
|
Field,
|
|
FieldDescriptorProto,
|
|
FieldDescriptorProtoType,
|
|
FieldDescriptorProtoLabel,
|
|
)
|
|
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
|
|
|
|
|
import re
|
|
import textwrap
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
|
|
|
|
from ..casing import sanitize_name
|
|
from ..compile.importing import get_type_reference, parse_source_type_name
|
|
from ..compile.naming import (
|
|
pythonize_class_name,
|
|
pythonize_field_name,
|
|
pythonize_method_name,
|
|
)
|
|
|
|
|
|
# Create a unique placeholder to deal with
|
|
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
|
PLACEHOLDER = object()
|
|
|
|
# Organize proto types into categories
|
|
PROTO_FLOAT_TYPES = (
|
|
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
|
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
|
)
|
|
PROTO_INT_TYPES = (
|
|
FieldDescriptorProtoType.TYPE_INT64, # 3
|
|
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
|
FieldDescriptorProtoType.TYPE_INT32, # 5
|
|
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
|
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
|
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
|
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
|
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
|
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
|
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
|
)
|
|
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
|
|
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
|
|
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
|
|
PROTO_MESSAGE_TYPES = (
|
|
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
|
|
FieldDescriptorProtoType.TYPE_ENUM, # 14
|
|
)
|
|
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
|
|
PROTO_PACKED_TYPES = (
|
|
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
|
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
|
FieldDescriptorProtoType.TYPE_INT64, # 3
|
|
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
|
FieldDescriptorProtoType.TYPE_INT32, # 5
|
|
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
|
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
|
FieldDescriptorProtoType.TYPE_BOOL, # 8
|
|
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
|
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
|
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
|
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
|
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
|
)
|
|
|
|
|
|
def monkey_patch_oneof_index():
|
|
"""
|
|
The compiler message types are written for proto2, but we read them as proto3.
|
|
For this to work in the case of the oneof_index fields, which depend on being able
|
|
to tell whether they were set, we have to treat them as oneof fields. This method
|
|
monkey patches the generated classes after the fact to force this behaviour.
|
|
"""
|
|
object.__setattr__(
|
|
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
|
|
"betterproto"
|
|
],
|
|
"group",
|
|
"oneof_index",
|
|
)
|
|
object.__setattr__(
|
|
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
|
|
"group",
|
|
"oneof_index",
|
|
)
|
|
|
|
|
|
def get_comment(
|
|
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
|
) -> str:
|
|
pad = " " * indent
|
|
for sci_loc in proto_file.source_code_info.location:
|
|
if list(sci_loc.path) == path and sci_loc.leading_comments:
|
|
lines = textwrap.wrap(
|
|
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
|
|
)
|
|
|
|
if path[-2] == 2 and path[-4] != 6:
|
|
# This is a field
|
|
return f"{pad}# " + f"\n{pad}# ".join(lines)
|
|
else:
|
|
# This is a message, enum, service, or method
|
|
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
|
|
lines[0] = lines[0].strip('"')
|
|
return f'{pad}"""{lines[0]}"""'
|
|
else:
|
|
joined = f"\n{pad}".join(lines)
|
|
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
|
|
|
|
return ""
|
|
|
|
|
|
class ProtoContentBase:
|
|
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
|
|
|
source_file: FileDescriptorProto
|
|
path: List[int]
|
|
comment_indent: int = 4
|
|
parent: Union["betterproto.Message", "OutputTemplate"]
|
|
|
|
__dataclass_fields__: Dict[str, object]
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Checks that no fake default fields were left as placeholders."""
|
|
for field_name, field_val in self.__dataclass_fields__.items():
|
|
if field_val is PLACEHOLDER:
|
|
raise ValueError(f"`{field_name}` is a required field.")
|
|
|
|
@property
|
|
def output_file(self) -> "OutputTemplate":
|
|
current = self
|
|
while not isinstance(current, OutputTemplate):
|
|
current = current.parent
|
|
return current
|
|
|
|
@property
|
|
def request(self) -> "PluginRequestCompiler":
|
|
current = self
|
|
while not isinstance(current, OutputTemplate):
|
|
current = current.parent
|
|
return current.parent_request
|
|
|
|
@property
|
|
def comment(self) -> str:
|
|
"""Crawl the proto source code and retrieve comments
|
|
for this object.
|
|
"""
|
|
return get_comment(
|
|
proto_file=self.source_file, path=self.path, indent=self.comment_indent
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PluginRequestCompiler:
|
|
|
|
plugin_request_obj: CodeGeneratorRequest
|
|
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
|
|
|
@property
|
|
def all_messages(self) -> List["MessageCompiler"]:
|
|
"""All of the messages in this request.
|
|
|
|
Returns
|
|
-------
|
|
List[MessageCompiler]
|
|
List of all of the messages in this request.
|
|
"""
|
|
return [
|
|
msg for output in self.output_packages.values() for msg in output.messages
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class OutputTemplate:
|
|
"""Representation of an output .py file.
|
|
|
|
Each output file corresponds to a .proto input file,
|
|
but may need references to other .proto files to be
|
|
built.
|
|
"""
|
|
|
|
parent_request: PluginRequestCompiler
|
|
package_proto_obj: FileDescriptorProto
|
|
input_files: List[str] = field(default_factory=list)
|
|
imports: Set[str] = field(default_factory=set)
|
|
datetime_imports: Set[str] = field(default_factory=set)
|
|
typing_imports: Set[str] = field(default_factory=set)
|
|
builtins_import: bool = False
|
|
messages: List["MessageCompiler"] = field(default_factory=list)
|
|
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
|
services: List["ServiceCompiler"] = field(default_factory=list)
|
|
|
|
@property
|
|
def package(self) -> str:
|
|
"""Name of input package.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
Name of input package.
|
|
"""
|
|
return self.package_proto_obj.package
|
|
|
|
@property
|
|
def input_filenames(self) -> Iterable[str]:
|
|
"""Names of the input files used to build this output.
|
|
|
|
Returns
|
|
-------
|
|
Iterable[str]
|
|
Names of the input files used to build this output.
|
|
"""
|
|
return sorted(f.name for f in self.input_files)
|
|
|
|
@property
|
|
def python_module_imports(self) -> Set[str]:
|
|
imports = set()
|
|
if any(x for x in self.messages if any(x.deprecated_fields)):
|
|
imports.add("warnings")
|
|
if self.builtins_import:
|
|
imports.add("builtins")
|
|
return imports
|
|
|
|
|
|
@dataclass
|
|
class MessageCompiler(ProtoContentBase):
|
|
"""Representation of a protobuf message."""
|
|
|
|
source_file: FileDescriptorProto
|
|
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
|
proto_obj: DescriptorProto = PLACEHOLDER
|
|
path: List[int] = PLACEHOLDER
|
|
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
|
|
default_factory=list
|
|
)
|
|
deprecated: bool = field(default=False, init=False)
|
|
builtins_types: Set[str] = field(default_factory=set)
|
|
|
|
def __post_init__(self) -> None:
|
|
# Add message to output file
|
|
if isinstance(self.parent, OutputTemplate):
|
|
if isinstance(self, EnumDefinitionCompiler):
|
|
self.output_file.enums.append(self)
|
|
else:
|
|
self.output_file.messages.append(self)
|
|
self.deprecated = self.proto_obj.options.deprecated
|
|
super().__post_init__()
|
|
|
|
@property
|
|
def proto_name(self) -> str:
|
|
return self.proto_obj.name
|
|
|
|
@property
|
|
def py_name(self) -> str:
|
|
return pythonize_class_name(self.proto_name)
|
|
|
|
@property
|
|
def annotation(self) -> str:
|
|
if self.repeated:
|
|
return f"List[{self.py_name}]"
|
|
return self.py_name
|
|
|
|
@property
|
|
def deprecated_fields(self) -> Iterator[str]:
|
|
for f in self.fields:
|
|
if f.deprecated:
|
|
yield f.py_name
|
|
|
|
@property
|
|
def has_deprecated_fields(self) -> bool:
|
|
return any(self.deprecated_fields)
|
|
|
|
|
|
def is_map(
|
|
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
|
) -> bool:
|
|
"""True if proto_field_obj is a map, otherwise False."""
|
|
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
|
|
# This might be a map...
|
|
message_type = proto_field_obj.type_name.split(".").pop().lower()
|
|
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
|
if message_type == map_entry:
|
|
for nested in parent_message.nested_type: # parent message
|
|
if (
|
|
nested.name.replace("_", "").lower() == map_entry
|
|
and nested.options.map_entry
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
|
"""
|
|
True if proto_field_obj is a OneOf, otherwise False.
|
|
|
|
.. warning::
|
|
Becuase the message from protoc is defined in proto2, and betterproto works with
|
|
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
|
|
distinguishing between default and unset values (which proto3 doesn't support),
|
|
we have to hack the generated FieldDescriptorProto class for this to work.
|
|
The hack consists of setting group="oneof_index" in the field metadata,
|
|
essentially making oneof_index the sole member of a one_of group, which allows
|
|
us to tell whether it was set, via the which_one_of interface.
|
|
"""
|
|
|
|
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
|
|
|
|
|
@dataclass
|
|
class FieldCompiler(MessageCompiler):
|
|
parent: MessageCompiler = PLACEHOLDER
|
|
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
|
|
|
def __post_init__(self) -> None:
|
|
# Add field to message
|
|
self.parent.fields.append(self)
|
|
# Check for new imports
|
|
self.add_imports_to(self.output_file)
|
|
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
|
|
|
def get_field_string(self, indent: int = 4) -> str:
|
|
"""Construct string representation of this field as a field."""
|
|
name = f"{self.py_name}"
|
|
annotations = f": {self.annotation}"
|
|
field_args = ", ".join(
|
|
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
|
)
|
|
betterproto_field_type = (
|
|
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
|
)
|
|
if self.py_name in dir(builtins):
|
|
self.parent.builtins_types.add(self.py_name)
|
|
return f"{name}{annotations} = {betterproto_field_type}"
|
|
|
|
@property
|
|
def betterproto_field_args(self) -> List[str]:
|
|
args = []
|
|
if self.field_wraps:
|
|
args.append(f"wraps={self.field_wraps}")
|
|
if self.optional:
|
|
args.append(f"optional=True")
|
|
return args
|
|
|
|
@property
|
|
def datetime_imports(self) -> Set[str]:
|
|
imports = set()
|
|
annotation = self.annotation
|
|
# FIXME: false positives - e.g. `MyDatetimedelta`
|
|
if "timedelta" in annotation:
|
|
imports.add("timedelta")
|
|
if "datetime" in annotation:
|
|
imports.add("datetime")
|
|
return imports
|
|
|
|
@property
|
|
def typing_imports(self) -> Set[str]:
|
|
imports = set()
|
|
annotation = self.annotation
|
|
if "Optional[" in annotation:
|
|
imports.add("Optional")
|
|
if "List[" in annotation:
|
|
imports.add("List")
|
|
if "Dict[" in annotation:
|
|
imports.add("Dict")
|
|
return imports
|
|
|
|
@property
|
|
def use_builtins(self) -> bool:
|
|
return self.py_type in self.parent.builtins_types or (
|
|
self.py_type == self.py_name and self.py_name in dir(builtins)
|
|
)
|
|
|
|
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
|
output_file.datetime_imports.update(self.datetime_imports)
|
|
output_file.typing_imports.update(self.typing_imports)
|
|
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
|
|
|
@property
|
|
def field_wraps(self) -> Optional[str]:
|
|
"""Returns betterproto wrapped field type or None."""
|
|
match_wrapper = re.match(
|
|
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
|
|
)
|
|
if match_wrapper:
|
|
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
|
if hasattr(betterproto, wrapped_type):
|
|
return f"betterproto.{wrapped_type}"
|
|
return None
|
|
|
|
@property
|
|
def repeated(self) -> bool:
|
|
return (
|
|
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
|
|
and not is_map(self.proto_obj, self.parent)
|
|
)
|
|
|
|
@property
|
|
def optional(self) -> bool:
|
|
return self.proto_obj.proto3_optional
|
|
|
|
@property
|
|
def mutable(self) -> bool:
|
|
"""True if the field is a mutable type, otherwise False."""
|
|
return self.annotation.startswith(("List[", "Dict["))
|
|
|
|
@property
|
|
def field_type(self) -> str:
|
|
"""String representation of proto field type."""
|
|
return (
|
|
FieldDescriptorProtoType(self.proto_obj.type)
|
|
.name.lower()
|
|
.replace("type_", "")
|
|
)
|
|
|
|
@property
|
|
def default_value_string(self) -> str:
|
|
"""Python representation of the default proto value."""
|
|
if self.repeated:
|
|
return "[]"
|
|
if self.optional:
|
|
return "None"
|
|
if self.py_type == "int":
|
|
return "0"
|
|
if self.py_type == "float":
|
|
return "0.0"
|
|
elif self.py_type == "bool":
|
|
return "False"
|
|
elif self.py_type == "str":
|
|
return '""'
|
|
elif self.py_type == "bytes":
|
|
return 'b""'
|
|
elif self.field_type == "enum":
|
|
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
|
|
enum = next(
|
|
e
|
|
for e in self.output_file.enums
|
|
if e.proto_obj.name == enum_proto_obj_name
|
|
)
|
|
return enum.default_value_string
|
|
else:
|
|
# Message type
|
|
return "None"
|
|
|
|
@property
|
|
def packed(self) -> bool:
|
|
"""True if the wire representation is a packed format."""
|
|
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
|
|
|
@property
|
|
def py_name(self) -> str:
|
|
"""Pythonized name."""
|
|
return pythonize_field_name(self.proto_name)
|
|
|
|
@property
|
|
def proto_name(self) -> str:
|
|
"""Original protobuf name."""
|
|
return self.proto_obj.name
|
|
|
|
@property
|
|
def py_type(self) -> str:
|
|
"""String representation of Python type."""
|
|
if self.proto_obj.type in PROTO_FLOAT_TYPES:
|
|
return "float"
|
|
elif self.proto_obj.type in PROTO_INT_TYPES:
|
|
return "int"
|
|
elif self.proto_obj.type in PROTO_BOOL_TYPES:
|
|
return "bool"
|
|
elif self.proto_obj.type in PROTO_STR_TYPES:
|
|
return "str"
|
|
elif self.proto_obj.type in PROTO_BYTES_TYPES:
|
|
return "bytes"
|
|
elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
|
|
# Type referencing another defined Message or a named enum
|
|
return get_type_reference(
|
|
package=self.output_file.package,
|
|
imports=self.output_file.imports,
|
|
source_type=self.proto_obj.type_name,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unknown type {field.type}")
|
|
|
|
@property
|
|
def annotation(self) -> str:
|
|
py_type = self.py_type
|
|
if self.use_builtins:
|
|
py_type = f"builtins.{py_type}"
|
|
if self.repeated:
|
|
return f"List[{py_type}]"
|
|
if self.optional:
|
|
return f"Optional[{py_type}]"
|
|
return py_type
|
|
|
|
|
|
@dataclass
|
|
class OneOfFieldCompiler(FieldCompiler):
|
|
@property
|
|
def betterproto_field_args(self) -> List[str]:
|
|
args = super().betterproto_field_args
|
|
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
|
|
args.append(f'group="{group}"')
|
|
return args
|
|
|
|
|
|
@dataclass
|
|
class MapEntryCompiler(FieldCompiler):
|
|
py_k_type: Type = PLACEHOLDER
|
|
py_v_type: Type = PLACEHOLDER
|
|
proto_k_type: str = PLACEHOLDER
|
|
proto_v_type: str = PLACEHOLDER
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Explore nested types and set k_type and v_type if unset."""
|
|
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
|
for nested in self.parent.proto_obj.nested_type:
|
|
if (
|
|
nested.name.replace("_", "").lower() == map_entry
|
|
and nested.options.map_entry
|
|
):
|
|
# Get Python types
|
|
self.py_k_type = FieldCompiler(
|
|
source_file=self.source_file,
|
|
parent=self,
|
|
proto_obj=nested.field[0], # key
|
|
).py_type
|
|
self.py_v_type = FieldCompiler(
|
|
source_file=self.source_file,
|
|
parent=self,
|
|
proto_obj=nested.field[1], # value
|
|
).py_type
|
|
|
|
# Get proto types
|
|
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
|
|
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
|
|
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
|
|
|
@property
|
|
def betterproto_field_args(self) -> List[str]:
|
|
return [f"betterproto.{self.proto_k_type}", f"betterproto.{self.proto_v_type}"]
|
|
|
|
@property
|
|
def field_type(self) -> str:
|
|
return "map"
|
|
|
|
@property
|
|
def annotation(self) -> str:
|
|
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
|
|
|
@property
|
|
def repeated(self) -> bool:
|
|
return False # maps cannot be repeated
|
|
|
|
|
|
@dataclass
|
|
class EnumDefinitionCompiler(MessageCompiler):
|
|
"""Representation of a proto Enum definition."""
|
|
|
|
proto_obj: EnumDescriptorProto = PLACEHOLDER
|
|
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
|
|
|
|
@dataclass(unsafe_hash=True)
|
|
class EnumEntry:
|
|
"""Representation of an Enum entry."""
|
|
|
|
name: str
|
|
value: int
|
|
comment: str
|
|
|
|
def __post_init__(self) -> None:
|
|
# Get entries/allowed values for this Enum
|
|
self.entries = [
|
|
self.EnumEntry(
|
|
name=sanitize_name(entry_proto_value.name),
|
|
value=entry_proto_value.number,
|
|
comment=get_comment(
|
|
proto_file=self.source_file, path=self.path + [2, entry_number]
|
|
),
|
|
)
|
|
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
|
|
]
|
|
super().__post_init__() # call MessageCompiler __post_init__
|
|
|
|
@property
|
|
def default_value_string(self) -> str:
|
|
"""Python representation of the default value for Enums.
|
|
|
|
As per the spec, this is the first value of the Enum.
|
|
"""
|
|
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
|
|
|
|
|
|
@dataclass
|
|
class ServiceCompiler(ProtoContentBase):
|
|
parent: OutputTemplate = PLACEHOLDER
|
|
proto_obj: DescriptorProto = PLACEHOLDER
|
|
path: List[int] = PLACEHOLDER
|
|
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
|
|
|
|
def __post_init__(self) -> None:
|
|
# Add service to output file
|
|
self.output_file.services.append(self)
|
|
self.output_file.typing_imports.add("Dict")
|
|
super().__post_init__() # check for unset fields
|
|
|
|
@property
|
|
def proto_name(self) -> str:
|
|
return self.proto_obj.name
|
|
|
|
@property
|
|
def py_name(self) -> str:
|
|
return pythonize_class_name(self.proto_name)
|
|
|
|
|
|
@dataclass
|
|
class ServiceMethodCompiler(ProtoContentBase):
|
|
|
|
parent: ServiceCompiler
|
|
proto_obj: MethodDescriptorProto
|
|
path: List[int] = PLACEHOLDER
|
|
comment_indent: int = 8
|
|
|
|
def __post_init__(self) -> None:
|
|
# Add method to service
|
|
self.parent.methods.append(self)
|
|
|
|
# Check for imports
|
|
if self.py_input_message:
|
|
for f in self.py_input_message.fields:
|
|
f.add_imports_to(self.output_file)
|
|
if "Optional" in self.py_output_message_type:
|
|
self.output_file.typing_imports.add("Optional")
|
|
self.mutable_default_args # ensure this is called before rendering
|
|
|
|
# Check for Async imports
|
|
if self.client_streaming:
|
|
self.output_file.typing_imports.add("AsyncIterable")
|
|
self.output_file.typing_imports.add("Iterable")
|
|
self.output_file.typing_imports.add("Union")
|
|
|
|
# Required by both client and server
|
|
if self.client_streaming or self.server_streaming:
|
|
self.output_file.typing_imports.add("AsyncIterator")
|
|
|
|
super().__post_init__() # check for unset fields
|
|
|
|
@property
|
|
def mutable_default_args(self) -> Dict[str, str]:
|
|
"""Handle mutable default arguments.
|
|
|
|
Returns a list of tuples containing the name and default value
|
|
for arguments to this message who's default value is mutable.
|
|
The defaults are swapped out for None and replaced back inside
|
|
the method's body.
|
|
Reference:
|
|
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
|
|
|
|
Returns
|
|
-------
|
|
Dict[str, str]
|
|
Name and actual default value (as a string)
|
|
for each argument with mutable default values.
|
|
"""
|
|
mutable_default_args = {}
|
|
|
|
if self.py_input_message:
|
|
for f in self.py_input_message.fields:
|
|
if (
|
|
not self.client_streaming
|
|
and f.default_value_string != "None"
|
|
and f.mutable
|
|
):
|
|
mutable_default_args[f.py_name] = f.default_value_string
|
|
self.output_file.typing_imports.add("Optional")
|
|
|
|
return mutable_default_args
|
|
|
|
@property
|
|
def py_name(self) -> str:
|
|
"""Pythonized method name."""
|
|
return pythonize_method_name(self.proto_obj.name)
|
|
|
|
@property
|
|
def proto_name(self) -> str:
|
|
"""Original protobuf name."""
|
|
return self.proto_obj.name
|
|
|
|
@property
|
|
def route(self) -> str:
|
|
package_part = (
|
|
f"{self.output_file.package}." if self.output_file.package else ""
|
|
)
|
|
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
|
|
|
|
@property
|
|
def py_input_message(self) -> Optional[MessageCompiler]:
|
|
"""Find the input message object.
|
|
|
|
Returns
|
|
-------
|
|
Optional[MessageCompiler]
|
|
Method instance representing the input message.
|
|
If not input message could be found or there are no
|
|
input messages, None is returned.
|
|
"""
|
|
package, name = parse_source_type_name(self.proto_obj.input_type)
|
|
|
|
# Nested types are currently flattened without dots.
|
|
# Todo: keep a fully quantified name in types, that is
|
|
# comparable with method.input_type
|
|
for msg in self.request.all_messages:
|
|
if (
|
|
msg.py_name == name.replace(".", "")
|
|
and msg.output_file.package == package
|
|
):
|
|
return msg
|
|
return None
|
|
|
|
@property
|
|
def py_input_message_type(self) -> str:
|
|
"""String representation of the Python type corresponding to the
|
|
input message.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
String representation of the Python type corresponding to the input message.
|
|
"""
|
|
return get_type_reference(
|
|
package=self.output_file.package,
|
|
imports=self.output_file.imports,
|
|
source_type=self.proto_obj.input_type,
|
|
).strip('"')
|
|
|
|
@property
|
|
def py_output_message_type(self) -> str:
|
|
"""String representation of the Python type corresponding to the
|
|
output message.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
String representation of the Python type corresponding to the output message.
|
|
"""
|
|
return get_type_reference(
|
|
package=self.output_file.package,
|
|
imports=self.output_file.imports,
|
|
source_type=self.proto_obj.output_type,
|
|
unwrap=False,
|
|
).strip('"')
|
|
|
|
@property
|
|
def client_streaming(self) -> bool:
|
|
return self.proto_obj.client_streaming
|
|
|
|
@property
|
|
def server_streaming(self) -> bool:
|
|
return self.proto_obj.server_streaming
|