Add support for pydantic dataclasses (#406)

This commit is contained in:
Samuel Yvon 2023-02-13 07:37:16 -08:00 committed by GitHub
parent 6df8cef3f0
commit 13d656587c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 283 additions and 19 deletions

View File

@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
- Timezone-aware `datetime` and `timedelta` objects
- Relative imports
- Mypy type checking
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)
This project is heavily inspired by, and borrows functionality from:
@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```
## Generating Pydantic Models
You can use python-betterproto to generate pydantic based models, using
pydantic dataclasses. This means the results of the protobuf unmarshalling will
be typed checked. The usage is the same, but you need to add a custom option
when calling the protobuf compiler:
```
protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
```
With the important change being `--custom_opt=pydantic_dataclasses`. This will
swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.
## Development
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_

62
poetry.lock generated
View File

@ -1248,6 +1248,59 @@ files = [
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
]
[[package]]
name = "pydantic"
version = "1.10.4"
description = "Data validation and settings management using python type hints"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pydantic-1.10.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5635de53e6686fe7a44b5cf25fcc419a0d5e5c1a1efe73d49d48fe7586db854"},
{file = "pydantic-1.10.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dc1cc241440ed7ca9ab59d9929075445da6b7c94ced281b3dd4cfe6c8cff817"},
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51bdeb10d2db0f288e71d49c9cefa609bca271720ecd0c58009bd7504a0c464c"},
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78cec42b95dbb500a1f7120bdf95c401f6abb616bbe8785ef09887306792e66e"},
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8775d4ef5e7299a2f4699501077a0defdaac5b6c4321173bcb0f3c496fbadf85"},
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:572066051eeac73d23f95ba9a71349c42a3e05999d0ee1572b7860235b850cc6"},
{file = "pydantic-1.10.4-cp310-cp310-win_amd64.whl", hash = "sha256:7feb6a2d401f4d6863050f58325b8d99c1e56f4512d98b11ac64ad1751dc647d"},
{file = "pydantic-1.10.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:39f4a73e5342b25c2959529f07f026ef58147249f9b7431e1ba8414a36761f53"},
{file = "pydantic-1.10.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:983e720704431a6573d626b00662eb78a07148c9115129f9b4351091ec95ecc3"},
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d52162fe6b2b55964fbb0af2ee58e99791a3138588c482572bb6087953113a"},
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdf8d759ef326962b4678d89e275ffc55b7ce59d917d9f72233762061fd04a2d"},
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:05a81b006be15655b2a1bae5faa4280cf7c81d0e09fcb49b342ebf826abe5a72"},
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d88c4c0e5c5dfd05092a4b271282ef0588e5f4aaf345778056fc5259ba098857"},
{file = "pydantic-1.10.4-cp311-cp311-win_amd64.whl", hash = "sha256:6a05a9db1ef5be0fe63e988f9617ca2551013f55000289c671f71ec16f4985e3"},
{file = "pydantic-1.10.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:887ca463c3bc47103c123bc06919c86720e80e1214aab79e9b779cda0ff92a00"},
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdf88ab63c3ee282c76d652fc86518aacb737ff35796023fae56a65ced1a5978"},
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a48f1953c4a1d9bd0b5167ac50da9a79f6072c63c4cef4cf2a3736994903583e"},
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:a9f2de23bec87ff306aef658384b02aa7c32389766af3c5dee9ce33e80222dfa"},
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:cd8702c5142afda03dc2b1ee6bc358b62b3735b2cce53fc77b31ca9f728e4bc8"},
{file = "pydantic-1.10.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6e7124d6855b2780611d9f5e1e145e86667eaa3bd9459192c8dc1a097f5e9903"},
{file = "pydantic-1.10.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b53e1d41e97063d51a02821b80538053ee4608b9a181c1005441f1673c55423"},
{file = "pydantic-1.10.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55b1625899acd33229c4352ce0ae54038529b412bd51c4915349b49ca575258f"},
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301d626a59edbe5dfb48fcae245896379a450d04baeed50ef40d8199f2733b06"},
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f9d649892a6f54a39ed56b8dfd5e08b5f3be5f893da430bed76975f3735d15"},
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7b5a3821225f5c43496c324b0d6875fde910a1c2933d726a743ce328fbb2a8c"},
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f2f7eb6273dd12472d7f218e1fef6f7c7c2f00ac2e1ecde4db8824c457300416"},
{file = "pydantic-1.10.4-cp38-cp38-win_amd64.whl", hash = "sha256:4b05697738e7d2040696b0a66d9f0a10bec0efa1883ca75ee9e55baf511909d6"},
{file = "pydantic-1.10.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a9a6747cac06c2beb466064dda999a13176b23535e4c496c9d48e6406f92d42d"},
{file = "pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb992a1ef739cc7b543576337bebfc62c0e6567434e522e97291b251a41dad7f"},
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:990406d226dea0e8f25f643b370224771878142155b879784ce89f633541a024"},
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e82a6d37a95e0b1b42b82ab340ada3963aea1317fd7f888bb6b9dfbf4fff57c"},
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9193d4f4ee8feca58bc56c8306bcb820f5c7905fd919e0750acdeeeef0615b28"},
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2b3ce5f16deb45c472dde1a0ee05619298c864a20cded09c4edd820e1454129f"},
{file = "pydantic-1.10.4-cp39-cp39-win_amd64.whl", hash = "sha256:9cbdc268a62d9a98c56e2452d6c41c0263d64a2009aac69246486f01b4f594c4"},
{file = "pydantic-1.10.4-py3-none-any.whl", hash = "sha256:4948f264678c703f3877d1c8877c4e3b2e12e549c57795107f08cf70c6ec7774"},
{file = "pydantic-1.10.4.tar.gz", hash = "sha256:b9a3859f24eb4e097502a3be1fb4b2abb79b6103dd9e2e0edb70613a4459a648"},
]
[package.dependencies]
typing-extensions = ">=4.2.0"
[package.extras]
dotenv = ["python-dotenv (>=0.10.4)"]
email = ["email-validator (>=1.0.3)"]
[[package]]
name = "pygments"
version = "2.14.0"
@ -1386,6 +1439,13 @@ files = [
{file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"},
{file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"},
{file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"},
{file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"},
{file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"},
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"},
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"},
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"},
{file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"},
{file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"},
{file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"},
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"},
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"},
@ -1808,4 +1868,4 @@ compiler = ["black", "isort", "jinja2"]
[metadata]
lock-version = "2.0"
python-versions = "^3.7"
content-hash = "80c6215b1b18ae9de0b8ed966cfab05018272dd7eff7dde0fb28d2b2231ac051"
content-hash = "f9503e42026d0807dc3ed344b5f8e6fa3f1c7ffa9c66b086de929aefaa8cb8c6"

View File

@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0"
tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
[tool.poetry.scripts]

View File

@ -628,7 +628,6 @@ class Message(ABC):
# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():
if meta.group:
group_current.setdefault(meta.group)
@ -1470,6 +1469,24 @@ class Message(ABC):
)
return self.__raw_get(name) is not default
@classmethod
def _validate_field_groups(cls, values):
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
for group, field_set in meta.items():
set_fields = [
field.name for field in field_set if values[field.name] is not None
]
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"
)
return values
def serialized_on_wire(message: Message) -> bool:
"""

View File

@ -214,7 +214,6 @@ class ProtoContentBase:
@dataclass
class PluginRequestCompiler:
plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@ -247,11 +246,13 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
@property
@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase):
def has_deprecated_fields(self) -> bool:
return any(self.deprecated_fields)
@property
def has_oneof_fields(self) -> bool:
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
@property
def has_message_field(self) -> bool:
return any(
(
field.proto_obj.type in PROTO_MESSAGE_TYPES
for field in self.fields
if isinstance(field.proto_obj, FieldDescriptorProto)
)
)
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler):
imports.add("Dict")
return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@property
def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or (
@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@property
@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
return args
@dataclass
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def optional(self) -> bool:
# Force the optional to be True. This will allow the pydantic dataclass
# to validate the object correctly by allowing the field to be let empty.
# We add a pydantic validator later to ensure exactly one field is defined.
return True
@property
def pydantic_imports(self) -> Set[str]:
return {"root_validator"}
@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
@dataclass
class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER

View File

@ -11,6 +11,7 @@ from typing import (
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
@ -30,6 +31,7 @@ from .models import (
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
PydanticOneOfFieldCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False
if "pydantic_dataclasses" in plugin_options:
request_data.output_packages[
output_package_name
].pydantic_dataclasses = True
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
return response
def _make_one_of_field_compiler(
output_package: OutputTemplate,
source_file: "FileDescriptorProto",
parent: MessageCompiler,
proto_obj: "FieldDescriptorProto",
path: List[int],
) -> FieldCompiler:
pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls(
source_file=source_file,
parent=parent,
proto_obj=proto_obj,
path=path,
)
def read_protobuf_type(
item: DescriptorProto,
path: List[int],
@ -168,11 +193,8 @@ def read_protobuf_type(
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
_make_one_of_field_compiler(
output_package, source_file, message_data, field, path + [2, index]
)
else:
FieldCompiler(

View File

@ -5,7 +5,13 @@
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
{% endfor %}
{% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
{% endfor %}
{% for service in output_file.services %}
@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
{% endif %}
{% endfor %}
{% endif %}

View File

@ -11,6 +11,7 @@ from tests.util import (
get_directories,
inputs_path,
output_path_betterproto,
output_path_betterproto_pydantic,
output_path_reference,
protoc,
)
@ -80,9 +81,11 @@ async def generate_test_case_output(
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
test_case_output_path_betterproto = output_path_betterproto
test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic
os.makedirs(test_case_output_path_reference, exist_ok=True)
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True)
clear_directory(test_case_output_path_reference)
clear_directory(test_case_output_path_betterproto)
@ -90,9 +93,13 @@ async def generate_test_case_output(
(
(ref_out, ref_err, ref_code),
(plg_out, plg_err, plg_code),
(plg_out_pyd, plg_err_pyd, plg_code_pyd),
) = await asyncio.gather(
protoc(test_case_input_path, test_case_output_path_reference, True),
protoc(test_case_input_path, test_case_output_path_betterproto, False),
protoc(
test_case_input_path, test_case_output_path_betterproto_pyd, False, True
),
)
if ref_code == 0:
@ -131,7 +138,27 @@ async def generate_test_case_output(
sys.stderr.buffer.write(plg_err)
sys.stderr.buffer.flush()
return max(ref_code, plg_code)
if plg_code_pyd == 0:
print(
f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
else:
print(
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
if verbose:
if plg_out_pyd:
print("Plugin stdout:")
sys.stdout.buffer.write(plg_out_pyd)
sys.stdout.buffer.flush()
if plg_err_pyd:
print("Plugin stderr:")
sys.stderr.buffer.write(plg_err_pyd)
sys.stderr.buffer.flush()
return max(ref_code, plg_code, plg_code_pyd)
HELP = "\n".join(

View File

@ -1,6 +1,19 @@
import pytest
from tests.output_betterproto.bool import Test
from tests.output_betterproto_pydantic.bool import Test as TestPyd
def test_value():
message = Test()
assert not message.value, "Boolean is False by default"
def test_pydantic_no_value():
with pytest.raises(ValueError):
TestPyd()
def test_pydantic_value():
message = Test(value=False)
assert not message.value

View File

@ -1,5 +1,6 @@
import betterproto
from tests.output_betterproto.oneof import Test
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
from tests.util import get_test_case_json_data
@ -13,3 +14,8 @@ def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
def test_which_count_pyd():
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")

View File

@ -1,7 +1,10 @@
import asyncio
import atexit
import importlib
import os
import platform
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
@ -22,6 +25,7 @@ root_path = Path(__file__).resolve().parent
inputs_path = root_path.joinpath("inputs")
output_path_reference = root_path.joinpath("output_reference")
output_path_betterproto = root_path.joinpath("output_betterproto")
output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic")
def get_files(path, suffix: str) -> Generator[str, None, None]:
@ -36,19 +40,56 @@ def get_directories(path):
async def protoc(
path: Union[str, Path], output_dir: Union[str, Path], reference: bool = False
path: Union[str, Path],
output_dir: Union[str, Path],
reference: bool = False,
pydantic_dataclasses: bool = False,
):
path: Path = Path(path).resolve()
output_dir: Path = Path(output_dir).resolve()
python_out_option: str = "python_betterproto_out" if not reference else "python_out"
command = [
sys.executable,
"-m",
"grpc.tools.protoc",
f"--proto_path={path.as_posix()}",
f"--{python_out_option}={output_dir.as_posix()}",
*[p.as_posix() for p in path.glob("*.proto")],
]
if pydantic_dataclasses:
plugin_path = Path("src/betterproto/plugin/main.py")
if "Win" in platform.system():
with tempfile.NamedTemporaryFile(
"w", encoding="UTF-8", suffix=".bat", delete=False
) as tf:
# See https://stackoverflow.com/a/42622705
tf.writelines(
[
"@echo off",
f"\nchdir {os.getcwd()}",
f"\n{sys.executable} -u {plugin_path.as_posix()}",
]
)
tf.flush()
plugin_path = Path(tf.name)
atexit.register(os.remove, plugin_path)
command = [
sys.executable,
"-m",
"grpc.tools.protoc",
f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
"--experimental_allow_proto3_optional",
"--custom_opt=pydantic_dataclasses",
f"--proto_path={path.as_posix()}",
f"--custom_out={output_dir.as_posix()}",
*[p.as_posix() for p in path.glob("*.proto")],
]
else:
command = [
sys.executable,
"-m",
"grpc.tools.protoc",
f"--proto_path={path.as_posix()}",
f"--{python_out_option}={output_dir.as_posix()}",
*[p.as_posix() for p in path.glob("*.proto")],
]
proc = await asyncio.create_subprocess_exec(
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)