diff --git a/README.md b/README.md index 48da7fa..38fb168 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i - Timezone-aware `datetime` and `timedelta` objects - Relative imports - Mypy type checking +- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models) This project is heavily inspired by, and borrows functionality from: @@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc) {'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'} ``` +## Generating Pydantic Models + +You can use python-betterproto to generate pydantic based models, using +pydantic dataclasses. This means the results of the protobuf unmarshalling will +be typed checked. The usage is the same, but you need to add a custom option +when calling the protobuf compiler: + + +``` +protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto +``` + +With the important change being `--custom_opt=pydantic_dataclasses`. This will +swap the dataclass implementation from the builtin python dataclass to the +pydantic dataclass. You must have pydantic as a dependency in your project for +this to work. + + + ## Development - _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_ diff --git a/poetry.lock b/poetry.lock index b9f3319..84d1b43 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1248,6 +1248,59 @@ files = [ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] +[[package]] +name = "pydantic" +version = "1.10.4" +description = "Data validation and settings management using python type hints" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-1.10.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5635de53e6686fe7a44b5cf25fcc419a0d5e5c1a1efe73d49d48fe7586db854"}, + {file = "pydantic-1.10.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dc1cc241440ed7ca9ab59d9929075445da6b7c94ced281b3dd4cfe6c8cff817"}, + {file = "pydantic-1.10.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51bdeb10d2db0f288e71d49c9cefa609bca271720ecd0c58009bd7504a0c464c"}, + {file = "pydantic-1.10.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78cec42b95dbb500a1f7120bdf95c401f6abb616bbe8785ef09887306792e66e"}, + {file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8775d4ef5e7299a2f4699501077a0defdaac5b6c4321173bcb0f3c496fbadf85"}, + {file = "pydantic-1.10.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:572066051eeac73d23f95ba9a71349c42a3e05999d0ee1572b7860235b850cc6"}, + {file = "pydantic-1.10.4-cp310-cp310-win_amd64.whl", hash = "sha256:7feb6a2d401f4d6863050f58325b8d99c1e56f4512d98b11ac64ad1751dc647d"}, + {file = "pydantic-1.10.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:39f4a73e5342b25c2959529f07f026ef58147249f9b7431e1ba8414a36761f53"}, + {file = "pydantic-1.10.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:983e720704431a6573d626b00662eb78a07148c9115129f9b4351091ec95ecc3"}, + {file = "pydantic-1.10.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d52162fe6b2b55964fbb0af2ee58e99791a3138588c482572bb6087953113a"}, + {file = "pydantic-1.10.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdf8d759ef326962b4678d89e275ffc55b7ce59d917d9f72233762061fd04a2d"}, + {file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:05a81b006be15655b2a1bae5faa4280cf7c81d0e09fcb49b342ebf826abe5a72"}, + {file = "pydantic-1.10.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d88c4c0e5c5dfd05092a4b271282ef0588e5f4aaf345778056fc5259ba098857"}, + {file = "pydantic-1.10.4-cp311-cp311-win_amd64.whl", hash = "sha256:6a05a9db1ef5be0fe63e988f9617ca2551013f55000289c671f71ec16f4985e3"}, + {file = "pydantic-1.10.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:887ca463c3bc47103c123bc06919c86720e80e1214aab79e9b779cda0ff92a00"}, + {file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdf88ab63c3ee282c76d652fc86518aacb737ff35796023fae56a65ced1a5978"}, + {file = "pydantic-1.10.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a48f1953c4a1d9bd0b5167ac50da9a79f6072c63c4cef4cf2a3736994903583e"}, + {file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:a9f2de23bec87ff306aef658384b02aa7c32389766af3c5dee9ce33e80222dfa"}, + {file = "pydantic-1.10.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:cd8702c5142afda03dc2b1ee6bc358b62b3735b2cce53fc77b31ca9f728e4bc8"}, + {file = "pydantic-1.10.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6e7124d6855b2780611d9f5e1e145e86667eaa3bd9459192c8dc1a097f5e9903"}, + {file = "pydantic-1.10.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b53e1d41e97063d51a02821b80538053ee4608b9a181c1005441f1673c55423"}, + {file = "pydantic-1.10.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55b1625899acd33229c4352ce0ae54038529b412bd51c4915349b49ca575258f"}, + {file = "pydantic-1.10.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301d626a59edbe5dfb48fcae245896379a450d04baeed50ef40d8199f2733b06"}, + {file = "pydantic-1.10.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f9d649892a6f54a39ed56b8dfd5e08b5f3be5f893da430bed76975f3735d15"}, + {file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7b5a3821225f5c43496c324b0d6875fde910a1c2933d726a743ce328fbb2a8c"}, + {file = "pydantic-1.10.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f2f7eb6273dd12472d7f218e1fef6f7c7c2f00ac2e1ecde4db8824c457300416"}, + {file = "pydantic-1.10.4-cp38-cp38-win_amd64.whl", hash = "sha256:4b05697738e7d2040696b0a66d9f0a10bec0efa1883ca75ee9e55baf511909d6"}, + {file = "pydantic-1.10.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a9a6747cac06c2beb466064dda999a13176b23535e4c496c9d48e6406f92d42d"}, + {file = "pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb992a1ef739cc7b543576337bebfc62c0e6567434e522e97291b251a41dad7f"}, + {file = "pydantic-1.10.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:990406d226dea0e8f25f643b370224771878142155b879784ce89f633541a024"}, + {file = "pydantic-1.10.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e82a6d37a95e0b1b42b82ab340ada3963aea1317fd7f888bb6b9dfbf4fff57c"}, + {file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9193d4f4ee8feca58bc56c8306bcb820f5c7905fd919e0750acdeeeef0615b28"}, + {file = "pydantic-1.10.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2b3ce5f16deb45c472dde1a0ee05619298c864a20cded09c4edd820e1454129f"}, + {file = "pydantic-1.10.4-cp39-cp39-win_amd64.whl", hash = "sha256:9cbdc268a62d9a98c56e2452d6c41c0263d64a2009aac69246486f01b4f594c4"}, + {file = "pydantic-1.10.4-py3-none-any.whl", hash = "sha256:4948f264678c703f3877d1c8877c4e3b2e12e549c57795107f08cf70c6ec7774"}, + {file = "pydantic-1.10.4.tar.gz", hash = "sha256:b9a3859f24eb4e097502a3be1fb4b2abb79b6103dd9e2e0edb70613a4459a648"}, +] + +[package.dependencies] +typing-extensions = ">=4.2.0" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pygments" version = "2.14.0" @@ -1386,6 +1439,13 @@ files = [ {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, @@ -1808,4 +1868,4 @@ compiler = ["black", "isort", "jinja2"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "80c6215b1b18ae9de0b8ed966cfab05018272dd7eff7dde0fb28d2b2231ac051" +content-hash = "f9503e42026d0807dc3ed344b5f8e6fa3f1c7ffa9c66b086de929aefaa8cb8c6" diff --git a/pyproject.toml b/pyproject.toml index 47dc937..81dba4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0" tomlkit = "^0.7.0" tox = "^3.15.1" pre-commit = "^2.17.0" +pydantic = ">=1.8.0" [tool.poetry.scripts] diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index e217314..860f0cf 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -628,7 +628,6 @@ class Message(ABC): # Set current field of each group after `__init__` has already been run. group_current: Dict[str, Optional[str]] = {} for field_name, meta in self._betterproto.meta_by_field_name.items(): - if meta.group: group_current.setdefault(meta.group) @@ -1470,6 +1469,24 @@ class Message(ABC): ) return self.__raw_get(name) is not default + @classmethod + def _validate_field_groups(cls, values): + meta = cls._betterproto_meta.oneof_field_by_group # type: ignore + + for group, field_set in meta.items(): + set_fields = [ + field.name for field in field_set if values[field.name] is not None + ] + if not set_fields: + raise ValueError(f"Group {group} has no value; all fields are None") + elif len(set_fields) > 1: + set_fields_str = ", ".join(set_fields) + raise ValueError( + f"Group {group} has more than one value; fields {set_fields_str} are not None" + ) + + return values + def serialized_on_wire(message: Message) -> bool: """ diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 1d3239c..ea819d4 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -214,7 +214,6 @@ class ProtoContentBase: @dataclass class PluginRequestCompiler: - plugin_request_obj: CodeGeneratorRequest output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) @@ -247,11 +246,13 @@ class OutputTemplate: imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) + pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False messages: List["MessageCompiler"] = field(default_factory=list) enums: List["EnumDefinitionCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list) imports_type_checking_only: Set[str] = field(default_factory=set) + pydantic_dataclasses: bool = False output: bool = True @property @@ -334,6 +335,20 @@ class MessageCompiler(ProtoContentBase): def has_deprecated_fields(self) -> bool: return any(self.deprecated_fields) + @property + def has_oneof_fields(self) -> bool: + return any(isinstance(field, OneOfFieldCompiler) for field in self.fields) + + @property + def has_message_field(self) -> bool: + return any( + ( + field.proto_obj.type in PROTO_MESSAGE_TYPES + for field in self.fields + if isinstance(field.proto_obj, FieldDescriptorProto) + ) + ) + def is_map( proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto @@ -431,6 +446,10 @@ class FieldCompiler(MessageCompiler): imports.add("Dict") return imports + @property + def pydantic_imports(self) -> Set[str]: + return set() + @property def use_builtins(self) -> bool: return self.py_type in self.parent.builtins_types or ( @@ -440,6 +459,7 @@ class FieldCompiler(MessageCompiler): def add_imports_to(self, output_file: OutputTemplate) -> None: output_file.datetime_imports.update(self.datetime_imports) output_file.typing_imports.update(self.typing_imports) + output_file.pydantic_imports.update(self.pydantic_imports) output_file.builtins_import = output_file.builtins_import or self.use_builtins @property @@ -568,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler): return args +@dataclass +class PydanticOneOfFieldCompiler(OneOfFieldCompiler): + @property + def optional(self) -> bool: + # Force the optional to be True. This will allow the pydantic dataclass + # to validate the object correctly by allowing the field to be let empty. + # We add a pydantic validator later to ensure exactly one field is defined. + return True + + @property + def pydantic_imports(self) -> Set[str]: + return {"root_validator"} + + @dataclass class MapEntryCompiler(FieldCompiler): py_k_type: Type = PLACEHOLDER @@ -679,7 +713,6 @@ class ServiceCompiler(ProtoContentBase): @dataclass class ServiceMethodCompiler(ProtoContentBase): - parent: ServiceCompiler proto_obj: MethodDescriptorProto path: List[int] = PLACEHOLDER diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index c0d32f6..358cc20 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -11,6 +11,7 @@ from typing import ( from betterproto.lib.google.protobuf import ( DescriptorProto, EnumDescriptorProto, + FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto, ) @@ -30,6 +31,7 @@ from .models import ( OneOfFieldCompiler, OutputTemplate, PluginRequestCompiler, + PydanticOneOfFieldCompiler, ServiceCompiler, ServiceMethodCompiler, is_map, @@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # skip outputting Google's well-known types request_data.output_packages[output_package_name].output = False + if "pydantic_dataclasses" in plugin_options: + request_data.output_packages[ + output_package_name + ].pydantic_dataclasses = True + # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service @@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: return response +def _make_one_of_field_compiler( + output_package: OutputTemplate, + source_file: "FileDescriptorProto", + parent: MessageCompiler, + proto_obj: "FieldDescriptorProto", + path: List[int], +) -> FieldCompiler: + + pydantic = output_package.pydantic_dataclasses + Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler + return Cls( + source_file=source_file, + parent=parent, + proto_obj=proto_obj, + path=path, + ) + + def read_protobuf_type( item: DescriptorProto, path: List[int], @@ -168,11 +193,8 @@ def read_protobuf_type( path=path + [2, index], ) elif is_oneof(field): - OneOfFieldCompiler( - source_file=source_file, - parent=message_data, - proto_obj=field, - path=path + [2, index], + _make_one_of_field_compiler( + output_package, source_file, message_data, field, path + [2, index] ) else: FieldCompiler( diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index c80cd77..643fbe6 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -5,7 +5,13 @@ {% for i in output_file.python_module_imports|sort %} import {{ i }} {% endfor %} + +{% if output_file.pydantic_dataclasses %} +from pydantic.dataclasses import dataclass +{%- else -%} from dataclasses import dataclass +{% endif %} + {% if output_file.datetime_imports %} from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} @@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} +{% if output_file.pydantic_imports %} +from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} + +{% endif %} + import betterproto {% if output_file.services %} from betterproto.grpc.grpclib_server import ServiceBase @@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message): {% endfor %} {% endif %} + {% if output_file.pydantic_dataclasses and message.has_oneof_fields %} + @root_validator() + def check_oneof(cls, values): + return cls._validate_field_groups(values) + {% endif %} {% endfor %} {% for service in output_file.services %} @@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase): } {% endfor %} + +{% if output_file.pydantic_dataclasses %} +{% for message in output_file.messages %} +{% if message.has_message_field %} +{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore +{% endif %} +{% endfor %} +{% endif %} diff --git a/tests/generate.py b/tests/generate.py index afbc4e2..9ce375f 100755 --- a/tests/generate.py +++ b/tests/generate.py @@ -11,6 +11,7 @@ from tests.util import ( get_directories, inputs_path, output_path_betterproto, + output_path_betterproto_pydantic, output_path_reference, protoc, ) @@ -80,9 +81,11 @@ async def generate_test_case_output( test_case_output_path_reference = output_path_reference.joinpath(test_case_name) test_case_output_path_betterproto = output_path_betterproto + test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) + os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True) clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) @@ -90,9 +93,13 @@ async def generate_test_case_output( ( (ref_out, ref_err, ref_code), (plg_out, plg_err, plg_code), + (plg_out_pyd, plg_err_pyd, plg_code_pyd), ) = await asyncio.gather( protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_betterproto, False), + protoc( + test_case_input_path, test_case_output_path_betterproto_pyd, False, True + ), ) if ref_code == 0: @@ -131,7 +138,27 @@ async def generate_test_case_output( sys.stderr.buffer.write(plg_err) sys.stderr.buffer.flush() - return max(ref_code, plg_code) + if plg_code_pyd == 0: + print( + f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m" + ) + else: + print( + f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" + ) + + if verbose: + if plg_out_pyd: + print("Plugin stdout:") + sys.stdout.buffer.write(plg_out_pyd) + sys.stdout.buffer.flush() + + if plg_err_pyd: + print("Plugin stderr:") + sys.stderr.buffer.write(plg_err_pyd) + sys.stderr.buffer.flush() + + return max(ref_code, plg_code, plg_code_pyd) HELP = "\n".join( diff --git a/tests/inputs/bool/test_bool.py b/tests/inputs/bool/test_bool.py index e91bf0a..c32a170 100644 --- a/tests/inputs/bool/test_bool.py +++ b/tests/inputs/bool/test_bool.py @@ -1,6 +1,19 @@ +import pytest + from tests.output_betterproto.bool import Test +from tests.output_betterproto_pydantic.bool import Test as TestPyd def test_value(): message = Test() assert not message.value, "Boolean is False by default" + + +def test_pydantic_no_value(): + with pytest.raises(ValueError): + TestPyd() + + +def test_pydantic_value(): + message = Test(value=False) + assert not message.value diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py index d126765..92e8e77 100644 --- a/tests/inputs/oneof/test_oneof.py +++ b/tests/inputs/oneof/test_oneof.py @@ -1,5 +1,6 @@ import betterproto from tests.output_betterproto.oneof import Test +from tests.output_betterproto_pydantic.oneof import Test as TestPyd from tests.util import get_test_case_json_data @@ -13,3 +14,8 @@ def test_which_name(): message = Test() message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") + + +def test_which_count_pyd(): + message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") + assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") diff --git a/tests/util.py b/tests/util.py index e3b43aa..22c4f90 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,7 +1,10 @@ import asyncio +import atexit import importlib import os +import platform import sys +import tempfile from dataclasses import dataclass from pathlib import Path from types import ModuleType @@ -22,6 +25,7 @@ root_path = Path(__file__).resolve().parent inputs_path = root_path.joinpath("inputs") output_path_reference = root_path.joinpath("output_reference") output_path_betterproto = root_path.joinpath("output_betterproto") +output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic") def get_files(path, suffix: str) -> Generator[str, None, None]: @@ -36,19 +40,56 @@ def get_directories(path): async def protoc( - path: Union[str, Path], output_dir: Union[str, Path], reference: bool = False + path: Union[str, Path], + output_dir: Union[str, Path], + reference: bool = False, + pydantic_dataclasses: bool = False, ): path: Path = Path(path).resolve() output_dir: Path = Path(output_dir).resolve() python_out_option: str = "python_betterproto_out" if not reference else "python_out" - command = [ - sys.executable, - "-m", - "grpc.tools.protoc", - f"--proto_path={path.as_posix()}", - f"--{python_out_option}={output_dir.as_posix()}", - *[p.as_posix() for p in path.glob("*.proto")], - ] + + if pydantic_dataclasses: + plugin_path = Path("src/betterproto/plugin/main.py") + + if "Win" in platform.system(): + with tempfile.NamedTemporaryFile( + "w", encoding="UTF-8", suffix=".bat", delete=False + ) as tf: + # See https://stackoverflow.com/a/42622705 + tf.writelines( + [ + "@echo off", + f"\nchdir {os.getcwd()}", + f"\n{sys.executable} -u {plugin_path.as_posix()}", + ] + ) + + tf.flush() + + plugin_path = Path(tf.name) + atexit.register(os.remove, plugin_path) + + command = [ + sys.executable, + "-m", + "grpc.tools.protoc", + f"--plugin=protoc-gen-custom={plugin_path.as_posix()}", + "--experimental_allow_proto3_optional", + "--custom_opt=pydantic_dataclasses", + f"--proto_path={path.as_posix()}", + f"--custom_out={output_dir.as_posix()}", + *[p.as_posix() for p in path.glob("*.proto")], + ] + else: + command = [ + sys.executable, + "-m", + "grpc.tools.protoc", + f"--proto_path={path.as_posix()}", + f"--{python_out_option}={output_dir.as_posix()}", + *[p.as_posix() for p in path.glob("*.proto")], + ] proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE )