Add support for pydantic dataclasses (#406)
This commit is contained in:
@@ -628,7 +628,6 @@ class Message(ABC):
|
||||
# Set current field of each group after `__init__` has already been run.
|
||||
group_current: Dict[str, Optional[str]] = {}
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
|
||||
if meta.group:
|
||||
group_current.setdefault(meta.group)
|
||||
|
||||
@@ -1470,6 +1469,24 @@ class Message(ABC):
|
||||
)
|
||||
return self.__raw_get(name) is not default
|
||||
|
||||
@classmethod
|
||||
def _validate_field_groups(cls, values):
|
||||
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||
|
||||
for group, field_set in meta.items():
|
||||
set_fields = [
|
||||
field.name for field in field_set if values[field.name] is not None
|
||||
]
|
||||
if not set_fields:
|
||||
raise ValueError(f"Group {group} has no value; all fields are None")
|
||||
elif len(set_fields) > 1:
|
||||
set_fields_str = ", ".join(set_fields)
|
||||
raise ValueError(
|
||||
f"Group {group} has more than one value; fields {set_fields_str} are not None"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
|
||||
@@ -214,7 +214,6 @@ class ProtoContentBase:
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
|
||||
plugin_request_obj: CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@@ -247,11 +246,13 @@ class OutputTemplate:
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
|
||||
@property
|
||||
@@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase):
|
||||
def has_deprecated_fields(self) -> bool:
|
||||
return any(self.deprecated_fields)
|
||||
|
||||
@property
|
||||
def has_oneof_fields(self) -> bool:
|
||||
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
|
||||
|
||||
@property
|
||||
def has_message_field(self) -> bool:
|
||||
return any(
|
||||
(
|
||||
field.proto_obj.type in PROTO_MESSAGE_TYPES
|
||||
for field in self.fields
|
||||
if isinstance(field.proto_obj, FieldDescriptorProto)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
@@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler):
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
|
||||
@property
|
||||
def use_builtins(self) -> bool:
|
||||
return self.py_type in self.parent.builtins_types or (
|
||||
@@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler):
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@property
|
||||
@@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||
@property
|
||||
def optional(self) -> bool:
|
||||
# Force the optional to be True. This will allow the pydantic dataclass
|
||||
# to validate the object correctly by allowing the field to be let empty.
|
||||
# We add a pydantic validator later to ensure exactly one field is defined.
|
||||
return True
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return {"root_validator"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapEntryCompiler(FieldCompiler):
|
||||
py_k_type: Type = PLACEHOLDER
|
||||
@@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
|
||||
|
||||
@dataclass
|
||||
class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
parent: ServiceCompiler
|
||||
proto_obj: MethodDescriptorProto
|
||||
path: List[int] = PLACEHOLDER
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import (
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
@@ -30,6 +31,7 @@ from .models import (
|
||||
OneOfFieldCompiler,
|
||||
OutputTemplate,
|
||||
PluginRequestCompiler,
|
||||
PydanticOneOfFieldCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
@@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
# skip outputting Google's well-known types
|
||||
request_data.output_packages[output_package_name].output = False
|
||||
|
||||
if "pydantic_dataclasses" in plugin_options:
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
return response
|
||||
|
||||
|
||||
def _make_one_of_field_compiler(
|
||||
output_package: OutputTemplate,
|
||||
source_file: "FileDescriptorProto",
|
||||
parent: MessageCompiler,
|
||||
proto_obj: "FieldDescriptorProto",
|
||||
path: List[int],
|
||||
) -> FieldCompiler:
|
||||
|
||||
pydantic = output_package.pydantic_dataclasses
|
||||
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
|
||||
return Cls(
|
||||
source_file=source_file,
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_type(
|
||||
item: DescriptorProto,
|
||||
path: List[int],
|
||||
@@ -168,11 +193,8 @@ def read_protobuf_type(
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif is_oneof(field):
|
||||
OneOfFieldCompiler(
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
_make_one_of_field_compiler(
|
||||
output_package, source_file, message_data, field, path + [2, index]
|
||||
)
|
||||
else:
|
||||
FieldCompiler(
|
||||
|
||||
@@ -5,7 +5,13 @@
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
@@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
@@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||
@root_validator()
|
||||
def check_oneof(cls, values):
|
||||
return cls._validate_field_groups(values)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
{% for service in output_file.services %}
|
||||
@@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
{% for message in output_file.messages %}
|
||||
{% if message.has_message_field %}
|
||||
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
Reference in New Issue
Block a user