Add support for pydantic dataclasses (#406)
This commit is contained in:
parent
6df8cef3f0
commit
13d656587c
20
README.md
20
README.md
@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
|
||||
- Timezone-aware `datetime` and `timedelta` objects
|
||||
- Relative imports
|
||||
- Mypy type checking
|
||||
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)
|
||||
|
||||
This project is heavily inspired by, and borrows functionality from:
|
||||
|
||||
@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
||||
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
|
||||
```
|
||||
|
||||
## Generating Pydantic Models
|
||||
|
||||
You can use python-betterproto to generate pydantic based models, using
|
||||
pydantic dataclasses. This means the results of the protobuf unmarshalling will
|
||||
be typed checked. The usage is the same, but you need to add a custom option
|
||||
when calling the protobuf compiler:
|
||||
|
||||
|
||||
```
|
||||
protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
|
||||
```
|
||||
|
||||
With the important change being `--custom_opt=pydantic_dataclasses`. This will
|
||||
swap the dataclass implementation from the builtin python dataclass to the
|
||||
pydantic dataclass. You must have pydantic as a dependency in your project for
|
||||
this to work.
|
||||
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
|
||||
|
62
poetry.lock
generated
62
poetry.lock
generated
@ -1248,6 +1248,59 @@ files = [
|
||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "1.10.4"
|
||||
description = "Data validation and settings management using python type hints"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pydantic-1.10.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5635de53e6686fe7a44b5cf25fcc419a0d5e5c1a1efe73d49d48fe7586db854"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dc1cc241440ed7ca9ab59d9929075445da6b7c94ced281b3dd4cfe6c8cff817"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51bdeb10d2db0f288e71d49c9cefa609bca271720ecd0c58009bd7504a0c464c"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78cec42b95dbb500a1f7120bdf95c401f6abb616bbe8785ef09887306792e66e"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8775d4ef5e7299a2f4699501077a0defdaac5b6c4321173bcb0f3c496fbadf85"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:572066051eeac73d23f95ba9a71349c42a3e05999d0ee1572b7860235b850cc6"},
|
||||
{file = "pydantic-1.10.4-cp310-cp310-win_amd64.whl", hash = "sha256:7feb6a2d401f4d6863050f58325b8d99c1e56f4512d98b11ac64ad1751dc647d"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:39f4a73e5342b25c2959529f07f026ef58147249f9b7431e1ba8414a36761f53"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:983e720704431a6573d626b00662eb78a07148c9115129f9b4351091ec95ecc3"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d52162fe6b2b55964fbb0af2ee58e99791a3138588c482572bb6087953113a"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdf8d759ef326962b4678d89e275ffc55b7ce59d917d9f72233762061fd04a2d"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:05a81b006be15655b2a1bae5faa4280cf7c81d0e09fcb49b342ebf826abe5a72"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d88c4c0e5c5dfd05092a4b271282ef0588e5f4aaf345778056fc5259ba098857"},
|
||||
{file = "pydantic-1.10.4-cp311-cp311-win_amd64.whl", hash = "sha256:6a05a9db1ef5be0fe63e988f9617ca2551013f55000289c671f71ec16f4985e3"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:887ca463c3bc47103c123bc06919c86720e80e1214aab79e9b779cda0ff92a00"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdf88ab63c3ee282c76d652fc86518aacb737ff35796023fae56a65ced1a5978"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a48f1953c4a1d9bd0b5167ac50da9a79f6072c63c4cef4cf2a3736994903583e"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:a9f2de23bec87ff306aef658384b02aa7c32389766af3c5dee9ce33e80222dfa"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:cd8702c5142afda03dc2b1ee6bc358b62b3735b2cce53fc77b31ca9f728e4bc8"},
|
||||
{file = "pydantic-1.10.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6e7124d6855b2780611d9f5e1e145e86667eaa3bd9459192c8dc1a097f5e9903"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b53e1d41e97063d51a02821b80538053ee4608b9a181c1005441f1673c55423"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55b1625899acd33229c4352ce0ae54038529b412bd51c4915349b49ca575258f"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301d626a59edbe5dfb48fcae245896379a450d04baeed50ef40d8199f2733b06"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f9d649892a6f54a39ed56b8dfd5e08b5f3be5f893da430bed76975f3735d15"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7b5a3821225f5c43496c324b0d6875fde910a1c2933d726a743ce328fbb2a8c"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f2f7eb6273dd12472d7f218e1fef6f7c7c2f00ac2e1ecde4db8824c457300416"},
|
||||
{file = "pydantic-1.10.4-cp38-cp38-win_amd64.whl", hash = "sha256:4b05697738e7d2040696b0a66d9f0a10bec0efa1883ca75ee9e55baf511909d6"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a9a6747cac06c2beb466064dda999a13176b23535e4c496c9d48e6406f92d42d"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb992a1ef739cc7b543576337bebfc62c0e6567434e522e97291b251a41dad7f"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:990406d226dea0e8f25f643b370224771878142155b879784ce89f633541a024"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e82a6d37a95e0b1b42b82ab340ada3963aea1317fd7f888bb6b9dfbf4fff57c"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9193d4f4ee8feca58bc56c8306bcb820f5c7905fd919e0750acdeeeef0615b28"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2b3ce5f16deb45c472dde1a0ee05619298c864a20cded09c4edd820e1454129f"},
|
||||
{file = "pydantic-1.10.4-cp39-cp39-win_amd64.whl", hash = "sha256:9cbdc268a62d9a98c56e2452d6c41c0263d64a2009aac69246486f01b4f594c4"},
|
||||
{file = "pydantic-1.10.4-py3-none-any.whl", hash = "sha256:4948f264678c703f3877d1c8877c4e3b2e12e549c57795107f08cf70c6ec7774"},
|
||||
{file = "pydantic-1.10.4.tar.gz", hash = "sha256:b9a3859f24eb4e097502a3be1fb4b2abb79b6103dd9e2e0edb70613a4459a648"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
dotenv = ["python-dotenv (>=0.10.4)"]
|
||||
email = ["email-validator (>=1.0.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.14.0"
|
||||
@ -1386,6 +1439,13 @@ files = [
|
||||
{file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"},
|
||||
{file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"},
|
||||
{file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"},
|
||||
{file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"},
|
||||
{file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"},
|
||||
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"},
|
||||
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"},
|
||||
@ -1808,4 +1868,4 @@ compiler = ["black", "isort", "jinja2"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "80c6215b1b18ae9de0b8ed966cfab05018272dd7eff7dde0fb28d2b2231ac051"
|
||||
content-hash = "f9503e42026d0807dc3ed344b5f8e6fa3f1c7ffa9c66b086de929aefaa8cb8c6"
|
||||
|
@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0"
|
||||
tomlkit = "^0.7.0"
|
||||
tox = "^3.15.1"
|
||||
pre-commit = "^2.17.0"
|
||||
pydantic = ">=1.8.0"
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
@ -628,7 +628,6 @@ class Message(ABC):
|
||||
# Set current field of each group after `__init__` has already been run.
|
||||
group_current: Dict[str, Optional[str]] = {}
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
|
||||
if meta.group:
|
||||
group_current.setdefault(meta.group)
|
||||
|
||||
@ -1470,6 +1469,24 @@ class Message(ABC):
|
||||
)
|
||||
return self.__raw_get(name) is not default
|
||||
|
||||
@classmethod
|
||||
def _validate_field_groups(cls, values):
|
||||
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||
|
||||
for group, field_set in meta.items():
|
||||
set_fields = [
|
||||
field.name for field in field_set if values[field.name] is not None
|
||||
]
|
||||
if not set_fields:
|
||||
raise ValueError(f"Group {group} has no value; all fields are None")
|
||||
elif len(set_fields) > 1:
|
||||
set_fields_str = ", ".join(set_fields)
|
||||
raise ValueError(
|
||||
f"Group {group} has more than one value; fields {set_fields_str} are not None"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
|
@ -214,7 +214,6 @@ class ProtoContentBase:
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
|
||||
plugin_request_obj: CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@ -247,11 +246,13 @@ class OutputTemplate:
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
|
||||
@property
|
||||
@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase):
|
||||
def has_deprecated_fields(self) -> bool:
|
||||
return any(self.deprecated_fields)
|
||||
|
||||
@property
|
||||
def has_oneof_fields(self) -> bool:
|
||||
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
|
||||
|
||||
@property
|
||||
def has_message_field(self) -> bool:
|
||||
return any(
|
||||
(
|
||||
field.proto_obj.type in PROTO_MESSAGE_TYPES
|
||||
for field in self.fields
|
||||
if isinstance(field.proto_obj, FieldDescriptorProto)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler):
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
|
||||
@property
|
||||
def use_builtins(self) -> bool:
|
||||
return self.py_type in self.parent.builtins_types or (
|
||||
@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler):
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@property
|
||||
@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||
@property
|
||||
def optional(self) -> bool:
|
||||
# Force the optional to be True. This will allow the pydantic dataclass
|
||||
# to validate the object correctly by allowing the field to be let empty.
|
||||
# We add a pydantic validator later to ensure exactly one field is defined.
|
||||
return True
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return {"root_validator"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapEntryCompiler(FieldCompiler):
|
||||
py_k_type: Type = PLACEHOLDER
|
||||
@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
|
||||
|
||||
@dataclass
|
||||
class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
parent: ServiceCompiler
|
||||
proto_obj: MethodDescriptorProto
|
||||
path: List[int] = PLACEHOLDER
|
||||
|
@ -11,6 +11,7 @@ from typing import (
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
@ -30,6 +31,7 @@ from .models import (
|
||||
OneOfFieldCompiler,
|
||||
OutputTemplate,
|
||||
PluginRequestCompiler,
|
||||
PydanticOneOfFieldCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
# skip outputting Google's well-known types
|
||||
request_data.output_packages[output_package_name].output = False
|
||||
|
||||
if "pydantic_dataclasses" in plugin_options:
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
return response
|
||||
|
||||
|
||||
def _make_one_of_field_compiler(
|
||||
output_package: OutputTemplate,
|
||||
source_file: "FileDescriptorProto",
|
||||
parent: MessageCompiler,
|
||||
proto_obj: "FieldDescriptorProto",
|
||||
path: List[int],
|
||||
) -> FieldCompiler:
|
||||
|
||||
pydantic = output_package.pydantic_dataclasses
|
||||
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
|
||||
return Cls(
|
||||
source_file=source_file,
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_type(
|
||||
item: DescriptorProto,
|
||||
path: List[int],
|
||||
@ -168,11 +193,8 @@ def read_protobuf_type(
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif is_oneof(field):
|
||||
OneOfFieldCompiler(
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
_make_one_of_field_compiler(
|
||||
output_package, source_file, message_data, field, path + [2, index]
|
||||
)
|
||||
else:
|
||||
FieldCompiler(
|
||||
|
@ -5,7 +5,13 @@
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||
@root_validator()
|
||||
def check_oneof(cls, values):
|
||||
return cls._validate_field_groups(values)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
{% for service in output_file.services %}
|
||||
@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
{% for message in output_file.messages %}
|
||||
{% if message.has_message_field %}
|
||||
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
@ -11,6 +11,7 @@ from tests.util import (
|
||||
get_directories,
|
||||
inputs_path,
|
||||
output_path_betterproto,
|
||||
output_path_betterproto_pydantic,
|
||||
output_path_reference,
|
||||
protoc,
|
||||
)
|
||||
@ -80,9 +81,11 @@ async def generate_test_case_output(
|
||||
|
||||
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||
test_case_output_path_betterproto = output_path_betterproto
|
||||
test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic
|
||||
|
||||
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||
os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True)
|
||||
|
||||
clear_directory(test_case_output_path_reference)
|
||||
clear_directory(test_case_output_path_betterproto)
|
||||
@ -90,9 +93,13 @@ async def generate_test_case_output(
|
||||
(
|
||||
(ref_out, ref_err, ref_code),
|
||||
(plg_out, plg_err, plg_code),
|
||||
(plg_out_pyd, plg_err_pyd, plg_code_pyd),
|
||||
) = await asyncio.gather(
|
||||
protoc(test_case_input_path, test_case_output_path_reference, True),
|
||||
protoc(test_case_input_path, test_case_output_path_betterproto, False),
|
||||
protoc(
|
||||
test_case_input_path, test_case_output_path_betterproto_pyd, False, True
|
||||
),
|
||||
)
|
||||
|
||||
if ref_code == 0:
|
||||
@ -131,7 +138,27 @@ async def generate_test_case_output(
|
||||
sys.stderr.buffer.write(plg_err)
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
return max(ref_code, plg_code)
|
||||
if plg_code_pyd == 0:
|
||||
print(
|
||||
f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||
)
|
||||
|
||||
if verbose:
|
||||
if plg_out_pyd:
|
||||
print("Plugin stdout:")
|
||||
sys.stdout.buffer.write(plg_out_pyd)
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
if plg_err_pyd:
|
||||
print("Plugin stderr:")
|
||||
sys.stderr.buffer.write(plg_err_pyd)
|
||||
sys.stderr.buffer.flush()
|
||||
|
||||
return max(ref_code, plg_code, plg_code_pyd)
|
||||
|
||||
|
||||
HELP = "\n".join(
|
||||
|
@ -1,6 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from tests.output_betterproto.bool import Test
|
||||
from tests.output_betterproto_pydantic.bool import Test as TestPyd
|
||||
|
||||
|
||||
def test_value():
|
||||
message = Test()
|
||||
assert not message.value, "Boolean is False by default"
|
||||
|
||||
|
||||
def test_pydantic_no_value():
|
||||
with pytest.raises(ValueError):
|
||||
TestPyd()
|
||||
|
||||
|
||||
def test_pydantic_value():
|
||||
message = Test(value=False)
|
||||
assert not message.value
|
||||
|
@ -1,5 +1,6 @@
|
||||
import betterproto
|
||||
from tests.output_betterproto.oneof import Test
|
||||
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
||||
from tests.util import get_test_case_json_data
|
||||
|
||||
|
||||
@ -13,3 +14,8 @@ def test_which_name():
|
||||
message = Test()
|
||||
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||
|
||||
|
||||
def test_which_count_pyd():
|
||||
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||
|
@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import importlib
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
@ -22,6 +25,7 @@ root_path = Path(__file__).resolve().parent
|
||||
inputs_path = root_path.joinpath("inputs")
|
||||
output_path_reference = root_path.joinpath("output_reference")
|
||||
output_path_betterproto = root_path.joinpath("output_betterproto")
|
||||
output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic")
|
||||
|
||||
|
||||
def get_files(path, suffix: str) -> Generator[str, None, None]:
|
||||
@ -36,19 +40,56 @@ def get_directories(path):
|
||||
|
||||
|
||||
async def protoc(
|
||||
path: Union[str, Path], output_dir: Union[str, Path], reference: bool = False
|
||||
path: Union[str, Path],
|
||||
output_dir: Union[str, Path],
|
||||
reference: bool = False,
|
||||
pydantic_dataclasses: bool = False,
|
||||
):
|
||||
path: Path = Path(path).resolve()
|
||||
output_dir: Path = Path(output_dir).resolve()
|
||||
python_out_option: str = "python_betterproto_out" if not reference else "python_out"
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"grpc.tools.protoc",
|
||||
f"--proto_path={path.as_posix()}",
|
||||
f"--{python_out_option}={output_dir.as_posix()}",
|
||||
*[p.as_posix() for p in path.glob("*.proto")],
|
||||
]
|
||||
|
||||
if pydantic_dataclasses:
|
||||
plugin_path = Path("src/betterproto/plugin/main.py")
|
||||
|
||||
if "Win" in platform.system():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w", encoding="UTF-8", suffix=".bat", delete=False
|
||||
) as tf:
|
||||
# See https://stackoverflow.com/a/42622705
|
||||
tf.writelines(
|
||||
[
|
||||
"@echo off",
|
||||
f"\nchdir {os.getcwd()}",
|
||||
f"\n{sys.executable} -u {plugin_path.as_posix()}",
|
||||
]
|
||||
)
|
||||
|
||||
tf.flush()
|
||||
|
||||
plugin_path = Path(tf.name)
|
||||
atexit.register(os.remove, plugin_path)
|
||||
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"grpc.tools.protoc",
|
||||
f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
|
||||
"--experimental_allow_proto3_optional",
|
||||
"--custom_opt=pydantic_dataclasses",
|
||||
f"--proto_path={path.as_posix()}",
|
||||
f"--custom_out={output_dir.as_posix()}",
|
||||
*[p.as_posix() for p in path.glob("*.proto")],
|
||||
]
|
||||
else:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"grpc.tools.protoc",
|
||||
f"--proto_path={path.as_posix()}",
|
||||
f"--{python_out_option}={output_dir.as_posix()}",
|
||||
*[p.as_posix() for p in path.glob("*.proto")],
|
||||
]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user