Add support for pydantic dataclasses (#406)
This commit is contained in:
parent
6df8cef3f0
commit
13d656587c
20
README.md
20
README.md
@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
|
|||||||
- Timezone-aware `datetime` and `timedelta` objects
|
- Timezone-aware `datetime` and `timedelta` objects
|
||||||
- Relative imports
|
- Relative imports
|
||||||
- Mypy type checking
|
- Mypy type checking
|
||||||
|
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)
|
||||||
|
|
||||||
This project is heavily inspired by, and borrows functionality from:
|
This project is heavily inspired by, and borrows functionality from:
|
||||||
|
|
||||||
@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
|||||||
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
|
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Generating Pydantic Models
|
||||||
|
|
||||||
|
You can use python-betterproto to generate pydantic based models, using
|
||||||
|
pydantic dataclasses. This means the results of the protobuf unmarshalling will
|
||||||
|
be typed checked. The usage is the same, but you need to add a custom option
|
||||||
|
when calling the protobuf compiler:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
|
||||||
|
```
|
||||||
|
|
||||||
|
With the important change being `--custom_opt=pydantic_dataclasses`. This will
|
||||||
|
swap the dataclass implementation from the builtin python dataclass to the
|
||||||
|
pydantic dataclass. You must have pydantic as a dependency in your project for
|
||||||
|
this to work.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
|
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
|
||||||
|
62
poetry.lock
generated
62
poetry.lock
generated
@ -1248,6 +1248,59 @@ files = [
|
|||||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pydantic"
|
||||||
|
version = "1.10.4"
|
||||||
|
description = "Data validation and settings management using python type hints"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5635de53e6686fe7a44b5cf25fcc419a0d5e5c1a1efe73d49d48fe7586db854"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dc1cc241440ed7ca9ab59d9929075445da6b7c94ced281b3dd4cfe6c8cff817"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51bdeb10d2db0f288e71d49c9cefa609bca271720ecd0c58009bd7504a0c464c"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78cec42b95dbb500a1f7120bdf95c401f6abb616bbe8785ef09887306792e66e"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8775d4ef5e7299a2f4699501077a0defdaac5b6c4321173bcb0f3c496fbadf85"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:572066051eeac73d23f95ba9a71349c42a3e05999d0ee1572b7860235b850cc6"},
|
||||||
|
{file = "pydantic-1.10.4-cp310-cp310-win_amd64.whl", hash = "sha256:7feb6a2d401f4d6863050f58325b8d99c1e56f4512d98b11ac64ad1751dc647d"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:39f4a73e5342b25c2959529f07f026ef58147249f9b7431e1ba8414a36761f53"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:983e720704431a6573d626b00662eb78a07148c9115129f9b4351091ec95ecc3"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d52162fe6b2b55964fbb0af2ee58e99791a3138588c482572bb6087953113a"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdf8d759ef326962b4678d89e275ffc55b7ce59d917d9f72233762061fd04a2d"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:05a81b006be15655b2a1bae5faa4280cf7c81d0e09fcb49b342ebf826abe5a72"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d88c4c0e5c5dfd05092a4b271282ef0588e5f4aaf345778056fc5259ba098857"},
|
||||||
|
{file = "pydantic-1.10.4-cp311-cp311-win_amd64.whl", hash = "sha256:6a05a9db1ef5be0fe63e988f9617ca2551013f55000289c671f71ec16f4985e3"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:887ca463c3bc47103c123bc06919c86720e80e1214aab79e9b779cda0ff92a00"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdf88ab63c3ee282c76d652fc86518aacb737ff35796023fae56a65ced1a5978"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a48f1953c4a1d9bd0b5167ac50da9a79f6072c63c4cef4cf2a3736994903583e"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:a9f2de23bec87ff306aef658384b02aa7c32389766af3c5dee9ce33e80222dfa"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:cd8702c5142afda03dc2b1ee6bc358b62b3735b2cce53fc77b31ca9f728e4bc8"},
|
||||||
|
{file = "pydantic-1.10.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6e7124d6855b2780611d9f5e1e145e86667eaa3bd9459192c8dc1a097f5e9903"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b53e1d41e97063d51a02821b80538053ee4608b9a181c1005441f1673c55423"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55b1625899acd33229c4352ce0ae54038529b412bd51c4915349b49ca575258f"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301d626a59edbe5dfb48fcae245896379a450d04baeed50ef40d8199f2733b06"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f9d649892a6f54a39ed56b8dfd5e08b5f3be5f893da430bed76975f3735d15"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7b5a3821225f5c43496c324b0d6875fde910a1c2933d726a743ce328fbb2a8c"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f2f7eb6273dd12472d7f218e1fef6f7c7c2f00ac2e1ecde4db8824c457300416"},
|
||||||
|
{file = "pydantic-1.10.4-cp38-cp38-win_amd64.whl", hash = "sha256:4b05697738e7d2040696b0a66d9f0a10bec0efa1883ca75ee9e55baf511909d6"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a9a6747cac06c2beb466064dda999a13176b23535e4c496c9d48e6406f92d42d"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb992a1ef739cc7b543576337bebfc62c0e6567434e522e97291b251a41dad7f"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:990406d226dea0e8f25f643b370224771878142155b879784ce89f633541a024"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e82a6d37a95e0b1b42b82ab340ada3963aea1317fd7f888bb6b9dfbf4fff57c"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9193d4f4ee8feca58bc56c8306bcb820f5c7905fd919e0750acdeeeef0615b28"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2b3ce5f16deb45c472dde1a0ee05619298c864a20cded09c4edd820e1454129f"},
|
||||||
|
{file = "pydantic-1.10.4-cp39-cp39-win_amd64.whl", hash = "sha256:9cbdc268a62d9a98c56e2452d6c41c0263d64a2009aac69246486f01b4f594c4"},
|
||||||
|
{file = "pydantic-1.10.4-py3-none-any.whl", hash = "sha256:4948f264678c703f3877d1c8877c4e3b2e12e549c57795107f08cf70c6ec7774"},
|
||||||
|
{file = "pydantic-1.10.4.tar.gz", hash = "sha256:b9a3859f24eb4e097502a3be1fb4b2abb79b6103dd9e2e0edb70613a4459a648"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = ">=4.2.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dotenv = ["python-dotenv (>=0.10.4)"]
|
||||||
|
email = ["email-validator (>=1.0.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pygments"
|
name = "pygments"
|
||||||
version = "2.14.0"
|
version = "2.14.0"
|
||||||
@ -1386,6 +1439,13 @@ files = [
|
|||||||
{file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"},
|
{file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"},
|
||||||
{file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"},
|
{file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"},
|
||||||
{file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"},
|
{file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"},
|
||||||
|
{file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"},
|
||||||
{file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"},
|
{file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"},
|
||||||
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"},
|
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"},
|
||||||
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"},
|
{file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"},
|
||||||
@ -1808,4 +1868,4 @@ compiler = ["black", "isort", "jinja2"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.7"
|
python-versions = "^3.7"
|
||||||
content-hash = "80c6215b1b18ae9de0b8ed966cfab05018272dd7eff7dde0fb28d2b2231ac051"
|
content-hash = "f9503e42026d0807dc3ed344b5f8e6fa3f1c7ffa9c66b086de929aefaa8cb8c6"
|
||||||
|
@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0"
|
|||||||
tomlkit = "^0.7.0"
|
tomlkit = "^0.7.0"
|
||||||
tox = "^3.15.1"
|
tox = "^3.15.1"
|
||||||
pre-commit = "^2.17.0"
|
pre-commit = "^2.17.0"
|
||||||
|
pydantic = ">=1.8.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
@ -628,7 +628,6 @@ class Message(ABC):
|
|||||||
# Set current field of each group after `__init__` has already been run.
|
# Set current field of each group after `__init__` has already been run.
|
||||||
group_current: Dict[str, Optional[str]] = {}
|
group_current: Dict[str, Optional[str]] = {}
|
||||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
group_current.setdefault(meta.group)
|
group_current.setdefault(meta.group)
|
||||||
|
|
||||||
@ -1470,6 +1469,24 @@ class Message(ABC):
|
|||||||
)
|
)
|
||||||
return self.__raw_get(name) is not default
|
return self.__raw_get(name) is not default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_field_groups(cls, values):
|
||||||
|
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||||
|
|
||||||
|
for group, field_set in meta.items():
|
||||||
|
set_fields = [
|
||||||
|
field.name for field in field_set if values[field.name] is not None
|
||||||
|
]
|
||||||
|
if not set_fields:
|
||||||
|
raise ValueError(f"Group {group} has no value; all fields are None")
|
||||||
|
elif len(set_fields) > 1:
|
||||||
|
set_fields_str = ", ".join(set_fields)
|
||||||
|
raise ValueError(
|
||||||
|
f"Group {group} has more than one value; fields {set_fields_str} are not None"
|
||||||
|
)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
def serialized_on_wire(message: Message) -> bool:
|
def serialized_on_wire(message: Message) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -214,7 +214,6 @@ class ProtoContentBase:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PluginRequestCompiler:
|
class PluginRequestCompiler:
|
||||||
|
|
||||||
plugin_request_obj: CodeGeneratorRequest
|
plugin_request_obj: CodeGeneratorRequest
|
||||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||||
|
|
||||||
@ -247,11 +246,13 @@ class OutputTemplate:
|
|||||||
imports: Set[str] = field(default_factory=set)
|
imports: Set[str] = field(default_factory=set)
|
||||||
datetime_imports: Set[str] = field(default_factory=set)
|
datetime_imports: Set[str] = field(default_factory=set)
|
||||||
typing_imports: Set[str] = field(default_factory=set)
|
typing_imports: Set[str] = field(default_factory=set)
|
||||||
|
pydantic_imports: Set[str] = field(default_factory=set)
|
||||||
builtins_import: bool = False
|
builtins_import: bool = False
|
||||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||||
|
pydantic_dataclasses: bool = False
|
||||||
output: bool = True
|
output: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase):
|
|||||||
def has_deprecated_fields(self) -> bool:
|
def has_deprecated_fields(self) -> bool:
|
||||||
return any(self.deprecated_fields)
|
return any(self.deprecated_fields)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_oneof_fields(self) -> bool:
|
||||||
|
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_message_field(self) -> bool:
|
||||||
|
return any(
|
||||||
|
(
|
||||||
|
field.proto_obj.type in PROTO_MESSAGE_TYPES
|
||||||
|
for field in self.fields
|
||||||
|
if isinstance(field.proto_obj, FieldDescriptorProto)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_map(
|
def is_map(
|
||||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||||
@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler):
|
|||||||
imports.add("Dict")
|
imports.add("Dict")
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pydantic_imports(self) -> Set[str]:
|
||||||
|
return set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_builtins(self) -> bool:
|
def use_builtins(self) -> bool:
|
||||||
return self.py_type in self.parent.builtins_types or (
|
return self.py_type in self.parent.builtins_types or (
|
||||||
@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler):
|
|||||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||||
output_file.datetime_imports.update(self.datetime_imports)
|
output_file.datetime_imports.update(self.datetime_imports)
|
||||||
output_file.typing_imports.update(self.typing_imports)
|
output_file.typing_imports.update(self.typing_imports)
|
||||||
|
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||||
|
@property
|
||||||
|
def optional(self) -> bool:
|
||||||
|
# Force the optional to be True. This will allow the pydantic dataclass
|
||||||
|
# to validate the object correctly by allowing the field to be let empty.
|
||||||
|
# We add a pydantic validator later to ensure exactly one field is defined.
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pydantic_imports(self) -> Set[str]:
|
||||||
|
return {"root_validator"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MapEntryCompiler(FieldCompiler):
|
class MapEntryCompiler(FieldCompiler):
|
||||||
py_k_type: Type = PLACEHOLDER
|
py_k_type: Type = PLACEHOLDER
|
||||||
@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServiceMethodCompiler(ProtoContentBase):
|
class ServiceMethodCompiler(ProtoContentBase):
|
||||||
|
|
||||||
parent: ServiceCompiler
|
parent: ServiceCompiler
|
||||||
proto_obj: MethodDescriptorProto
|
proto_obj: MethodDescriptorProto
|
||||||
path: List[int] = PLACEHOLDER
|
path: List[int] = PLACEHOLDER
|
||||||
|
@ -11,6 +11,7 @@ from typing import (
|
|||||||
from betterproto.lib.google.protobuf import (
|
from betterproto.lib.google.protobuf import (
|
||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
|
FieldDescriptorProto,
|
||||||
FileDescriptorProto,
|
FileDescriptorProto,
|
||||||
ServiceDescriptorProto,
|
ServiceDescriptorProto,
|
||||||
)
|
)
|
||||||
@ -30,6 +31,7 @@ from .models import (
|
|||||||
OneOfFieldCompiler,
|
OneOfFieldCompiler,
|
||||||
OutputTemplate,
|
OutputTemplate,
|
||||||
PluginRequestCompiler,
|
PluginRequestCompiler,
|
||||||
|
PydanticOneOfFieldCompiler,
|
||||||
ServiceCompiler,
|
ServiceCompiler,
|
||||||
ServiceMethodCompiler,
|
ServiceMethodCompiler,
|
||||||
is_map,
|
is_map,
|
||||||
@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|||||||
# skip outputting Google's well-known types
|
# skip outputting Google's well-known types
|
||||||
request_data.output_packages[output_package_name].output = False
|
request_data.output_packages[output_package_name].output = False
|
||||||
|
|
||||||
|
if "pydantic_dataclasses" in plugin_options:
|
||||||
|
request_data.output_packages[
|
||||||
|
output_package_name
|
||||||
|
].pydantic_dataclasses = True
|
||||||
|
|
||||||
# Read Messages and Enums
|
# Read Messages and Enums
|
||||||
# We need to read Messages before Services in so that we can
|
# We need to read Messages before Services in so that we can
|
||||||
# get the references to input/output messages for each service
|
# get the references to input/output messages for each service
|
||||||
@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _make_one_of_field_compiler(
|
||||||
|
output_package: OutputTemplate,
|
||||||
|
source_file: "FileDescriptorProto",
|
||||||
|
parent: MessageCompiler,
|
||||||
|
proto_obj: "FieldDescriptorProto",
|
||||||
|
path: List[int],
|
||||||
|
) -> FieldCompiler:
|
||||||
|
|
||||||
|
pydantic = output_package.pydantic_dataclasses
|
||||||
|
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
|
||||||
|
return Cls(
|
||||||
|
source_file=source_file,
|
||||||
|
parent=parent,
|
||||||
|
proto_obj=proto_obj,
|
||||||
|
path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_protobuf_type(
|
def read_protobuf_type(
|
||||||
item: DescriptorProto,
|
item: DescriptorProto,
|
||||||
path: List[int],
|
path: List[int],
|
||||||
@ -168,11 +193,8 @@ def read_protobuf_type(
|
|||||||
path=path + [2, index],
|
path=path + [2, index],
|
||||||
)
|
)
|
||||||
elif is_oneof(field):
|
elif is_oneof(field):
|
||||||
OneOfFieldCompiler(
|
_make_one_of_field_compiler(
|
||||||
source_file=source_file,
|
output_package, source_file, message_data, field, path + [2, index]
|
||||||
parent=message_data,
|
|
||||||
proto_obj=field,
|
|
||||||
path=path + [2, index],
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
FieldCompiler(
|
FieldCompiler(
|
||||||
|
@ -5,7 +5,13 @@
|
|||||||
{% for i in output_file.python_module_imports|sort %}
|
{% for i in output_file.python_module_imports|sort %}
|
||||||
import {{ i }}
|
import {{ i }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_dataclasses %}
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
{%- else -%}
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if output_file.datetime_imports %}
|
{% if output_file.datetime_imports %}
|
||||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
|||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_imports %}
|
||||||
|
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
{% if output_file.services %}
|
{% if output_file.services %}
|
||||||
from betterproto.grpc.grpclib_server import ServiceBase
|
from betterproto.grpc.grpclib_server import ServiceBase
|
||||||
@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||||
|
@root_validator()
|
||||||
|
def check_oneof(cls, values):
|
||||||
|
return cls._validate_field_groups(values)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% for service in output_file.services %}
|
{% for service in output_file.services %}
|
||||||
@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_dataclasses %}
|
||||||
|
{% for message in output_file.messages %}
|
||||||
|
{% if message.has_message_field %}
|
||||||
|
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
@ -11,6 +11,7 @@ from tests.util import (
|
|||||||
get_directories,
|
get_directories,
|
||||||
inputs_path,
|
inputs_path,
|
||||||
output_path_betterproto,
|
output_path_betterproto,
|
||||||
|
output_path_betterproto_pydantic,
|
||||||
output_path_reference,
|
output_path_reference,
|
||||||
protoc,
|
protoc,
|
||||||
)
|
)
|
||||||
@ -80,9 +81,11 @@ async def generate_test_case_output(
|
|||||||
|
|
||||||
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||||
test_case_output_path_betterproto = output_path_betterproto
|
test_case_output_path_betterproto = output_path_betterproto
|
||||||
|
test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic
|
||||||
|
|
||||||
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||||
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||||
|
os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True)
|
||||||
|
|
||||||
clear_directory(test_case_output_path_reference)
|
clear_directory(test_case_output_path_reference)
|
||||||
clear_directory(test_case_output_path_betterproto)
|
clear_directory(test_case_output_path_betterproto)
|
||||||
@ -90,9 +93,13 @@ async def generate_test_case_output(
|
|||||||
(
|
(
|
||||||
(ref_out, ref_err, ref_code),
|
(ref_out, ref_err, ref_code),
|
||||||
(plg_out, plg_err, plg_code),
|
(plg_out, plg_err, plg_code),
|
||||||
|
(plg_out_pyd, plg_err_pyd, plg_code_pyd),
|
||||||
) = await asyncio.gather(
|
) = await asyncio.gather(
|
||||||
protoc(test_case_input_path, test_case_output_path_reference, True),
|
protoc(test_case_input_path, test_case_output_path_reference, True),
|
||||||
protoc(test_case_input_path, test_case_output_path_betterproto, False),
|
protoc(test_case_input_path, test_case_output_path_betterproto, False),
|
||||||
|
protoc(
|
||||||
|
test_case_input_path, test_case_output_path_betterproto_pyd, False, True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if ref_code == 0:
|
if ref_code == 0:
|
||||||
@ -131,7 +138,27 @@ async def generate_test_case_output(
|
|||||||
sys.stderr.buffer.write(plg_err)
|
sys.stderr.buffer.write(plg_err)
|
||||||
sys.stderr.buffer.flush()
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
return max(ref_code, plg_code)
|
if plg_code_pyd == 0:
|
||||||
|
print(
|
||||||
|
f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
if plg_out_pyd:
|
||||||
|
print("Plugin stdout:")
|
||||||
|
sys.stdout.buffer.write(plg_out_pyd)
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
|
||||||
|
if plg_err_pyd:
|
||||||
|
print("Plugin stderr:")
|
||||||
|
sys.stderr.buffer.write(plg_err_pyd)
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
return max(ref_code, plg_code, plg_code_pyd)
|
||||||
|
|
||||||
|
|
||||||
HELP = "\n".join(
|
HELP = "\n".join(
|
||||||
|
@ -1,6 +1,19 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
from tests.output_betterproto.bool import Test
|
from tests.output_betterproto.bool import Test
|
||||||
|
from tests.output_betterproto_pydantic.bool import Test as TestPyd
|
||||||
|
|
||||||
|
|
||||||
def test_value():
|
def test_value():
|
||||||
message = Test()
|
message = Test()
|
||||||
assert not message.value, "Boolean is False by default"
|
assert not message.value, "Boolean is False by default"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_no_value():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TestPyd()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_value():
|
||||||
|
message = Test(value=False)
|
||||||
|
assert not message.value
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import betterproto
|
import betterproto
|
||||||
from tests.output_betterproto.oneof import Test
|
from tests.output_betterproto.oneof import Test
|
||||||
|
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
||||||
from tests.util import get_test_case_json_data
|
from tests.util import get_test_case_json_data
|
||||||
|
|
||||||
|
|
||||||
@ -13,3 +14,8 @@ def test_which_name():
|
|||||||
message = Test()
|
message = Test()
|
||||||
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
|
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
|
||||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||||
|
|
||||||
|
|
||||||
|
def test_which_count_pyd():
|
||||||
|
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
||||||
|
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import atexit
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
@ -22,6 +25,7 @@ root_path = Path(__file__).resolve().parent
|
|||||||
inputs_path = root_path.joinpath("inputs")
|
inputs_path = root_path.joinpath("inputs")
|
||||||
output_path_reference = root_path.joinpath("output_reference")
|
output_path_reference = root_path.joinpath("output_reference")
|
||||||
output_path_betterproto = root_path.joinpath("output_betterproto")
|
output_path_betterproto = root_path.joinpath("output_betterproto")
|
||||||
|
output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic")
|
||||||
|
|
||||||
|
|
||||||
def get_files(path, suffix: str) -> Generator[str, None, None]:
|
def get_files(path, suffix: str) -> Generator[str, None, None]:
|
||||||
@ -36,19 +40,56 @@ def get_directories(path):
|
|||||||
|
|
||||||
|
|
||||||
async def protoc(
|
async def protoc(
|
||||||
path: Union[str, Path], output_dir: Union[str, Path], reference: bool = False
|
path: Union[str, Path],
|
||||||
|
output_dir: Union[str, Path],
|
||||||
|
reference: bool = False,
|
||||||
|
pydantic_dataclasses: bool = False,
|
||||||
):
|
):
|
||||||
path: Path = Path(path).resolve()
|
path: Path = Path(path).resolve()
|
||||||
output_dir: Path = Path(output_dir).resolve()
|
output_dir: Path = Path(output_dir).resolve()
|
||||||
python_out_option: str = "python_betterproto_out" if not reference else "python_out"
|
python_out_option: str = "python_betterproto_out" if not reference else "python_out"
|
||||||
command = [
|
|
||||||
sys.executable,
|
if pydantic_dataclasses:
|
||||||
"-m",
|
plugin_path = Path("src/betterproto/plugin/main.py")
|
||||||
"grpc.tools.protoc",
|
|
||||||
f"--proto_path={path.as_posix()}",
|
if "Win" in platform.system():
|
||||||
f"--{python_out_option}={output_dir.as_posix()}",
|
with tempfile.NamedTemporaryFile(
|
||||||
*[p.as_posix() for p in path.glob("*.proto")],
|
"w", encoding="UTF-8", suffix=".bat", delete=False
|
||||||
]
|
) as tf:
|
||||||
|
# See https://stackoverflow.com/a/42622705
|
||||||
|
tf.writelines(
|
||||||
|
[
|
||||||
|
"@echo off",
|
||||||
|
f"\nchdir {os.getcwd()}",
|
||||||
|
f"\n{sys.executable} -u {plugin_path.as_posix()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
tf.flush()
|
||||||
|
|
||||||
|
plugin_path = Path(tf.name)
|
||||||
|
atexit.register(os.remove, plugin_path)
|
||||||
|
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"grpc.tools.protoc",
|
||||||
|
f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
|
||||||
|
"--experimental_allow_proto3_optional",
|
||||||
|
"--custom_opt=pydantic_dataclasses",
|
||||||
|
f"--proto_path={path.as_posix()}",
|
||||||
|
f"--custom_out={output_dir.as_posix()}",
|
||||||
|
*[p.as_posix() for p in path.glob("*.proto")],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"grpc.tools.protoc",
|
||||||
|
f"--proto_path={path.as_posix()}",
|
||||||
|
f"--{python_out_option}={output_dir.as_posix()}",
|
||||||
|
*[p.as_posix() for p in path.glob("*.proto")],
|
||||||
|
]
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user