Add support for pydantic dataclasses (#406)

This commit is contained in:
Samuel Yvon
2023-02-13 07:37:16 -08:00
committed by GitHub
parent 6df8cef3f0
commit 13d656587c
11 changed files with 283 additions and 19 deletions

View File

@@ -628,7 +628,6 @@ class Message(ABC):
# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():
if meta.group:
group_current.setdefault(meta.group)
@@ -1470,6 +1469,24 @@ class Message(ABC):
)
return self.__raw_get(name) is not default
@classmethod
def _validate_field_groups(cls, values):
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
for group, field_set in meta.items():
set_fields = [
field.name for field in field_set if values[field.name] is not None
]
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"
)
return values
def serialized_on_wire(message: Message) -> bool:
"""

View File

@@ -214,7 +214,6 @@ class ProtoContentBase:
@dataclass
class PluginRequestCompiler:
plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@@ -247,11 +246,13 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
@property
@@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase):
def has_deprecated_fields(self) -> bool:
return any(self.deprecated_fields)
@property
def has_oneof_fields(self) -> bool:
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
@property
def has_message_field(self) -> bool:
return any(
(
field.proto_obj.type in PROTO_MESSAGE_TYPES
for field in self.fields
if isinstance(field.proto_obj, FieldDescriptorProto)
)
)
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
@@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler):
imports.add("Dict")
return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@property
def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or (
@@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@property
@@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
return args
@dataclass
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def optional(self) -> bool:
# Force the optional to be True. This will allow the pydantic dataclass
# to validate the object correctly by allowing the field to be let empty.
# We add a pydantic validator later to ensure exactly one field is defined.
return True
@property
def pydantic_imports(self) -> Set[str]:
return {"root_validator"}
@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
@@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
@dataclass
class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER

View File

@@ -11,6 +11,7 @@ from typing import (
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
@@ -30,6 +31,7 @@ from .models import (
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
PydanticOneOfFieldCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
@@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False
if "pydantic_dataclasses" in plugin_options:
request_data.output_packages[
output_package_name
].pydantic_dataclasses = True
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
return response
def _make_one_of_field_compiler(
output_package: OutputTemplate,
source_file: "FileDescriptorProto",
parent: MessageCompiler,
proto_obj: "FieldDescriptorProto",
path: List[int],
) -> FieldCompiler:
pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls(
source_file=source_file,
parent=parent,
proto_obj=proto_obj,
path=path,
)
def read_protobuf_type(
item: DescriptorProto,
path: List[int],
@@ -168,11 +193,8 @@ def read_protobuf_type(
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
_make_one_of_field_compiler(
output_package, source_file, message_data, field, path + [2, index]
)
else:
FieldCompiler(

View File

@@ -5,7 +5,13 @@
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
@@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
{% endfor %}
{% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
{% endfor %}
{% for service in output_file.services %}
@@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
{% endif %}
{% endfor %}
{% endif %}