Compare commits
	
		
			85 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 87b84afc4b | ||
|  | 8283ef7298 | ||
|  | 0931eb3bf5 | ||
|  | 8f535913a1 | ||
|  | fd02cb6180 | ||
|  | 950d2f6536 | ||
|  | 29f12ea88d | ||
|  | 219233b50e | ||
|  | 2d30bdb7b2 | ||
|  | bdd3389b17 | ||
|  | 441844b97a | ||
|  | aa81680c83 | ||
|  | a413d08fc1 | ||
|  | 24d694afe2 | ||
|  | 84af157122 | ||
|  | df0c17bf0a | ||
|  | 8659c51123 | ||
|  | d1825026db | ||
|  | a12c9d24de | ||
|  | d79a9eee14 | ||
|  | d848d05710 | ||
|  | 26da86d2cd | ||
|  | 604dcb104f | ||
|  | 421aa78014 | ||
|  | 0fda2cc05d | ||
|  | 4cdf1bb9e0 | ||
|  | d203659a44 | ||
|  | 6faac1d1ca | ||
|  | 098989e9e9 | ||
|  | 182aedaec4 | ||
|  | a7532bbadc | ||
|  | 73d1fa3d5b | ||
|  | c00bc96db7 | ||
|  | d3e9621aa8 | ||
|  | fcbd8a3759 | ||
|  | aad7d2ad76 | ||
|  | 37e53fce85 | ||
|  | 2b41383745 | ||
|  | b0b6cd24ad | ||
|  | b81195eb44 | ||
|  | d2af2f2fac | ||
|  | e7f07fa2a1 | ||
|  | 50fa4e6268 | ||
|  | 2fa0be2141 | ||
|  | 13d656587c | ||
|  | 6df8cef3f0 | ||
|  | 1b1bd47cb1 | ||
|  | 0adcc9020c | ||
|  | bfc0fac754 | ||
|  | 8fbf4476a8 | ||
|  | 591ec5efb3 | ||
|  | f31d51cf3c | ||
|  | 496eba2750 | ||
|  | d663a318b7 | ||
|  | 2fb37dd108 | ||
|  | 42d2df6de6 | ||
|  | 3fd5a0d662 | ||
|  | bc13e7070d | ||
|  | 6536181902 | ||
|  | 85e4be96d8 | ||
|  | 06c26ba60d | ||
|  | 6a70b8e8ea | ||
|  | 3ca092a724 | ||
|  | 6f7d706a8e | ||
|  | ac96d8254b | ||
|  | e7133adeb3 | ||
|  | 204e04dd69 | ||
|  | b9b0b22d57 | ||
|  | 402c21256f | ||
|  | 5f7e4d58ef | ||
|  | 1aaf7728cc | ||
|  | 70310c9e8c | ||
|  | 18a518efa7 | ||
|  | 62da35b3ea | ||
|  | 69f4192341 | ||
|  | 9c1bf25304 | ||
|  | a836fb23bc | ||
|  | bd69862a02 | ||
|  | 74205e3319 | ||
|  | 3f377e3bfd | ||
|  | 8c727d904f | ||
|  | eeddc844a5 | ||
|  | 9b5594adbe | ||
|  | d991040ff6 | ||
|  | d260f071e0 | 
							
								
								
									
										16
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -13,17 +13,15 @@ jobs: | |||||||
|     name: ${{ matrix.os }} / ${{ matrix.python-version }} |     name: ${{ matrix.os }} / ${{ matrix.python-version }} | ||||||
|     runs-on: ${{ matrix.os }}-latest |     runs-on: ${{ matrix.os }}-latest | ||||||
|     strategy: |     strategy: | ||||||
|  |       fail-fast: false | ||||||
|       matrix: |       matrix: | ||||||
|         os: [Ubuntu, MacOS, Windows] |         os: [Ubuntu, MacOS, Windows] | ||||||
|         python-version: ['3.6.7', '3.7', '3.8', '3.9', '3.10'] |         python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] | ||||||
|         exclude: |  | ||||||
|           - os: Windows |  | ||||||
|             python-version: 3.6 |  | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|       - name: Set up Python ${{ matrix.python-version }} |       - name: Set up Python ${{ matrix.python-version }} | ||||||
|         uses: actions/setup-python@v2 |         uses: actions/setup-python@v4 | ||||||
|         with: |         with: | ||||||
|           python-version: ${{ matrix.python-version }} |           python-version: ${{ matrix.python-version }} | ||||||
|  |  | ||||||
| @@ -43,7 +41,7 @@ jobs: | |||||||
|         run: poetry config virtualenvs.in-project true |         run: poetry config virtualenvs.in-project true | ||||||
|  |  | ||||||
|       - name: Set up cache |       - name: Set up cache | ||||||
|         uses: actions/cache@v2 |         uses: actions/cache@v3 | ||||||
|         id: cache |         id: cache | ||||||
|         with: |         with: | ||||||
|           path: .venv |           path: .venv | ||||||
| @@ -56,9 +54,7 @@ jobs: | |||||||
|  |  | ||||||
|       - name: Install dependencies |       - name: Install dependencies | ||||||
|         shell: bash |         shell: bash | ||||||
|         run: | |         run: poetry install -E compiler | ||||||
|           poetry run python -m pip install pip -U |  | ||||||
|           poetry install |  | ||||||
|  |  | ||||||
|       - name: Generate code from proto files |       - name: Generate code from proto files | ||||||
|         shell: bash |         shell: bash | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								.github/workflows/code-quality.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/code-quality.yml
									
									
									
									
										vendored
									
									
								
							| @@ -13,14 +13,6 @@ jobs: | |||||||
|     name: Check code/doc formatting |     name: Check code/doc formatting | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|     - uses: actions/checkout@v2 |     - uses: actions/checkout@v3 | ||||||
|     - name: Run Black |     - uses: actions/setup-python@v4 | ||||||
|       uses: lgeiger/black-action@master |     - uses: pre-commit/action@v2.0.3 | ||||||
|       with: |  | ||||||
|         args: --check src/ tests/ benchmarks/ |  | ||||||
|  |  | ||||||
|     - name: Install rST dependcies |  | ||||||
|       run: python -m pip install doc8 |  | ||||||
|     - name: Lint documentation for errors |  | ||||||
|       run: python -m doc8 docs --max-line-length 88 --ignore-path-errors "docs/migrating.rst;D001" |  | ||||||
|       # it has a table which is longer than 88 characters long |  | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								.github/workflows/release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -15,9 +15,9 @@ jobs: | |||||||
|     name: Distribution |     name: Distribution | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v2 |       - uses: actions/checkout@v3 | ||||||
|       - name: Set up Python 3.8 |       - name: Set up Python 3.8 | ||||||
|         uses: actions/setup-python@v2 |         uses: actions/setup-python@v4 | ||||||
|         with: |         with: | ||||||
|           python-version: 3.8 |           python-version: 3.8 | ||||||
|       - name: Install poetry |       - name: Install poetry | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -17,3 +17,4 @@ output | |||||||
| .venv | .venv | ||||||
| .asv | .asv | ||||||
| venv | venv | ||||||
|  | .devcontainer | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | ci: | ||||||
|  |   autofix_prs: false | ||||||
|  |  | ||||||
|  | repos: | ||||||
|  |   - repo: https://github.com/pycqa/isort | ||||||
|  |     rev: 5.11.5 | ||||||
|  |     hooks: | ||||||
|  |       - id: isort | ||||||
|  |  | ||||||
|  |   - repo: https://github.com/psf/black | ||||||
|  |     rev: 23.1.0 | ||||||
|  |     hooks: | ||||||
|  |       - id: black | ||||||
|  |         args: ["--target-version", "py310"] | ||||||
|  |  | ||||||
|  |   - repo: https://github.com/PyCQA/doc8 | ||||||
|  |     rev: 0.10.1 | ||||||
|  |     hooks: | ||||||
|  |     -   id: doc8 | ||||||
|  |         additional_dependencies: | ||||||
|  |           - toml | ||||||
							
								
								
									
										72
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -7,6 +7,78 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 | |||||||
|  |  | ||||||
| - Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`. | - Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`. | ||||||
|  |  | ||||||
|  | ## [2.0.0b6] - 2023-06-25 | ||||||
|  |  | ||||||
|  | - **Breaking**: the minimum Python version has been bumped to `3.7` [#444](https://github.com/danielgtaylor/python-betterproto/pull/444) | ||||||
|  |  | ||||||
|  | - Support generating [Pydantic dataclasses](https://docs.pydantic.dev/latest/usage/dataclasses).  | ||||||
|  |   Pydantic dataclasses are are drop-in replacement for dataclasses in the standard library that additionally supports validation. | ||||||
|  |   Pass `--python_betterproto_opt=pydantic_dataclasses` to enable this feature. | ||||||
|  |   Refer to [#406](https://github.com/danielgtaylor/python-betterproto/pull/406) | ||||||
|  |   and [README.md](https://github.com/danielgtaylor/python-betterproto#generating-pydantic-models) for more information. | ||||||
|  |    | ||||||
|  | - Added support for `@generated` marker [#382](https://github.com/danielgtaylor/python-betterproto/pull/382) | ||||||
|  | - Pull down the `include_default_values` argument to `to_json()` [#405](https://github.com/danielgtaylor/python-betterproto/pull/405) | ||||||
|  | - Pythonize input_type name in py_input_message [#436](https://github.com/danielgtaylor/python-betterproto/pull/436) | ||||||
|  | - Widen `from_dict()` to accept any `Mapping` [#451](https://github.com/danielgtaylor/python-betterproto/pull/451) | ||||||
|  | - Replace `pkg_resources` with `importlib` [#462](https://github.com/danielgtaylor/python-betterproto/pull/462) | ||||||
|  |    | ||||||
|  | - Fix typechecker compatiblity checks in server streaming methods [#413](https://github.com/danielgtaylor/python-betterproto/pull/413) | ||||||
|  | - Fix "empty-valued" repeated fields not being serialised [#417](https://github.com/danielgtaylor/python-betterproto/pull/417) | ||||||
|  | - Fix `dict` encoding for timezone-aware `datetimes` [#468](https://github.com/danielgtaylor/python-betterproto/pull/468) | ||||||
|  | - Fix `to_pydict()` serialization for optional fields [#495](https://github.com/danielgtaylor/python-betterproto/pull/495) | ||||||
|  | - Handle empty value objects properly [#481](https://github.com/danielgtaylor/python-betterproto/pull/481) | ||||||
|  |    | ||||||
|  | ## [2.0.0b5] - 2022-08-01 | ||||||
|  |  | ||||||
|  | - **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters [#331](https://github.com/danielgtaylor/python-betterproto/pull/311) | ||||||
|  |  | ||||||
|  |     Update your client calls and server handlers as follows: | ||||||
|  |  | ||||||
|  |     Clients before: | ||||||
|  |  | ||||||
|  |     ```py | ||||||
|  |     response = await service.echo(value="hello", extra_times=1) | ||||||
|  |     ``` | ||||||
|  |  | ||||||
|  |     Clients after: | ||||||
|  |  | ||||||
|  |     ```py | ||||||
|  |     response = await service.echo(EchoRequest(value="hello", extra_times=1)) | ||||||
|  |     ``` | ||||||
|  |  | ||||||
|  |     Servers before: | ||||||
|  |  | ||||||
|  |     ```py | ||||||
|  |     async def echo(self, value: str, extra_times: int) -> EchoResponse: ... | ||||||
|  |     ``` | ||||||
|  |  | ||||||
|  |     Servers after: | ||||||
|  |  | ||||||
|  |     ```py | ||||||
|  |     async def echo(self, echo_request: EchoRequest) -> EchoResponse: | ||||||
|  |         # Use echo_request.value | ||||||
|  |         # Use echo_request.extra_times | ||||||
|  |         ... | ||||||
|  |     ``` | ||||||
|  |  | ||||||
|  | - Add `to/from_pydict()` for `Message` [#203](https://github.com/danielgtaylor/python-betterproto/pull/203) | ||||||
|  | - Format field comments also as docstrings [#304](https://github.com/danielgtaylor/python-betterproto/pull/304) | ||||||
|  | - Implement `__deepcopy__` for `Message` [#339](https://github.com/danielgtaylor/python-betterproto/pull/339) | ||||||
|  | - Run isort on compiled code [#355](https://github.com/danielgtaylor/python-betterproto/pull/355) | ||||||
|  | - Expose timeout, deadline and metadata parameters from grpclib [#352](https://github.com/danielgtaylor/python-betterproto/pull/352) | ||||||
|  | - Make `Message.__getattribute__` invisible to type checkers [#359](https://github.com/danielgtaylor/python-betterproto/pull/359) | ||||||
|  |  | ||||||
|  | - Fix map field edge-case [#254](https://github.com/danielgtaylor/python-betterproto/pull/254) | ||||||
|  | - Fix message text in `NotImplementedError` [#325](https://github.com/danielgtaylor/python-betterproto/pull/325) | ||||||
|  | - Fix `Message.from_dict()` in the presence of optional datetime fields [#329](https://github.com/danielgtaylor/python-betterproto/pull/329) | ||||||
|  | - Support Jinja2 3.0 to prevent version conflicts [#330](https://github.com/danielgtaylor/python-betterproto/pull/330) | ||||||
|  | - Fix overwriting top level `__init__.py` [#337](https://github.com/danielgtaylor/python-betterproto/pull/337) | ||||||
|  | - Remove deprecation warnings when fields are initialised with non-default values [#348](https://github.com/danielgtaylor/python-betterproto/pull/348) | ||||||
|  | - Ensure nested class names are converted to PascalCase [#353](https://github.com/danielgtaylor/python-betterproto/pull/353) | ||||||
|  | - Fix `Message.to_dict()` mutating the underlying Message [#378](https://github.com/danielgtaylor/python-betterproto/pull/378) | ||||||
|  | - Fix some parameters being missing from services [#381](https://github.com/danielgtaylor/python-betterproto/pull/381) | ||||||
|  |  | ||||||
| ## [2.0.0b4] - 2022-01-03 | ## [2.0.0b4] - 2022-01-03 | ||||||
|  |  | ||||||
| - **Breaking**: the minimum Python version has been bumped to `3.6.2` | - **Breaking**: the minimum Python version has been bumped to `3.6.2` | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								LICENSE.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | MIT License | ||||||
|  |  | ||||||
|  | Copyright (c) 2023 Daniel G. Taylor | ||||||
|  |  | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  |  | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  |  | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										52
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								README.md
									
									
									
									
									
								
							| @@ -7,13 +7,14 @@ This project aims to provide an improved experience when using Protobuf / gRPC i | |||||||
|  |  | ||||||
| - Protobuf 3 & gRPC code generation | - Protobuf 3 & gRPC code generation | ||||||
|   - Both binary & JSON serialization is built-in |   - Both binary & JSON serialization is built-in | ||||||
| - Python 3.6+ making use of: | - Python 3.7+ making use of: | ||||||
|   - Enums |   - Enums | ||||||
|   - Dataclasses |   - Dataclasses | ||||||
|   - `async`/`await` |   - `async`/`await` | ||||||
|   - Timezone-aware `datetime` and `timedelta` objects |   - Timezone-aware `datetime` and `timedelta` objects | ||||||
|   - Relative imports |   - Relative imports | ||||||
|   - Mypy type checking |   - Mypy type checking | ||||||
|  | - [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models) | ||||||
|  |  | ||||||
| This project is heavily inspired by, and borrows functionality from: | This project is heavily inspired by, and borrows functionality from: | ||||||
|  |  | ||||||
| @@ -38,6 +39,8 @@ This project exists because I am unhappy with the state of the official Google p | |||||||
|   - Uses `SerializeToString()` rather than the built-in `__bytes__()` |   - Uses `SerializeToString()` rather than the built-in `__bytes__()` | ||||||
|   - Special wrapped types don't use Python's `None` |   - Special wrapped types don't use Python's `None` | ||||||
|   - Timestamp/duration types don't use Python's built-in `datetime` module |   - Timestamp/duration types don't use Python's built-in `datetime` module | ||||||
|  |  | ||||||
|  |  | ||||||
| This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. | This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. | ||||||
|  |  | ||||||
| ## Installation | ## Installation | ||||||
| @@ -58,7 +61,7 @@ pip install betterproto | |||||||
|  |  | ||||||
| ### Compiling proto files | ### Compiling proto files | ||||||
|  |  | ||||||
| Now, given you installed the compiler and have a proto file, e.g `example.proto`: | Given you installed the compiler and have a proto file, e.g `example.proto`: | ||||||
|  |  | ||||||
| ```protobuf | ```protobuf | ||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
| @@ -177,10 +180,10 @@ from grpclib.client import Channel | |||||||
| async def main(): | async def main(): | ||||||
|     channel = Channel(host="127.0.0.1", port=50051) |     channel = Channel(host="127.0.0.1", port=50051) | ||||||
|     service = echo.EchoStub(channel) |     service = echo.EchoStub(channel) | ||||||
|     response = await service.echo(value="hello", extra_times=1) |     response = await service.echo(echo.EchoRequest(value="hello", extra_times=1)) | ||||||
|     print(response) |     print(response) | ||||||
|  |  | ||||||
|     async for response in service.echo_stream(value="hello", extra_times=1): |     async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)): | ||||||
|         print(response) |         print(response) | ||||||
|  |  | ||||||
|     # don't forget to close the channel when done! |     # don't forget to close the channel when done! | ||||||
| @@ -192,6 +195,7 @@ if __name__ == "__main__": | |||||||
|     loop.run_until_complete(main()) |     loop.run_until_complete(main()) | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| which would output | which would output | ||||||
| ```python | ```python | ||||||
| EchoResponse(values=['hello', 'hello']) | EchoResponse(values=['hello', 'hello']) | ||||||
| @@ -206,18 +210,18 @@ service methods: | |||||||
|  |  | ||||||
| ```python | ```python | ||||||
| import asyncio | import asyncio | ||||||
| from echo import EchoBase, EchoResponse, EchoStreamResponse | from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse | ||||||
| from grpclib.server import Server | from grpclib.server import Server | ||||||
| from typing import AsyncIterator | from typing import AsyncIterator | ||||||
|  |  | ||||||
|  |  | ||||||
| class EchoService(EchoBase): | class EchoService(EchoBase): | ||||||
|     async def echo(self, value: str, extra_times: int) -> "EchoResponse": |     async def echo(self, echo_request: "EchoRequest") -> "EchoResponse": | ||||||
|         return EchoResponse([value for _ in range(extra_times)]) |         return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)]) | ||||||
|  |  | ||||||
|     async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]: |     async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]: | ||||||
|         for _ in range(extra_times): |         for _ in range(echo_request.extra_times): | ||||||
|             yield EchoStreamResponse(value) |             yield EchoStreamResponse(echo_request.value) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def main(): | async def main(): | ||||||
| @@ -361,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc) | |||||||
| {'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'} | {'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'} | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | ## Generating Pydantic Models | ||||||
|  |  | ||||||
|  | You can use python-betterproto to generate pydantic based models, using | ||||||
|  | pydantic dataclasses. This means the results of the protobuf unmarshalling will | ||||||
|  | be typed checked. The usage is the same, but you need to add a custom option | ||||||
|  | when calling the protobuf compiler: | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | protoc -I . --python_betterproto_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | With the important change being `--python_betterproto_opt=pydantic_dataclasses`. This will | ||||||
|  | swap the dataclass implementation from the builtin python dataclass to the | ||||||
|  | pydantic dataclass. You must have pydantic as a dependency in your project for | ||||||
|  | this to work. | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Development | ## Development | ||||||
|  |  | ||||||
| - _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_ | - _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_ | ||||||
| @@ -368,7 +391,7 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc) | |||||||
|  |  | ||||||
| ### Requirements | ### Requirements | ||||||
|  |  | ||||||
| - Python (3.6 or higher) | - Python (3.7 or higher) | ||||||
|  |  | ||||||
| - [poetry](https://python-poetry.org/docs/#installation) | - [poetry](https://python-poetry.org/docs/#installation) | ||||||
|   *Needed to install dependencies in a virtual environment* |   *Needed to install dependencies in a virtual environment* | ||||||
| @@ -381,8 +404,7 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc) | |||||||
|  |  | ||||||
| ```sh | ```sh | ||||||
| # Get set up with the virtual env & dependencies | # Get set up with the virtual env & dependencies | ||||||
| poetry run pip install --upgrade pip | poetry install -E compiler | ||||||
| poetry install |  | ||||||
|  |  | ||||||
| # Activate the poetry environment | # Activate the poetry environment | ||||||
| poetry shell | poetry shell | ||||||
| @@ -417,7 +439,7 @@ Adding a standard test case is easy. | |||||||
|  |  | ||||||
| It will be picked up automatically when you run the tests. | It will be picked up automatically when you run the tests. | ||||||
|  |  | ||||||
| - See also: [Standard Tests Development Guide](betterproto/tests/README.md) | - See also: [Standard Tests Development Guide](tests/README.md) | ||||||
|  |  | ||||||
| #### Custom tests | #### Custom tests | ||||||
|  |  | ||||||
| @@ -442,7 +464,7 @@ poe full-test | |||||||
|  |  | ||||||
| ### (Re)compiling Google Well-known Types | ### (Re)compiling Google Well-known Types | ||||||
|  |  | ||||||
| Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google). | Betterproto includes compiled versions for Google's well-known types at [src/betterproto/lib/google](src/betterproto/lib/google). | ||||||
| Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests. | Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests. | ||||||
|  |  | ||||||
| Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. | Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. | ||||||
|   | |||||||
| @@ -1 +0,0 @@ | |||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| import betterproto |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  |  | ||||||
| from typing import List | from typing import List | ||||||
|  |  | ||||||
|  | import betterproto | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class TestMessage(betterproto.Message): | class TestMessage(betterproto.Message): | ||||||
|   | |||||||
							
								
								
									
										72
									
								
								betterproto-extras/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								betterproto-extras/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | |||||||
|  | /target | ||||||
|  |  | ||||||
|  | # Byte-compiled / optimized / DLL files | ||||||
|  | __pycache__/ | ||||||
|  | .pytest_cache/ | ||||||
|  | *.py[cod] | ||||||
|  |  | ||||||
|  | # C extensions | ||||||
|  | *.so | ||||||
|  |  | ||||||
|  | # Distribution / packaging | ||||||
|  | .Python | ||||||
|  | .venv/ | ||||||
|  | env/ | ||||||
|  | bin/ | ||||||
|  | build/ | ||||||
|  | develop-eggs/ | ||||||
|  | dist/ | ||||||
|  | eggs/ | ||||||
|  | lib/ | ||||||
|  | lib64/ | ||||||
|  | parts/ | ||||||
|  | sdist/ | ||||||
|  | var/ | ||||||
|  | include/ | ||||||
|  | man/ | ||||||
|  | venv/ | ||||||
|  | *.egg-info/ | ||||||
|  | .installed.cfg | ||||||
|  | *.egg | ||||||
|  |  | ||||||
|  | # Installer logs | ||||||
|  | pip-log.txt | ||||||
|  | pip-delete-this-directory.txt | ||||||
|  | pip-selfcheck.json | ||||||
|  |  | ||||||
|  | # Unit test / coverage reports | ||||||
|  | htmlcov/ | ||||||
|  | .tox/ | ||||||
|  | .coverage | ||||||
|  | .cache | ||||||
|  | nosetests.xml | ||||||
|  | coverage.xml | ||||||
|  |  | ||||||
|  | # Translations | ||||||
|  | *.mo | ||||||
|  |  | ||||||
|  | # Mr Developer | ||||||
|  | .mr.developer.cfg | ||||||
|  | .project | ||||||
|  | .pydevproject | ||||||
|  |  | ||||||
|  | # Rope | ||||||
|  | .ropeproject | ||||||
|  |  | ||||||
|  | # Django stuff: | ||||||
|  | *.log | ||||||
|  | *.pot | ||||||
|  |  | ||||||
|  | .DS_Store | ||||||
|  |  | ||||||
|  | # Sphinx documentation | ||||||
|  | docs/_build/ | ||||||
|  |  | ||||||
|  | # PyCharm | ||||||
|  | .idea/ | ||||||
|  |  | ||||||
|  | # VSCode | ||||||
|  | .vscode/ | ||||||
|  |  | ||||||
|  | # Pyenv | ||||||
|  | .python-version | ||||||
							
								
								
									
										383
									
								
								betterproto-extras/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										383
									
								
								betterproto-extras/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,383 @@ | |||||||
|  | # This file is automatically @generated by Cargo. | ||||||
|  | # It is not intended for manual editing. | ||||||
|  | version = 3 | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "anyhow" | ||||||
|  | version = "1.0.75" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "autocfg" | ||||||
|  | version = "1.1.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "betterproto-extras" | ||||||
|  | version = "0.1.0" | ||||||
|  | dependencies = [ | ||||||
|  |  "indoc 2.0.3", | ||||||
|  |  "prost-reflect", | ||||||
|  |  "pyo3", | ||||||
|  |  "thiserror", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "bitflags" | ||||||
|  | version = "1.3.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "bytes" | ||||||
|  | version = "1.4.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "cfg-if" | ||||||
|  | version = "1.0.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "either" | ||||||
|  | version = "1.9.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "indoc" | ||||||
|  | version = "1.0.9" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "indoc" | ||||||
|  | version = "2.0.3" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "2c785eefb63ebd0e33416dfcb8d6da0bf27ce752843a45632a67bf10d4d4b5c4" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "itertools" | ||||||
|  | version = "0.10.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" | ||||||
|  | dependencies = [ | ||||||
|  |  "either", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "libc" | ||||||
|  | version = "0.2.147" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "lock_api" | ||||||
|  | version = "0.4.10" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" | ||||||
|  | dependencies = [ | ||||||
|  |  "autocfg", | ||||||
|  |  "scopeguard", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "memoffset" | ||||||
|  | version = "0.9.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" | ||||||
|  | dependencies = [ | ||||||
|  |  "autocfg", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "once_cell" | ||||||
|  | version = "1.18.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "parking_lot" | ||||||
|  | version = "0.12.1" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" | ||||||
|  | dependencies = [ | ||||||
|  |  "lock_api", | ||||||
|  |  "parking_lot_core", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "parking_lot_core" | ||||||
|  | version = "0.9.8" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" | ||||||
|  | dependencies = [ | ||||||
|  |  "cfg-if", | ||||||
|  |  "libc", | ||||||
|  |  "redox_syscall", | ||||||
|  |  "smallvec", | ||||||
|  |  "windows-targets", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "proc-macro2" | ||||||
|  | version = "1.0.66" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" | ||||||
|  | dependencies = [ | ||||||
|  |  "unicode-ident", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "prost" | ||||||
|  | version = "0.11.9" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" | ||||||
|  | dependencies = [ | ||||||
|  |  "bytes", | ||||||
|  |  "prost-derive", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "prost-derive" | ||||||
|  | version = "0.11.9" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" | ||||||
|  | dependencies = [ | ||||||
|  |  "anyhow", | ||||||
|  |  "itertools", | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "syn 1.0.109", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "prost-reflect" | ||||||
|  | version = "0.11.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "6b823de344848e011658ac981009100818b322421676740546f8b52ed5249428" | ||||||
|  | dependencies = [ | ||||||
|  |  "once_cell", | ||||||
|  |  "prost", | ||||||
|  |  "prost-types", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "prost-types" | ||||||
|  | version = "0.11.9" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" | ||||||
|  | dependencies = [ | ||||||
|  |  "prost", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pyo3" | ||||||
|  | version = "0.19.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" | ||||||
|  | dependencies = [ | ||||||
|  |  "cfg-if", | ||||||
|  |  "indoc 1.0.9", | ||||||
|  |  "libc", | ||||||
|  |  "memoffset", | ||||||
|  |  "parking_lot", | ||||||
|  |  "pyo3-build-config", | ||||||
|  |  "pyo3-ffi", | ||||||
|  |  "pyo3-macros", | ||||||
|  |  "unindent", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pyo3-build-config" | ||||||
|  | version = "0.19.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" | ||||||
|  | dependencies = [ | ||||||
|  |  "once_cell", | ||||||
|  |  "target-lexicon", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pyo3-ffi" | ||||||
|  | version = "0.19.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" | ||||||
|  | dependencies = [ | ||||||
|  |  "libc", | ||||||
|  |  "pyo3-build-config", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pyo3-macros" | ||||||
|  | version = "0.19.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "pyo3-macros-backend", | ||||||
|  |  "quote", | ||||||
|  |  "syn 1.0.109", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pyo3-macros-backend" | ||||||
|  | version = "0.19.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "syn 1.0.109", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "quote" | ||||||
|  | version = "1.0.33" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "redox_syscall" | ||||||
|  | version = "0.3.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" | ||||||
|  | dependencies = [ | ||||||
|  |  "bitflags", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "scopeguard" | ||||||
|  | version = "1.2.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "smallvec" | ||||||
|  | version = "1.11.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "syn" | ||||||
|  | version = "1.0.109" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "unicode-ident", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "syn" | ||||||
|  | version = "2.0.29" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "unicode-ident", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "target-lexicon" | ||||||
|  | version = "0.12.11" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "thiserror" | ||||||
|  | version = "1.0.47" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" | ||||||
|  | dependencies = [ | ||||||
|  |  "thiserror-impl", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "thiserror-impl" | ||||||
|  | version = "1.0.47" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" | ||||||
|  | dependencies = [ | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "syn 2.0.29", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "unicode-ident" | ||||||
|  | version = "1.0.11" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "unindent" | ||||||
|  | version = "0.1.11" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows-targets" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" | ||||||
|  | dependencies = [ | ||||||
|  |  "windows_aarch64_gnullvm", | ||||||
|  |  "windows_aarch64_msvc", | ||||||
|  |  "windows_i686_gnu", | ||||||
|  |  "windows_i686_msvc", | ||||||
|  |  "windows_x86_64_gnu", | ||||||
|  |  "windows_x86_64_gnullvm", | ||||||
|  |  "windows_x86_64_msvc", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_aarch64_gnullvm" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_aarch64_msvc" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_i686_gnu" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_i686_msvc" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_x86_64_gnu" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_x86_64_gnullvm" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "windows_x86_64_msvc" | ||||||
|  | version = "0.48.5" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" | ||||||
							
								
								
									
										14
									
								
								betterproto-extras/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								betterproto-extras/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | [package] | ||||||
|  | name = "betterproto-extras" | ||||||
|  | version = "0.1.0" | ||||||
|  | edition = "2021" | ||||||
|  |  | ||||||
|  | [lib] | ||||||
|  | name = "betterproto_extras" | ||||||
|  | crate-type = ["cdylib"] | ||||||
|  |  | ||||||
|  | [dependencies] | ||||||
|  | indoc = "2.0.3" | ||||||
|  | prost-reflect = "0.11.5" | ||||||
|  | pyo3 = { version = "0.19.2", features = ["abi3-py37", "extension-module"] } | ||||||
|  | thiserror = "1.0.47" | ||||||
							
								
								
									
										5
									
								
								betterproto-extras/betterproto_extras.pyi
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								betterproto-extras/betterproto_extras.pyi
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | def deserialize(msg, data: bytes): | ||||||
|  |     """ | ||||||
|  |     Parses the binary encoded Protobuf `data` with respect to the metadata | ||||||
|  |     given by the betterproto message `msg`, and merges the result into `msg`. | ||||||
|  |     """ | ||||||
							
								
								
									
										16
									
								
								betterproto-extras/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								betterproto-extras/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | |||||||
|  | [build-system] | ||||||
|  | requires = ["maturin>=1.2,<2.0"] | ||||||
|  | build-backend = "maturin" | ||||||
|  |  | ||||||
|  | [project] | ||||||
|  | name = "betterproto-extras" | ||||||
|  | requires-python = ">=3.7" | ||||||
|  | classifiers = [ | ||||||
|  |     "Programming Language :: Rust", | ||||||
|  |     "Programming Language :: Python :: Implementation :: CPython", | ||||||
|  |     "Programming Language :: Python :: Implementation :: PyPy", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | [tool.maturin] | ||||||
|  | features = ["pyo3/extension-module"] | ||||||
							
								
								
									
										289
									
								
								betterproto-extras/src/descriptor_pool.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								betterproto-extras/src/descriptor_pool.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,289 @@ | |||||||
|  | use crate::{ | ||||||
|  |     error::{Error, Result}, | ||||||
|  |     py_any_extras::PyAnyExtras, | ||||||
|  | }; | ||||||
|  | use prost_reflect::{ | ||||||
|  |     prost_types::{ | ||||||
|  |         field_descriptor_proto::{Label, Type}, | ||||||
|  |         DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto, | ||||||
|  |         FileDescriptorProto, MessageOptions, OneofDescriptorProto, | ||||||
|  |     }, | ||||||
|  |     DescriptorPool, MessageDescriptor, | ||||||
|  | }; | ||||||
|  | use pyo3::PyAny; | ||||||
|  | use std::sync::{Mutex, OnceLock}; | ||||||
|  |  | ||||||
|  | pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> { | ||||||
|  |     static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new(); | ||||||
|  |     let mut pool = DESCRIPTOR_POOL | ||||||
|  |         .get_or_init(|| Mutex::new(DescriptorPool::global())) | ||||||
|  |         .lock() | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     let cls = obj.getattr("__class__")?; | ||||||
|  |     let name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier()); | ||||||
|  |     if let Some(desc) = pool.get_message_by_name(&name) { | ||||||
|  |         return Ok(desc); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let mut file = FileDescriptorProto { | ||||||
|  |         name: Some(name.clone()), | ||||||
|  |         ..Default::default() | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     add_message_to_file(name.clone(), obj, &pool, &mut file)?; | ||||||
|  |     pool.add_file_descriptor_proto(file)?; | ||||||
|  |     Ok(pool.get_message_by_name(&name).expect("Just registered...")) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn add_message_to_file( | ||||||
|  |     message_name: String, | ||||||
|  |     obj: &PyAny, | ||||||
|  |     pool: &DescriptorPool, | ||||||
|  |     file: &mut FileDescriptorProto, | ||||||
|  | ) -> Result<()> { | ||||||
|  |     let mut messages_to_add = vec![(message_name, obj)]; | ||||||
|  |  | ||||||
|  |     while let Some((message_name, obj)) = messages_to_add.pop() { | ||||||
|  |         let meta = obj.get_proto_meta()?; | ||||||
|  |         let mut message = DescriptorProto { | ||||||
|  |             name: Some(message_name.to_string()), | ||||||
|  |             ..Default::default() | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         for item in meta | ||||||
|  |             .getattr("meta_by_field_name")? | ||||||
|  |             .call_method0("items")? | ||||||
|  |             .iter()? | ||||||
|  |         { | ||||||
|  |             let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?; | ||||||
|  |             message.field.push({ | ||||||
|  |                 let mut field = FieldDescriptorProto { | ||||||
|  |                     name: Some(field_name.to_string()), | ||||||
|  |                     number: Some(field_meta.getattr("number")?.extract::<i32>()?), | ||||||
|  |                     ..Default::default() | ||||||
|  |                 }; | ||||||
|  |                 let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?; | ||||||
|  |  | ||||||
|  |                 if proto_type == "map" { | ||||||
|  |                     field.set_type(Type::Message); | ||||||
|  |                     let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?; | ||||||
|  |                     let key = map_type(key)?; | ||||||
|  |                     let val = map_type(val)?; | ||||||
|  |  | ||||||
|  |                     if matches!( | ||||||
|  |                         key, | ||||||
|  |                         Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum | ||||||
|  |                     ) { | ||||||
|  |                         return Err(Error::UnsupportedMapKeyType(key)); | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     let map_entry_name = format!("{field_name}Entry"); | ||||||
|  |                     field.type_name = Some(format!("{message_name}.{map_entry_name}")); | ||||||
|  |                     field.set_label(Label::Repeated); | ||||||
|  |                     message.nested_type.push(DescriptorProto { | ||||||
|  |                         name: Some(map_entry_name), | ||||||
|  |                         field: vec![ | ||||||
|  |                             { | ||||||
|  |                                 let mut proto = FieldDescriptorProto { | ||||||
|  |                                     name: Some("key".to_string()), | ||||||
|  |                                     number: Some(1), | ||||||
|  |                                     ..Default::default() | ||||||
|  |                                 }; | ||||||
|  |                                 proto.set_type(key); | ||||||
|  |                                 proto | ||||||
|  |                             }, | ||||||
|  |                             { | ||||||
|  |                                 let mut proto = FieldDescriptorProto { | ||||||
|  |                                     name: Some("value".to_string()), | ||||||
|  |                                     number: Some(2), | ||||||
|  |                                     ..Default::default() | ||||||
|  |                                 }; | ||||||
|  |                                 proto.set_type(val); | ||||||
|  |                                 if val == Type::Message { | ||||||
|  |                                     set_type_name( | ||||||
|  |                                         &message_name, | ||||||
|  |                                         meta.get_class(&format!("{field_name}.value"))?, | ||||||
|  |                                         &mut proto, | ||||||
|  |                                         file, | ||||||
|  |                                         &mut messages_to_add, | ||||||
|  |                                         pool, | ||||||
|  |                                     )?; | ||||||
|  |                                 } | ||||||
|  |                                 proto | ||||||
|  |                             }, | ||||||
|  |                         ], | ||||||
|  |                         options: Some(MessageOptions { | ||||||
|  |                             map_entry: Some(true), | ||||||
|  |                             ..Default::default() | ||||||
|  |                         }), | ||||||
|  |                         ..Default::default() | ||||||
|  |                     }) | ||||||
|  |                 } else { | ||||||
|  |                     field.set_type(map_type(proto_type)?); | ||||||
|  |                     match field.r#type() { | ||||||
|  |                         Type::Message => match field_meta | ||||||
|  |                             .getattr("wraps")? | ||||||
|  |                             .extract::<Option<&str>>()? | ||||||
|  |                             .map(map_type) | ||||||
|  |                             .transpose()? | ||||||
|  |                         { | ||||||
|  |                             Some(Type::Bool) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.BoolValue".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Double) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.DoubleValue".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Float) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.FloatValue".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Int64) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.Int64Value".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Uint64) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.UInt64Value".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Int32) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.Int32Value".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Uint32) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.UInt32Value".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::String) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.StringValue".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(Type::Bytes) => { | ||||||
|  |                                 field.type_name = Some("google.protobuf.BytesValue".to_string()); | ||||||
|  |                             } | ||||||
|  |                             Some(t) => return Err(Error::UnsupportedWrapperType(t)), | ||||||
|  |                             None => { | ||||||
|  |                                 set_type_name( | ||||||
|  |                                     &message_name, | ||||||
|  |                                     meta.get_class(field_name)?, | ||||||
|  |                                     &mut field, | ||||||
|  |                                     file, | ||||||
|  |                                     &mut messages_to_add, | ||||||
|  |                                     pool, | ||||||
|  |                                 )?; | ||||||
|  |                             } | ||||||
|  |                         }, | ||||||
|  |                         Type::Enum => { | ||||||
|  |                             let cls = meta.get_class(field_name)?; | ||||||
|  |                             let cls_name = | ||||||
|  |                                 format!("{}_{}", cls.qualified_name()?, cls.py_identifier()); | ||||||
|  |                             field.type_name = Some(cls_name.to_string()); | ||||||
|  |  | ||||||
|  |                             if pool.get_enum_by_name(&cls_name).is_none() | ||||||
|  |                                 && !file.enum_type.iter().any(|item| item.name() == cls_name) | ||||||
|  |                             { | ||||||
|  |                                 let mut proto = EnumDescriptorProto { | ||||||
|  |                                     name: Some(cls_name.clone()), | ||||||
|  |                                     ..Default::default() | ||||||
|  |                                 }; | ||||||
|  |  | ||||||
|  |                                 for item in cls.iter()? { | ||||||
|  |                                     let item = item?; | ||||||
|  |                                     proto.value.push(EnumValueDescriptorProto { | ||||||
|  |                                         number: Some(item.getattr("value")?.extract()?), | ||||||
|  |                                         name: Some(format!( | ||||||
|  |                                             "{}_{}", | ||||||
|  |                                             cls_name, | ||||||
|  |                                             item.getattr("name")?.extract::<&str>()? | ||||||
|  |                                         )), | ||||||
|  |                                         ..Default::default() | ||||||
|  |                                     }); | ||||||
|  |                                 } | ||||||
|  |  | ||||||
|  |                                 file.enum_type.push(proto); | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |                         _ => {} | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     if meta.is_list_field(field_name)? { | ||||||
|  |                         field.set_label(Label::Repeated); | ||||||
|  |                     } else if field_meta.getattr("optional")?.extract::<bool>()? { | ||||||
|  |                         field.proto3_optional = Some(true); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 if let Some(grp) = meta.oneof_group(field_name)? { | ||||||
|  |                     let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp); | ||||||
|  |  | ||||||
|  |                     match oneof_index { | ||||||
|  |                         Some(i) => field.oneof_index = Some(i as i32), | ||||||
|  |                         None => { | ||||||
|  |                             message.oneof_decl.push(OneofDescriptorProto { | ||||||
|  |                                 name: Some(grp), | ||||||
|  |                                 ..Default::default() | ||||||
|  |                             }); | ||||||
|  |                             field.oneof_index = Some((message.oneof_decl.len() - 1) as i32) | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 field | ||||||
|  |             }); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         file.message_type.push(message); | ||||||
|  |     } | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn map_type(str: &str) -> Result<Type> { | ||||||
|  |     match str { | ||||||
|  |         "enum" => Ok(Type::Enum), | ||||||
|  |         "bool" => Ok(Type::Bool), | ||||||
|  |         "int32" => Ok(Type::Int32), | ||||||
|  |         "int64" => Ok(Type::Int64), | ||||||
|  |         "uint32" => Ok(Type::Uint32), | ||||||
|  |         "uint64" => Ok(Type::Uint64), | ||||||
|  |         "sint32" => Ok(Type::Sint32), | ||||||
|  |         "sint64" => Ok(Type::Sint64), | ||||||
|  |         "float" => Ok(Type::Float), | ||||||
|  |         "double" => Ok(Type::Double), | ||||||
|  |         "fixed32" => Ok(Type::Fixed32), | ||||||
|  |         "sfixed32" => Ok(Type::Sfixed32), | ||||||
|  |         "fixed64" => Ok(Type::Fixed64), | ||||||
|  |         "sfixed64" => Ok(Type::Sfixed64), | ||||||
|  |         "string" => Ok(Type::String), | ||||||
|  |         "bytes" => Ok(Type::Bytes), | ||||||
|  |         "message" => Ok(Type::Message), | ||||||
|  |         _ => Err(Error::UnsupportedType(str.to_string())), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn set_type_name<'py>( | ||||||
|  |     message_name: &str, | ||||||
|  |     field_cls: &'py PyAny, | ||||||
|  |     field: &mut FieldDescriptorProto, | ||||||
|  |     file: &FileDescriptorProto, | ||||||
|  |     messages_to_add: &mut Vec<(String, &'py PyAny)>, | ||||||
|  |     pool: &DescriptorPool, | ||||||
|  | ) -> Result<()> { | ||||||
|  |     let cls_name = field_cls.qualified_name()?; | ||||||
|  |  | ||||||
|  |     match cls_name.as_str() { | ||||||
|  |         "datetime.datetime" => { | ||||||
|  |             field.type_name = Some("google.protobuf.Timestamp".to_string()); | ||||||
|  |         } | ||||||
|  |         "datetime.timedelta" => { | ||||||
|  |             field.type_name = Some("google.protobuf.Duration".to_string()); | ||||||
|  |         } | ||||||
|  |         _ => { | ||||||
|  |             let cls_name = format!("{}_{}", cls_name, field_cls.py_identifier()); | ||||||
|  |             field.type_name = Some(cls_name.clone()); | ||||||
|  |  | ||||||
|  |             if message_name != cls_name | ||||||
|  |                 && pool.get_message_by_name(&cls_name).is_none() | ||||||
|  |                 && !file.message_type.iter().any(|item| item.name() == cls_name) | ||||||
|  |                 && !messages_to_add.iter().any(|item| item.0 == cls_name) | ||||||
|  |             { | ||||||
|  |                 messages_to_add.push((cls_name, field_cls.call0()?)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
							
								
								
									
										29
									
								
								betterproto-extras/src/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								betterproto-extras/src/error.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | use prost_reflect::{ | ||||||
|  |     prost::DecodeError, prost_types::field_descriptor_proto::Type, DescriptorError, | ||||||
|  | }; | ||||||
|  | use pyo3::{exceptions::PyRuntimeError, PyErr}; | ||||||
|  | use thiserror::Error; | ||||||
|  |  | ||||||
|  | #[derive(Error, Debug)] | ||||||
|  | pub enum Error { | ||||||
|  |     #[error("Given object is not a valid betterproto message.")] | ||||||
|  |     NoBetterprotoMessage(#[from] PyErr), | ||||||
|  |     #[error("Unsupported type `{0}`.")] | ||||||
|  |     UnsupportedType(String), | ||||||
|  |     #[error("Unsupported map key type `{0:?}`.")] | ||||||
|  |     UnsupportedMapKeyType(Type), | ||||||
|  |     #[error("Unsupported wrapper type `{0:?}`.")] | ||||||
|  |     UnsupportedWrapperType(Type), | ||||||
|  |     #[error("Error on proto registration")] | ||||||
|  |     FailedToRegisterDescriptor(#[from] DescriptorError), | ||||||
|  |     #[error("The given binary data does not match the protobuf schema.")] | ||||||
|  |     FailedToDecode(#[from] DecodeError), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub type Result<T> = core::result::Result<T, Error>; | ||||||
|  |  | ||||||
|  | impl From<Error> for PyErr { | ||||||
|  |     fn from(value: Error) -> Self { | ||||||
|  |         PyRuntimeError::new_err(value.to_string()) | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										24
									
								
								betterproto-extras/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								betterproto-extras/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | mod descriptor_pool; | ||||||
|  | mod error; | ||||||
|  | mod merging; | ||||||
|  | mod py_any_extras; | ||||||
|  |  | ||||||
|  | use descriptor_pool::create_cached_descriptor; | ||||||
|  | use error::Result; | ||||||
|  | use merging::merge_msg_into_pyobj; | ||||||
|  | use prost_reflect::DynamicMessage; | ||||||
|  | use pyo3::prelude::*; | ||||||
|  |  | ||||||
|  | #[pyfunction] | ||||||
|  | fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> { | ||||||
|  |     let desc = create_cached_descriptor(obj)?; | ||||||
|  |     let msg = DynamicMessage::decode(desc, buf)?; | ||||||
|  |     merge_msg_into_pyobj(obj, msg)?; | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[pymodule] | ||||||
|  | fn betterproto_extras(_py: Python, m: &PyModule) -> PyResult<()> { | ||||||
|  |     m.add_function(wrap_pyfunction!(deserialize, m)?)?; | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
							
								
								
									
										182
									
								
								betterproto-extras/src/merging.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								betterproto-extras/src/merging.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,182 @@ | |||||||
|  | use crate::{error::Result, py_any_extras::PyAnyExtras}; | ||||||
|  | use indoc::indoc; | ||||||
|  | use prost_reflect::{ | ||||||
|  |     prost_types::{Duration, Timestamp}, | ||||||
|  |     DynamicMessage, MapKey, ReflectMessage, Value, | ||||||
|  | }; | ||||||
|  | use pyo3::{ | ||||||
|  |     sync::GILOnceCell, | ||||||
|  |     types::{IntoPyDict, PyBytes, PyModule}, | ||||||
|  |     Py, PyAny, PyObject, Python, ToPyObject, | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> { | ||||||
|  |     for field in msg.take_fields() { | ||||||
|  |         let field_name = field.0.name(); | ||||||
|  |         let proto_meta = obj.get_proto_meta()?; | ||||||
|  |         obj.setattr( | ||||||
|  |             field_name, | ||||||
|  |             map_field_value(field_name, field.1, proto_meta)?, | ||||||
|  |         )?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let mut buf = vec![]; | ||||||
|  |     for field in msg.unknown_fields() { | ||||||
|  |         field.encode(&mut buf); | ||||||
|  |     } | ||||||
|  |     if !buf.is_empty() { | ||||||
|  |         let mut unknown_fields = obj.getattr("_unknown_fields")?.extract::<Vec<u8>>()?; | ||||||
|  |         unknown_fields.append(&mut buf); | ||||||
|  |         obj.setattr("_unknown_fields", PyBytes::new(obj.py(), &unknown_fields))?; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     obj.setattr("_serialized_on_wire", true)?; | ||||||
|  |  | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> Result<PyObject> { | ||||||
|  |     let py = proto_meta.py(); | ||||||
|  |     match field_value { | ||||||
|  |         Value::Bool(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::Bytes(x) => Ok(PyBytes::new(py, &x).to_object(py)), | ||||||
|  |         Value::F32(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::F64(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::I32(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::I64(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::String(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::U32(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::U64(x) => Ok(x.to_object(py)), | ||||||
|  |         Value::Message(msg) => match msg.descriptor().full_name() { | ||||||
|  |             "google.protobuf.BoolValue" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_bool()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.DoubleValue" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_f64()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.FloatValue" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_f32()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.Int64Value" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_i64()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.UInt64Value" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_u64()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.Int32Value" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_i32()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.UInt32Value" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_u32()) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.StringValue" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_str().map(|s| s.to_string())) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.BytesValue" => Ok(msg | ||||||
|  |                 .get_field_by_number(1) | ||||||
|  |                 .and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b))) | ||||||
|  |                 .to_object(py)), | ||||||
|  |             "google.protobuf.Timestamp" => { | ||||||
|  |                 let msg = msg.transcode_to::<Timestamp>()?; | ||||||
|  |                 Ok(create_py_datetime(&msg, py)) | ||||||
|  |             } | ||||||
|  |             "google.protobuf.Duration" => { | ||||||
|  |                 let msg = msg.transcode_to::<Duration>()?; | ||||||
|  |                 Ok(create_py_timedelta(&msg, py)) | ||||||
|  |             } | ||||||
|  |             _ => { | ||||||
|  |                 let obj = proto_meta.create_instance(field_name)?; | ||||||
|  |                 merge_msg_into_pyobj(obj, msg)?; | ||||||
|  |                 Ok(obj.to_object(py)) | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|  |         Value::List(ls) => Ok(ls | ||||||
|  |             .into_iter() | ||||||
|  |             .map(|x| map_field_value(field_name, x, proto_meta)) | ||||||
|  |             .collect::<Result<Vec<PyObject>>>()? | ||||||
|  |             .to_object(py)), | ||||||
|  |         Value::EnumNumber(x) => { | ||||||
|  |             let cls = proto_meta.get_class(field_name)?; | ||||||
|  |             Ok(cls.call1((x,))?.to_object(py)) | ||||||
|  |         } | ||||||
|  |         Value::Map(map) => { | ||||||
|  |             let res: Result<Vec<_>> = map | ||||||
|  |                 .into_iter() | ||||||
|  |                 .map(|(k, v)| { | ||||||
|  |                     let key = map_key(k, py); | ||||||
|  |                     let val = map_field_value(&format!("{field_name}.value"), v, proto_meta)?; | ||||||
|  |                     Ok((key, val)) | ||||||
|  |                 }) | ||||||
|  |                 .collect(); | ||||||
|  |             Ok(res?.into_py_dict(py).to_object(py)) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn map_key(key: MapKey, py: Python) -> PyObject { | ||||||
|  |     match key { | ||||||
|  |         MapKey::Bool(x) => x.to_object(py), | ||||||
|  |         MapKey::I32(x) => x.to_object(py), | ||||||
|  |         MapKey::I64(x) => x.to_object(py), | ||||||
|  |         MapKey::U32(x) => x.to_object(py), | ||||||
|  |         MapKey::U64(x) => x.to_object(py), | ||||||
|  |         MapKey::String(x) => x.to_object(py), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn create_py_datetime(ts: &Timestamp, py: Python) -> PyObject { | ||||||
|  |     static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); | ||||||
|  |     let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || { | ||||||
|  |         let constructor = PyModule::from_code( | ||||||
|  |             py, | ||||||
|  |             indoc! {" | ||||||
|  |                 from datetime import datetime, timezone | ||||||
|  |                  | ||||||
|  |                 def constructor(ts): | ||||||
|  |                     return datetime.fromtimestamp(ts, tz=timezone.utc) | ||||||
|  |             "}, | ||||||
|  |             "", | ||||||
|  |             "", | ||||||
|  |         ) | ||||||
|  |         .expect("This is a valid Python module") | ||||||
|  |         .getattr("constructor") | ||||||
|  |         .expect("Attribute exists"); | ||||||
|  |         Py::from(constructor) | ||||||
|  |     }); | ||||||
|  |     let ts = (ts.seconds as f64) + (ts.nanos as f64) / 1e9; | ||||||
|  |     constructor | ||||||
|  |         .call1(py, (ts,)) | ||||||
|  |         .expect("static function will not fail") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn create_py_timedelta(duration: &Duration, py: Python) -> PyObject { | ||||||
|  |     static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); | ||||||
|  |     let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || { | ||||||
|  |         let constructor = PyModule::from_code( | ||||||
|  |             py, | ||||||
|  |             indoc! {" | ||||||
|  |                 from datetime import timedelta | ||||||
|  |                  | ||||||
|  |                 def constructor(s, ms): | ||||||
|  |                     return timedelta(seconds=s, microseconds=ms) | ||||||
|  |             "}, | ||||||
|  |             "", | ||||||
|  |             "", | ||||||
|  |         ) | ||||||
|  |         .expect("This is a valid Python module") | ||||||
|  |         .getattr("constructor") | ||||||
|  |         .expect("Attribute exists"); | ||||||
|  |         Py::from(constructor) | ||||||
|  |     }); | ||||||
|  |     constructor | ||||||
|  |         .call1(py, (duration.seconds as f64, (duration.nanos as f64) / 1e3)) | ||||||
|  |         .expect("static function will not fail") | ||||||
|  | } | ||||||
							
								
								
									
										68
									
								
								betterproto-extras/src/py_any_extras.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								betterproto-extras/src/py_any_extras.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | |||||||
|  | use crate::error::Result; | ||||||
|  | use pyo3::{PyAny, Py, sync::GILOnceCell}; | ||||||
|  |  | ||||||
|  | pub trait PyAnyExtras { | ||||||
|  |     fn qualified_name(&self) -> Result<String>; | ||||||
|  |     fn qualified_class_name(&self) -> Result<String>; | ||||||
|  |     fn get_proto_meta(&self) -> Result<&PyAny>; | ||||||
|  |     fn get_class(&self, field_name: &str) -> Result<&PyAny>; | ||||||
|  |     fn create_instance(&self, field_name: &str) -> Result<&PyAny>; | ||||||
|  |     fn is_list_field(&self, field_name: &str) -> Result<bool>; | ||||||
|  |     fn oneof_group(&self, field_name: &str) -> Result<Option<String>>; | ||||||
|  |     fn py_identifier(&self) -> u64; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl PyAnyExtras for PyAny { | ||||||
|  |     fn qualified_name(&self) -> Result<String> { | ||||||
|  |         let module = self.getattr("__module__")?; | ||||||
|  |         let name = self.getattr("__name__")?; | ||||||
|  |         Ok(format!("{module}.{name}")) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn qualified_class_name(&self) -> Result<String> { | ||||||
|  |         self.getattr("__class__")?.qualified_name() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get_proto_meta(&self) -> Result<&PyAny> { | ||||||
|  |         Ok(self.getattr("_betterproto")?) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn get_class(&self, field_name: &str) -> Result<&PyAny> { | ||||||
|  |         let cls = self.getattr("cls_by_field")?.get_item(field_name)?; | ||||||
|  |         Ok(cls) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn create_instance(&self, field_name: &str) -> Result<&PyAny> { | ||||||
|  |         Ok(self.get_class(field_name)?.call0()?) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn is_list_field(&self, field_name: &str) -> Result<bool> { | ||||||
|  |         let cls = self.getattr("default_gen")?.get_item(field_name)?; | ||||||
|  |         let module = cls.getattr("__module__")?; | ||||||
|  |         let name = cls.getattr("__name__")?; | ||||||
|  |         Ok(module.to_string() == "builtins" && name.to_string() == "list") | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn oneof_group(&self, field_name: &str) -> Result<Option<String>> { | ||||||
|  |         let opt = self | ||||||
|  |             .getattr("oneof_group_by_field")? | ||||||
|  |             .call_method1("get", (field_name,))? | ||||||
|  |             .extract()?; | ||||||
|  |         Ok(opt) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn py_identifier(&self) -> u64 { | ||||||
|  |         static FUN_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); | ||||||
|  |         let py = self.py(); | ||||||
|  |         let fun = FUN_CACHE.get_or_init(py, || { | ||||||
|  |             let fun = py | ||||||
|  |                 .eval("id", None, None) | ||||||
|  |                 .expect("This is a valid Python expression"); | ||||||
|  |             Py::from(fun) | ||||||
|  |         }); | ||||||
|  |         fun.call1(py, (self,)) | ||||||
|  |             .expect("Identity function is callable") | ||||||
|  |             .extract::<u64>(py) | ||||||
|  |             .expect("Identity function always returns an integer") | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								example.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | # dev tests | ||||||
|  | # to be deleted later | ||||||
|  |  | ||||||
|  | import betterproto | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from typing import Dict, List, Optional | ||||||
|  |  | ||||||
|  | @dataclass(repr=False) | ||||||
|  | class Baz(betterproto.Message): | ||||||
|  |     a: float = betterproto.float_field(1, group = "x") | ||||||
|  |     b: int = betterproto.int64_field(2, group = "x") | ||||||
|  |     c: float = betterproto.float_field(3, group = "y") | ||||||
|  |     d: int = betterproto.int64_field(4, group = "y") | ||||||
|  |     e: Optional[int] = betterproto.int32_field(5, group = "_e", optional = True) | ||||||
|  |  | ||||||
|  | @dataclass(repr=False) | ||||||
|  | class Foo(betterproto.Message): | ||||||
|  |     x: int = betterproto.int32_field(1) | ||||||
|  |     y: float = betterproto.double_field(2) | ||||||
|  |     z: List[Baz] = betterproto.message_field(3) | ||||||
|  |  | ||||||
|  | class Enm(betterproto.Enum): | ||||||
|  |     A = 0 | ||||||
|  |     B = 1 | ||||||
|  |     C = 2 | ||||||
|  |  | ||||||
|  | @dataclass(repr=False) | ||||||
|  | class Bar(betterproto.Message): | ||||||
|  |     foo1: Foo = betterproto.message_field(1) | ||||||
|  |     foo2: Foo = betterproto.message_field(2) | ||||||
|  |     packed: List[int] = betterproto.int64_field(3) | ||||||
|  |     enm: Enm = betterproto.enum_field(4) | ||||||
|  |     map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL) | ||||||
|  |     maybe: Optional[bool] = betterproto.message_field(6, wraps=betterproto.TYPE_BOOL) | ||||||
|  |     bts: bytes = betterproto.bytes_field(7) | ||||||
|  |  | ||||||
|  | # Serialization has not been changed yet. So nothing unusual here | ||||||
|  | buffer = bytes( | ||||||
|  |     Bar( | ||||||
|  |         foo1=Foo(1, 2.34), | ||||||
|  |         foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]), | ||||||
|  |         packed=[5, 3, 1], | ||||||
|  |         enm=Enm.B, | ||||||
|  |         map={ | ||||||
|  |             1: True, | ||||||
|  |             42: False | ||||||
|  |         }, | ||||||
|  |         maybe=True, | ||||||
|  |         bts=b'Hi There!' | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # Native deserialization happening here | ||||||
|  | bar = Bar().parse(buffer) | ||||||
|  | print(bar) | ||||||
							
								
								
									
										2620
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2620
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,6 +1,6 @@ | |||||||
| [tool.poetry] | [tool.poetry] | ||||||
| name = "betterproto" | name = "betterproto" | ||||||
| version = "2.0.0b4" | version = "2.0.0b6" | ||||||
| description = "A better Protobuf / gRPC generator & library" | description = "A better Protobuf / gRPC generator & library" | ||||||
| authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"] | authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"] | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| @@ -12,22 +12,23 @@ packages = [ | |||||||
| ] | ] | ||||||
|  |  | ||||||
| [tool.poetry.dependencies] | [tool.poetry.dependencies] | ||||||
| python = ">=3.6.2,<4.0" | python = "^3.7" | ||||||
| black = { version = ">=19.3b0", optional = true } | black = { version = ">=23.1.0", optional = true } | ||||||
| dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } |  | ||||||
| grpclib = "^0.4.1" | grpclib = "^0.4.1" | ||||||
| jinja2 = { version = "^2.11.2", optional = true } | importlib-metadata = { version = ">=1.6.0", python = "<3.8" } | ||||||
|  | jinja2 = { version = ">=3.0.3", optional = true } | ||||||
| python-dateutil = "^2.8" | python-dateutil = "^2.8" | ||||||
|  | isort = {version = "^5.11.5", optional = true} | ||||||
|  | betterproto-extras = { path = "betterproto-extras" } | ||||||
|  |  | ||||||
| [tool.poetry.dev-dependencies] | [tool.poetry.dev-dependencies] | ||||||
| asv = "^0.4.2" | asv = "^0.4.2" | ||||||
| black = "^21.11b0" |  | ||||||
| bpython = "^0.19" | bpython = "^0.19" | ||||||
| grpcio-tools = "^1.40.0" | grpcio-tools = "^1.54.2" | ||||||
| jinja2 = "^2.11.2" | jinja2 = ">=3.0.3" | ||||||
| mypy = "^0.930" | mypy = "^0.930" | ||||||
| poethepoet = ">=0.9.0" | poethepoet = ">=0.9.0" | ||||||
| protobuf = "^3.12.2" | protobuf = "^4.21.6" | ||||||
| pytest = "^6.2.5" | pytest = "^6.2.5" | ||||||
| pytest-asyncio = "^0.12.0" | pytest-asyncio = "^0.12.0" | ||||||
| pytest-cov = "^2.9.0" | pytest-cov = "^2.9.0" | ||||||
| @@ -36,13 +37,15 @@ sphinx = "3.1.2" | |||||||
| sphinx-rtd-theme = "0.5.0" | sphinx-rtd-theme = "0.5.0" | ||||||
| tomlkit = "^0.7.0" | tomlkit = "^0.7.0" | ||||||
| tox = "^3.15.1" | tox = "^3.15.1" | ||||||
|  | pre-commit = "^2.17.0" | ||||||
|  | pydantic = ">=1.8.0" | ||||||
|  |  | ||||||
|  |  | ||||||
| [tool.poetry.scripts] | [tool.poetry.scripts] | ||||||
| protoc-gen-python_betterproto = "betterproto.plugin:main" | protoc-gen-python_betterproto = "betterproto.plugin:main" | ||||||
|  |  | ||||||
| [tool.poetry.extras] | [tool.poetry.extras] | ||||||
| compiler = ["black", "jinja2"] | compiler = ["black", "isort", "jinja2"] | ||||||
|  |  | ||||||
|  |  | ||||||
| # Dev workflow tasks | # Dev workflow tasks | ||||||
| @@ -60,7 +63,7 @@ cmd  = "mypy src --ignore-missing-imports" | |||||||
| help = "Check types with mypy" | help = "Check types with mypy" | ||||||
|  |  | ||||||
| [tool.poe.tasks.format] | [tool.poe.tasks.format] | ||||||
| cmd  = "black . --exclude tests/output_" | cmd  = "black . --exclude tests/output_ --target-version py310" | ||||||
| help = "Apply black formatting to source code" | help = "Apply black formatting to source code" | ||||||
|  |  | ||||||
| [tool.poe.tasks.docs] | [tool.poe.tasks.docs] | ||||||
| @@ -85,8 +88,8 @@ protoc | |||||||
|     --plugin=protoc-gen-custom=src/betterproto/plugin/main.py |     --plugin=protoc-gen-custom=src/betterproto/plugin/main.py | ||||||
|     --custom_opt=INCLUDE_GOOGLE |     --custom_opt=INCLUDE_GOOGLE | ||||||
|     --custom_out=src/betterproto/lib |     --custom_out=src/betterproto/lib | ||||||
|     -I /usr/local/include/ |     -I C:\\work\\include | ||||||
|     /usr/local/include/google/protobuf/**/*.proto |     C:\\work\\include\\google\\protobuf\\**\\*.proto | ||||||
| """ | """ | ||||||
| help = "Regenerate the types in betterproto.lib.google" | help = "Regenerate the types in betterproto.lib.google" | ||||||
|  |  | ||||||
| @@ -97,12 +100,30 @@ shell = "poe generate && tox" | |||||||
| help = "Run tests with multiple pythons" | help = "Run tests with multiple pythons" | ||||||
|  |  | ||||||
| [tool.poe.tasks.check-style] | [tool.poe.tasks.check-style] | ||||||
| cmd = "black . --check --diff --exclude tests/output_" | cmd = "black . --check --diff" | ||||||
| help = "Check if code style is correct" | help = "Check if code style is correct" | ||||||
|  |  | ||||||
|  | [tool.isort] | ||||||
|  | py_version = 37 | ||||||
|  | profile = "black" | ||||||
|  | force_single_line = false | ||||||
|  | combine_as_imports = true | ||||||
|  | lines_after_imports = 2 | ||||||
|  | include_trailing_comma = true | ||||||
|  | force_grid_wrap = 2 | ||||||
|  | src_paths = ["src", "tests"] | ||||||
|  |  | ||||||
| [tool.black] | [tool.black] | ||||||
| target-version = ['py36'] | target-version = ['py37'] | ||||||
|  |  | ||||||
|  | [tool.doc8] | ||||||
|  | paths = ["docs"] | ||||||
|  | max_line_length = 88 | ||||||
|  |  | ||||||
|  | [tool.doc8.ignore_path_errors] | ||||||
|  | "docs/migrating.rst" = [ | ||||||
|  |     "D001",  # contains table which is longer than 88 characters long | ||||||
|  | ] | ||||||
|  |  | ||||||
| [tool.coverage.run] | [tool.coverage.run] | ||||||
| omit = ["betterproto/tests/*"] | omit = ["betterproto/tests/*"] | ||||||
| @@ -111,7 +132,7 @@ omit = ["betterproto/tests/*"] | |||||||
| legacy_tox_ini = """ | legacy_tox_ini = """ | ||||||
| [tox] | [tox] | ||||||
| isolated_build = true | isolated_build = true | ||||||
| envlist = py36, py37, py38 | envlist = py37, py38, py310 | ||||||
|  |  | ||||||
| [testenv] | [testenv] | ||||||
| whitelist_externals = poetry | whitelist_externals = poetry | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,7 +1,12 @@ | |||||||
| from typing import TYPE_CHECKING, TypeVar | from typing import ( | ||||||
|  |     TYPE_CHECKING, | ||||||
|  |     TypeVar, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from grpclib._typing import IProtoMessage |     from grpclib._typing import IProtoMessage | ||||||
|  |  | ||||||
|     from . import Message |     from . import Message | ||||||
|  |  | ||||||
| # Bound type variable to allow methods to return `self` of subclasses | # Bound type variable to allow methods to return `self` of subclasses | ||||||
|   | |||||||
| @@ -1,3 +1,7 @@ | |||||||
| from pkg_resources import get_distribution | try: | ||||||
|  |     from importlib import metadata | ||||||
|  | except ImportError:  # for Python<3.8 | ||||||
|  |     import importlib_metadata as metadata  # type: ignore | ||||||
|  |  | ||||||
| __version__ = get_distribution("betterproto").version |  | ||||||
|  | __version__ = metadata.version("betterproto") | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| import keyword | import keyword | ||||||
| import re | import re | ||||||
|  |  | ||||||
|  |  | ||||||
| # Word delimiters and symbols that will not be preserved when re-casing. | # Word delimiters and symbols that will not be preserved when re-casing. | ||||||
| # language=PythonRegExp | # language=PythonRegExp | ||||||
| SYMBOLS = "[^a-zA-Z0-9]*" | SYMBOLS = "[^a-zA-Z0-9]*" | ||||||
|   | |||||||
| @@ -1,11 +1,18 @@ | |||||||
| import os | import os | ||||||
| import re | import re | ||||||
| from typing import Dict, List, Set, Tuple, Type | from typing import ( | ||||||
|  |     Dict, | ||||||
|  |     List, | ||||||
|  |     Set, | ||||||
|  |     Tuple, | ||||||
|  |     Type, | ||||||
|  | ) | ||||||
|  |  | ||||||
| from ..casing import safe_snake_case | from ..casing import safe_snake_case | ||||||
| from ..lib.google import protobuf as google_protobuf | from ..lib.google import protobuf as google_protobuf | ||||||
| from .naming import pythonize_class_name | from .naming import pythonize_class_name | ||||||
|  |  | ||||||
|  |  | ||||||
| WRAPPER_TYPES: Dict[str, Type] = { | WRAPPER_TYPES: Dict[str, Type] = { | ||||||
|     ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, |     ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, | ||||||
|     ".google.protobuf.FloatValue": google_protobuf.FloatValue, |     ".google.protobuf.FloatValue": google_protobuf.FloatValue, | ||||||
| @@ -36,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_type_reference( | def get_type_reference( | ||||||
|     package: str, imports: set, source_type: str, unwrap: bool = True |     *, package: str, imports: set, source_type: str, unwrap: bool = True | ||||||
| ) -> str: | ) -> str: | ||||||
|     """ |     """ | ||||||
|     Return a Python type name for a proto type reference. Adds the import if |     Return a Python type name for a proto type reference. Adds the import if | ||||||
|   | |||||||
| @@ -15,17 +15,22 @@ from typing import ( | |||||||
|  |  | ||||||
| import grpclib.const | import grpclib.const | ||||||
|  |  | ||||||
| from .._types import ST, T |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from grpclib.client import Channel |     from grpclib.client import Channel | ||||||
|     from grpclib.metadata import Deadline |     from grpclib.metadata import Deadline | ||||||
|  |  | ||||||
|  |     from .._types import ( | ||||||
|  |         ST, | ||||||
|  |         IProtoMessage, | ||||||
|  |         Message, | ||||||
|  |         T, | ||||||
|  |     ) | ||||||
|  |  | ||||||
| _Value = Union[str, bytes] |  | ||||||
| _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] | Value = Union[str, bytes] | ||||||
| _MessageLike = Union[T, ST] | MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] | ||||||
| _MessageSource = Union[Iterable[ST], AsyncIterable[ST]] | MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServiceStub(ABC): | class ServiceStub(ABC): | ||||||
| @@ -39,7 +44,7 @@ class ServiceStub(ABC): | |||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[_MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.channel = channel |         self.channel = channel | ||||||
|         self.timeout = timeout |         self.timeout = timeout | ||||||
| @@ -50,7 +55,7 @@ class ServiceStub(ABC): | |||||||
|         self, |         self, | ||||||
|         timeout: Optional[float], |         timeout: Optional[float], | ||||||
|         deadline: Optional["Deadline"], |         deadline: Optional["Deadline"], | ||||||
|         metadata: Optional[_MetadataLike], |         metadata: Optional[MetadataLike], | ||||||
|     ): |     ): | ||||||
|         return { |         return { | ||||||
|             "timeout": self.timeout if timeout is None else timeout, |             "timeout": self.timeout if timeout is None else timeout, | ||||||
| @@ -61,13 +66,13 @@ class ServiceStub(ABC): | |||||||
|     async def _unary_unary( |     async def _unary_unary( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request: _MessageLike, |         request: "IProtoMessage", | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[_MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> T: |     ) -> "T": | ||||||
|         """Make a unary request and return the response.""" |         """Make a unary request and return the response.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -84,13 +89,13 @@ class ServiceStub(ABC): | |||||||
|     async def _unary_stream( |     async def _unary_stream( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request: _MessageLike, |         request: "IProtoMessage", | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[_MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> AsyncIterator[T]: |     ) -> AsyncIterator["T"]: | ||||||
|         """Make a unary request and return the stream response iterator.""" |         """Make a unary request and return the stream response iterator.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -106,14 +111,14 @@ class ServiceStub(ABC): | |||||||
|     async def _stream_unary( |     async def _stream_unary( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request_iterator: _MessageSource, |         request_iterator: MessageSource, | ||||||
|         request_type: Type[ST], |         request_type: Type["IProtoMessage"], | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[_MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> T: |     ) -> "T": | ||||||
|         """Make a stream request and return the response.""" |         """Make a stream request and return the response.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -130,14 +135,14 @@ class ServiceStub(ABC): | |||||||
|     async def _stream_stream( |     async def _stream_stream( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request_iterator: _MessageSource, |         request_iterator: MessageSource, | ||||||
|         request_type: Type[ST], |         request_type: Type["IProtoMessage"], | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[_MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> AsyncIterator[T]: |     ) -> AsyncIterator["T"]: | ||||||
|         """ |         """ | ||||||
|         Make a stream request and return an AsyncIterator to iterate over response |         Make a stream request and return an AsyncIterator to iterate over response | ||||||
|         messages. |         messages. | ||||||
| @@ -161,7 +166,7 @@ class ServiceStub(ABC): | |||||||
|                 raise |                 raise | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     async def _send_messages(stream, messages: _MessageSource): |     async def _send_messages(stream, messages: MessageSource): | ||||||
|         if isinstance(messages, AsyncIterable): |         if isinstance(messages, AsyncIterable): | ||||||
|             async for message in messages: |             async for message in messages: | ||||||
|                 await stream.send_message(message) |                 await stream.send_message(message) | ||||||
|   | |||||||
| @@ -1,6 +1,10 @@ | |||||||
| from abc import ABC | from abc import ABC | ||||||
| from collections.abc import AsyncIterable | from collections.abc import AsyncIterable | ||||||
| from typing import Callable, Any, Dict | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  |     Dict, | ||||||
|  | ) | ||||||
|  |  | ||||||
| import grpclib | import grpclib | ||||||
| import grpclib.server | import grpclib.server | ||||||
| @@ -15,10 +19,9 @@ class ServiceBase(ABC): | |||||||
|         self, |         self, | ||||||
|         handler: Callable, |         handler: Callable, | ||||||
|         stream: grpclib.server.Stream, |         stream: grpclib.server.Stream, | ||||||
|         request_kwargs: Dict[str, Any], |         request: Any, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         response_iter = handler(request) | ||||||
|         response_iter = handler(**request_kwargs) |  | ||||||
|         # check if response is actually an AsyncIterator |         # check if response is actually an AsyncIterator | ||||||
|         # this might be false if the method just returns without |         # this might be false if the method just returns without | ||||||
|         # yielding at least once |         # yielding at least once | ||||||
|   | |||||||
| @@ -1,5 +1,13 @@ | |||||||
| import asyncio | import asyncio | ||||||
| from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union | from typing import ( | ||||||
|  |     AsyncIterable, | ||||||
|  |     AsyncIterator, | ||||||
|  |     Iterable, | ||||||
|  |     Optional, | ||||||
|  |     TypeVar, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| T = TypeVar("T") | T = TypeVar("T") | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,14 +1,17 @@ | |||||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||||
| # sources: google/protobuf/compiler/plugin.proto | # sources: google/protobuf/compiler/plugin.proto | ||||||
| # plugin: python-betterproto | # plugin: python-betterproto | ||||||
|  | # This file has been @generated | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import List | from typing import List | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from betterproto.grpc.grpclib_server import ServiceBase | import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf | ||||||
|  |  | ||||||
|  |  | ||||||
| class CodeGeneratorResponseFeature(betterproto.Enum): | class CodeGeneratorResponseFeature(betterproto.Enum): | ||||||
|  |     """Sync with code_generator.h.""" | ||||||
|  |  | ||||||
|     FEATURE_NONE = 0 |     FEATURE_NONE = 0 | ||||||
|     FEATURE_PROTO3_OPTIONAL = 1 |     FEATURE_PROTO3_OPTIONAL = 1 | ||||||
|  |  | ||||||
| @@ -20,54 +23,69 @@ class Version(betterproto.Message): | |||||||
|     major: int = betterproto.int32_field(1) |     major: int = betterproto.int32_field(1) | ||||||
|     minor: int = betterproto.int32_field(2) |     minor: int = betterproto.int32_field(2) | ||||||
|     patch: int = betterproto.int32_field(3) |     patch: int = betterproto.int32_field(3) | ||||||
|     # A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should |  | ||||||
|     # be empty for mainline stable releases. |  | ||||||
|     suffix: str = betterproto.string_field(4) |     suffix: str = betterproto.string_field(4) | ||||||
|  |     """ | ||||||
|  |     A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should | ||||||
|  |     be empty for mainline stable releases. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass(eq=False, repr=False) | @dataclass(eq=False, repr=False) | ||||||
| class CodeGeneratorRequest(betterproto.Message): | class CodeGeneratorRequest(betterproto.Message): | ||||||
|     """An encoded CodeGeneratorRequest is written to the plugin's stdin.""" |     """An encoded CodeGeneratorRequest is written to the plugin's stdin.""" | ||||||
|  |  | ||||||
|     # The .proto files that were explicitly listed on the command-line.  The code |  | ||||||
|     # generator should generate code only for these files.  Each file's |  | ||||||
|     # descriptor will be included in proto_file, below. |  | ||||||
|     file_to_generate: List[str] = betterproto.string_field(1) |     file_to_generate: List[str] = betterproto.string_field(1) | ||||||
|     # The generator parameter passed on the command-line. |     """ | ||||||
|  |     The .proto files that were explicitly listed on the command-line.  The code | ||||||
|  |     generator should generate code only for these files.  Each file's | ||||||
|  |     descriptor will be included in proto_file, below. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     parameter: str = betterproto.string_field(2) |     parameter: str = betterproto.string_field(2) | ||||||
|     # FileDescriptorProtos for all files in files_to_generate and everything they |     """The generator parameter passed on the command-line.""" | ||||||
|     # import.  The files will appear in topological order, so each file appears |  | ||||||
|     # before any file that imports it. protoc guarantees that all proto_files |  | ||||||
|     # will be written after the fields above, even though this is not technically |  | ||||||
|     # guaranteed by the protobuf wire format.  This theoretically could allow a |  | ||||||
|     # plugin to stream in the FileDescriptorProtos and handle them one by one |  | ||||||
|     # rather than read the entire set into memory at once.  However, as of this |  | ||||||
|     # writing, this is not similarly optimized on protoc's end -- it will store |  | ||||||
|     # all fields in memory at once before sending them to the plugin. Type names |  | ||||||
|     # of fields and extensions in the FileDescriptorProto are always fully |  | ||||||
|     # qualified. |  | ||||||
|     proto_file: List[ |     proto_file: List[ | ||||||
|         "betterproto_lib_google_protobuf.FileDescriptorProto" |         "betterproto_lib_google_protobuf.FileDescriptorProto" | ||||||
|     ] = betterproto.message_field(15) |     ] = betterproto.message_field(15) | ||||||
|     # The version number of protocol compiler. |     """ | ||||||
|  |     FileDescriptorProtos for all files in files_to_generate and everything they | ||||||
|  |     import.  The files will appear in topological order, so each file appears | ||||||
|  |     before any file that imports it. protoc guarantees that all proto_files | ||||||
|  |     will be written after the fields above, even though this is not technically | ||||||
|  |     guaranteed by the protobuf wire format.  This theoretically could allow a | ||||||
|  |     plugin to stream in the FileDescriptorProtos and handle them one by one | ||||||
|  |     rather than read the entire set into memory at once.  However, as of this | ||||||
|  |     writing, this is not similarly optimized on protoc's end -- it will store | ||||||
|  |     all fields in memory at once before sending them to the plugin. Type names | ||||||
|  |     of fields and extensions in the FileDescriptorProto are always fully | ||||||
|  |     qualified. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     compiler_version: "Version" = betterproto.message_field(3) |     compiler_version: "Version" = betterproto.message_field(3) | ||||||
|  |     """The version number of protocol compiler.""" | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass(eq=False, repr=False) | @dataclass(eq=False, repr=False) | ||||||
| class CodeGeneratorResponse(betterproto.Message): | class CodeGeneratorResponse(betterproto.Message): | ||||||
|     """The plugin writes an encoded CodeGeneratorResponse to stdout.""" |     """The plugin writes an encoded CodeGeneratorResponse to stdout.""" | ||||||
|  |  | ||||||
|     # Error message.  If non-empty, code generation failed.  The plugin process |  | ||||||
|     # should exit with status code zero even if it reports an error in this way. |  | ||||||
|     # This should be used to indicate errors in .proto files which prevent the |  | ||||||
|     # code generator from generating correct code.  Errors which indicate a |  | ||||||
|     # problem in protoc itself -- such as the input CodeGeneratorRequest being |  | ||||||
|     # unparseable -- should be reported by writing a message to stderr and |  | ||||||
|     # exiting with a non-zero status code. |  | ||||||
|     error: str = betterproto.string_field(1) |     error: str = betterproto.string_field(1) | ||||||
|     # A bitmask of supported features that the code generator supports. This is a |     """ | ||||||
|     # bitwise "or" of values from the Feature enum. |     Error message.  If non-empty, code generation failed.  The plugin process | ||||||
|  |     should exit with status code zero even if it reports an error in this way. | ||||||
|  |     This should be used to indicate errors in .proto files which prevent the | ||||||
|  |     code generator from generating correct code.  Errors which indicate a | ||||||
|  |     problem in protoc itself -- such as the input CodeGeneratorRequest being | ||||||
|  |     unparseable -- should be reported by writing a message to stderr and | ||||||
|  |     exiting with a non-zero status code. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     supported_features: int = betterproto.uint64_field(2) |     supported_features: int = betterproto.uint64_field(2) | ||||||
|  |     """ | ||||||
|  |     A bitmask of supported features that the code generator supports. This is a | ||||||
|  |     bitwise "or" of values from the Feature enum. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15) |     file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -75,54 +93,60 @@ class CodeGeneratorResponse(betterproto.Message): | |||||||
| class CodeGeneratorResponseFile(betterproto.Message): | class CodeGeneratorResponseFile(betterproto.Message): | ||||||
|     """Represents a single generated file.""" |     """Represents a single generated file.""" | ||||||
|  |  | ||||||
|     # The file name, relative to the output directory.  The name must not contain |  | ||||||
|     # "." or ".." components and must be relative, not be absolute (so, the file |  | ||||||
|     # cannot lie outside the output directory).  "/" must be used as the path |  | ||||||
|     # separator, not "\". If the name is omitted, the content will be appended to |  | ||||||
|     # the previous file.  This allows the generator to break large files into |  | ||||||
|     # small chunks, and allows the generated text to be streamed back to protoc |  | ||||||
|     # so that large files need not reside completely in memory at one time.  Note |  | ||||||
|     # that as of this writing protoc does not optimize for this -- it will read |  | ||||||
|     # the entire CodeGeneratorResponse before writing files to disk. |  | ||||||
|     name: str = betterproto.string_field(1) |     name: str = betterproto.string_field(1) | ||||||
|     # If non-empty, indicates that the named file should already exist, and the |     """ | ||||||
|     # content here is to be inserted into that file at a defined insertion point. |     The file name, relative to the output directory.  The name must not contain | ||||||
|     # This feature allows a code generator to extend the output produced by |     "." or ".." components and must be relative, not be absolute (so, the file | ||||||
|     # another code generator.  The original generator may provide insertion |     cannot lie outside the output directory).  "/" must be used as the path | ||||||
|     # points by placing special annotations in the file that look like: |     separator, not "\". If the name is omitted, the content will be appended to | ||||||
|     # @@protoc_insertion_point(NAME) The annotation can have arbitrary text |     the previous file.  This allows the generator to break large files into | ||||||
|     # before and after it on the line, which allows it to be placed in a comment. |     small chunks, and allows the generated text to be streamed back to protoc | ||||||
|     # NAME should be replaced with an identifier naming the point -- this is what |     so that large files need not reside completely in memory at one time.  Note | ||||||
|     # other generators will use as the insertion_point.  Code inserted at this |     that as of this writing protoc does not optimize for this -- it will read | ||||||
|     # point will be placed immediately above the line containing the insertion |     the entire CodeGeneratorResponse before writing files to disk. | ||||||
|     # point (thus multiple insertions to the same point will come out in the |     """ | ||||||
|     # order they were added). The double-@ is intended to make it unlikely that |  | ||||||
|     # the generated code could contain things that look like insertion points by |  | ||||||
|     # accident. For example, the C++ code generator places the following line in |  | ||||||
|     # the .pb.h files that it generates:   // |  | ||||||
|     # @@protoc_insertion_point(namespace_scope) This line appears within the |  | ||||||
|     # scope of the file's package namespace, but outside of any particular class. |  | ||||||
|     # Another plugin can then specify the insertion_point "namespace_scope" to |  | ||||||
|     # generate additional classes or other declarations that should be placed in |  | ||||||
|     # this scope. Note that if the line containing the insertion point begins |  | ||||||
|     # with whitespace, the same whitespace will be added to every line of the |  | ||||||
|     # inserted text.  This is useful for languages like Python, where indentation |  | ||||||
|     # matters.  In these languages, the insertion point comment should be |  | ||||||
|     # indented the same amount as any inserted code will need to be in order to |  | ||||||
|     # work correctly in that context. The code generator that generates the |  | ||||||
|     # initial file and the one which inserts into it must both run as part of a |  | ||||||
|     # single invocation of protoc. Code generators are executed in the order in |  | ||||||
|     # which they appear on the command line. If |insertion_point| is present, |  | ||||||
|     # |name| must also be present. |  | ||||||
|     insertion_point: str = betterproto.string_field(2) |     insertion_point: str = betterproto.string_field(2) | ||||||
|     # The file contents. |     """ | ||||||
|  |     If non-empty, indicates that the named file should already exist, and the | ||||||
|  |     content here is to be inserted into that file at a defined insertion point. | ||||||
|  |     This feature allows a code generator to extend the output produced by | ||||||
|  |     another code generator.  The original generator may provide insertion | ||||||
|  |     points by placing special annotations in the file that look like: | ||||||
|  |     @@protoc_insertion_point(NAME) The annotation can have arbitrary text | ||||||
|  |     before and after it on the line, which allows it to be placed in a comment. | ||||||
|  |     NAME should be replaced with an identifier naming the point -- this is what | ||||||
|  |     other generators will use as the insertion_point.  Code inserted at this | ||||||
|  |     point will be placed immediately above the line containing the insertion | ||||||
|  |     point (thus multiple insertions to the same point will come out in the | ||||||
|  |     order they were added). The double-@ is intended to make it unlikely that | ||||||
|  |     the generated code could contain things that look like insertion points by | ||||||
|  |     accident. For example, the C++ code generator places the following line in | ||||||
|  |     the .pb.h files that it generates:   // | ||||||
|  |     @@protoc_insertion_point(namespace_scope) This line appears within the | ||||||
|  |     scope of the file's package namespace, but outside of any particular class. | ||||||
|  |     Another plugin can then specify the insertion_point "namespace_scope" to | ||||||
|  |     generate additional classes or other declarations that should be placed in | ||||||
|  |     this scope. Note that if the line containing the insertion point begins | ||||||
|  |     with whitespace, the same whitespace will be added to every line of the | ||||||
|  |     inserted text.  This is useful for languages like Python, where indentation | ||||||
|  |     matters.  In these languages, the insertion point comment should be | ||||||
|  |     indented the same amount as any inserted code will need to be in order to | ||||||
|  |     work correctly in that context. The code generator that generates the | ||||||
|  |     initial file and the one which inserts into it must both run as part of a | ||||||
|  |     single invocation of protoc. Code generators are executed in the order in | ||||||
|  |     which they appear on the command line. If |insertion_point| is present, | ||||||
|  |     |name| must also be present. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     content: str = betterproto.string_field(15) |     content: str = betterproto.string_field(15) | ||||||
|     # Information describing the file content being inserted. If an insertion |     """The file contents.""" | ||||||
|     # point is used, this information will be appropriately offset and inserted |  | ||||||
|     # into the code generation metadata for the generated files. |  | ||||||
|     generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = ( |     generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = ( | ||||||
|         betterproto.message_field(16) |         betterproto.message_field(16) | ||||||
|     ) |     ) | ||||||
|  |     """ | ||||||
|  |     Information describing the file content being inserted. If an insertion | ||||||
| import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf |     point is used, this information will be appropriately offset and inserted | ||||||
|  |     into the code generation metadata for the generated files. | ||||||
|  |     """ | ||||||
|   | |||||||
| @@ -1,8 +1,10 @@ | |||||||
| import os.path | import os.path | ||||||
|  |  | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     # betterproto[compiler] specific dependencies |     # betterproto[compiler] specific dependencies | ||||||
|     import black |     import black | ||||||
|  |     import isort.api | ||||||
|     import jinja2 |     import jinja2 | ||||||
| except ImportError as err: | except ImportError as err: | ||||||
|     print( |     print( | ||||||
| @@ -19,7 +21,6 @@ from .models import OutputTemplate | |||||||
|  |  | ||||||
|  |  | ||||||
| def outputfile_compiler(output_file: OutputTemplate) -> str: | def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||||
|  |  | ||||||
|     templates_folder = os.path.abspath( |     templates_folder = os.path.abspath( | ||||||
|         os.path.join(os.path.dirname(__file__), "..", "templates") |         os.path.join(os.path.dirname(__file__), "..", "templates") | ||||||
|     ) |     ) | ||||||
| @@ -31,7 +32,19 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | |||||||
|     ) |     ) | ||||||
|     template = env.get_template("template.py.j2") |     template = env.get_template("template.py.j2") | ||||||
|  |  | ||||||
|  |     code = template.render(output_file=output_file) | ||||||
|  |     code = isort.api.sort_code_string( | ||||||
|  |         code=code, | ||||||
|  |         show_diff=False, | ||||||
|  |         py_version=37, | ||||||
|  |         profile="black", | ||||||
|  |         combine_as_imports=True, | ||||||
|  |         lines_after_imports=2, | ||||||
|  |         quiet=True, | ||||||
|  |         force_grid_wrap=2, | ||||||
|  |         known_third_party=["grpclib", "betterproto"], | ||||||
|  |     ) | ||||||
|     return black.format_str( |     return black.format_str( | ||||||
|         template.render(output_file=output_file), |         src_contents=code, | ||||||
|         mode=black.Mode(), |         mode=black.Mode(), | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -7,9 +7,8 @@ from betterproto.lib.google.protobuf.compiler import ( | |||||||
|     CodeGeneratorRequest, |     CodeGeneratorRequest, | ||||||
|     CodeGeneratorResponse, |     CodeGeneratorResponse, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| from betterproto.plugin.parser import generate_code |  | ||||||
| from betterproto.plugin.models import monkey_patch_oneof_index | from betterproto.plugin.models import monkey_patch_oneof_index | ||||||
|  | from betterproto.plugin.parser import generate_code | ||||||
|  |  | ||||||
|  |  | ||||||
| def main() -> None: | def main() -> None: | ||||||
|   | |||||||
| @@ -31,6 +31,23 @@ reference to `A` to `B`'s `fields` attribute. | |||||||
|  |  | ||||||
|  |  | ||||||
| import builtins | import builtins | ||||||
|  | import re | ||||||
|  | import textwrap | ||||||
|  | from dataclasses import ( | ||||||
|  |     dataclass, | ||||||
|  |     field, | ||||||
|  | ) | ||||||
|  | from typing import ( | ||||||
|  |     Dict, | ||||||
|  |     Iterable, | ||||||
|  |     Iterator, | ||||||
|  |     List, | ||||||
|  |     Optional, | ||||||
|  |     Set, | ||||||
|  |     Type, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from betterproto import which_one_of | from betterproto import which_one_of | ||||||
| from betterproto.casing import sanitize_name | from betterproto.casing import sanitize_name | ||||||
| @@ -46,23 +63,20 @@ from betterproto.compile.naming import ( | |||||||
| from betterproto.lib.google.protobuf import ( | from betterproto.lib.google.protobuf import ( | ||||||
|     DescriptorProto, |     DescriptorProto, | ||||||
|     EnumDescriptorProto, |     EnumDescriptorProto, | ||||||
|     FileDescriptorProto, |  | ||||||
|     MethodDescriptorProto, |  | ||||||
|     Field, |     Field, | ||||||
|     FieldDescriptorProto, |     FieldDescriptorProto, | ||||||
|     FieldDescriptorProtoType, |  | ||||||
|     FieldDescriptorProtoLabel, |     FieldDescriptorProtoLabel, | ||||||
|  |     FieldDescriptorProtoType, | ||||||
|  |     FileDescriptorProto, | ||||||
|  |     MethodDescriptorProto, | ||||||
| ) | ) | ||||||
| from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest | from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest | ||||||
|  |  | ||||||
|  |  | ||||||
| import re |  | ||||||
| import textwrap |  | ||||||
| from dataclasses import dataclass, field |  | ||||||
| from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union |  | ||||||
|  |  | ||||||
| from ..casing import sanitize_name | from ..casing import sanitize_name | ||||||
| from ..compile.importing import get_type_reference, parse_source_type_name | from ..compile.importing import ( | ||||||
|  |     get_type_reference, | ||||||
|  |     parse_source_type_name, | ||||||
|  | ) | ||||||
| from ..compile.naming import ( | from ..compile.naming import ( | ||||||
|     pythonize_class_name, |     pythonize_class_name, | ||||||
|     pythonize_field_name, |     pythonize_field_name, | ||||||
| @@ -128,12 +142,12 @@ def monkey_patch_oneof_index(): | |||||||
|             "betterproto" |             "betterproto" | ||||||
|         ], |         ], | ||||||
|         "group", |         "group", | ||||||
|         "oneof_index", |         "_oneof_index", | ||||||
|     ) |     ) | ||||||
|     object.__setattr__( |     object.__setattr__( | ||||||
|         Field.__dataclass_fields__["oneof_index"].metadata["betterproto"], |         Field.__dataclass_fields__["oneof_index"].metadata["betterproto"], | ||||||
|         "group", |         "group", | ||||||
|         "oneof_index", |         "_oneof_index", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -147,17 +161,13 @@ def get_comment( | |||||||
|                 sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent |                 sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             if path[-2] == 2 and path[-4] != 6: |             # This is a field, message, enum, service, or method | ||||||
|                 # This is a field |             if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: | ||||||
|                 return f"{pad}# " + f"\n{pad}# ".join(lines) |                 lines[0] = lines[0].strip('"') | ||||||
|  |                 return f'{pad}"""{lines[0]}"""' | ||||||
|             else: |             else: | ||||||
|                 # This is a message, enum, service, or method |                 joined = f"\n{pad}".join(lines) | ||||||
|                 if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: |                 return f'{pad}"""\n{pad}{joined}\n{pad}"""' | ||||||
|                     lines[0] = lines[0].strip('"') |  | ||||||
|                     return f'{pad}"""{lines[0]}"""' |  | ||||||
|                 else: |  | ||||||
|                     joined = f"\n{pad}".join(lines) |  | ||||||
|                     return f'{pad}"""\n{pad}{joined}\n{pad}"""' |  | ||||||
|  |  | ||||||
|     return "" |     return "" | ||||||
|  |  | ||||||
| @@ -204,7 +214,6 @@ class ProtoContentBase: | |||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class PluginRequestCompiler: | class PluginRequestCompiler: | ||||||
|  |  | ||||||
|     plugin_request_obj: CodeGeneratorRequest |     plugin_request_obj: CodeGeneratorRequest | ||||||
|     output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) |     output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) | ||||||
|  |  | ||||||
| @@ -237,10 +246,14 @@ class OutputTemplate: | |||||||
|     imports: Set[str] = field(default_factory=set) |     imports: Set[str] = field(default_factory=set) | ||||||
|     datetime_imports: Set[str] = field(default_factory=set) |     datetime_imports: Set[str] = field(default_factory=set) | ||||||
|     typing_imports: Set[str] = field(default_factory=set) |     typing_imports: Set[str] = field(default_factory=set) | ||||||
|  |     pydantic_imports: Set[str] = field(default_factory=set) | ||||||
|     builtins_import: bool = False |     builtins_import: bool = False | ||||||
|     messages: List["MessageCompiler"] = field(default_factory=list) |     messages: List["MessageCompiler"] = field(default_factory=list) | ||||||
|     enums: List["EnumDefinitionCompiler"] = field(default_factory=list) |     enums: List["EnumDefinitionCompiler"] = field(default_factory=list) | ||||||
|     services: List["ServiceCompiler"] = field(default_factory=list) |     services: List["ServiceCompiler"] = field(default_factory=list) | ||||||
|  |     imports_type_checking_only: Set[str] = field(default_factory=set) | ||||||
|  |     pydantic_dataclasses: bool = False | ||||||
|  |     output: bool = True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def package(self) -> str: |     def package(self) -> str: | ||||||
| @@ -322,12 +335,29 @@ class MessageCompiler(ProtoContentBase): | |||||||
|     def has_deprecated_fields(self) -> bool: |     def has_deprecated_fields(self) -> bool: | ||||||
|         return any(self.deprecated_fields) |         return any(self.deprecated_fields) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def has_oneof_fields(self) -> bool: | ||||||
|  |         return any(isinstance(field, OneOfFieldCompiler) for field in self.fields) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def has_message_field(self) -> bool: | ||||||
|  |         return any( | ||||||
|  |             ( | ||||||
|  |                 field.proto_obj.type in PROTO_MESSAGE_TYPES | ||||||
|  |                 for field in self.fields | ||||||
|  |                 if isinstance(field.proto_obj, FieldDescriptorProto) | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_map( | def is_map( | ||||||
|     proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto |     proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto | ||||||
| ) -> bool: | ) -> bool: | ||||||
|     """True if proto_field_obj is a map, otherwise False.""" |     """True if proto_field_obj is a map, otherwise False.""" | ||||||
|     if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE: |     if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE: | ||||||
|  |         if not hasattr(parent_message, "nested_type"): | ||||||
|  |             return False | ||||||
|  |  | ||||||
|         # This might be a map... |         # This might be a map... | ||||||
|         message_type = proto_field_obj.type_name.split(".").pop().lower() |         message_type = proto_field_obj.type_name.split(".").pop().lower() | ||||||
|         map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" |         map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" | ||||||
| @@ -355,7 +385,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: | |||||||
|         us to tell whether it was set, via the which_one_of interface. |         us to tell whether it was set, via the which_one_of interface. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" |     return which_one_of(proto_field_obj, "_oneof_index")[0] == "oneof_index" | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| @@ -416,6 +446,10 @@ class FieldCompiler(MessageCompiler): | |||||||
|             imports.add("Dict") |             imports.add("Dict") | ||||||
|         return imports |         return imports | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def pydantic_imports(self) -> Set[str]: | ||||||
|  |         return set() | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def use_builtins(self) -> bool: |     def use_builtins(self) -> bool: | ||||||
|         return self.py_type in self.parent.builtins_types or ( |         return self.py_type in self.parent.builtins_types or ( | ||||||
| @@ -425,6 +459,7 @@ class FieldCompiler(MessageCompiler): | |||||||
|     def add_imports_to(self, output_file: OutputTemplate) -> None: |     def add_imports_to(self, output_file: OutputTemplate) -> None: | ||||||
|         output_file.datetime_imports.update(self.datetime_imports) |         output_file.datetime_imports.update(self.datetime_imports) | ||||||
|         output_file.typing_imports.update(self.typing_imports) |         output_file.typing_imports.update(self.typing_imports) | ||||||
|  |         output_file.pydantic_imports.update(self.pydantic_imports) | ||||||
|         output_file.builtins_import = output_file.builtins_import or self.use_builtins |         output_file.builtins_import = output_file.builtins_import or self.use_builtins | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -529,7 +564,7 @@ class FieldCompiler(MessageCompiler): | |||||||
|                 source_type=self.proto_obj.type_name, |                 source_type=self.proto_obj.type_name, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError(f"Unknown type {field.type}") |             raise NotImplementedError(f"Unknown type {self.proto_obj.type}") | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def annotation(self) -> str: |     def annotation(self) -> str: | ||||||
| @@ -553,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler): | |||||||
|         return args |         return args | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class PydanticOneOfFieldCompiler(OneOfFieldCompiler): | ||||||
|  |     @property | ||||||
|  |     def optional(self) -> bool: | ||||||
|  |         # Force the optional to be True. This will allow the pydantic dataclass | ||||||
|  |         # to validate the object correctly by allowing the field to be let empty. | ||||||
|  |         # We add a pydantic validator later to ensure exactly one field is defined. | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def pydantic_imports(self) -> Set[str]: | ||||||
|  |         return {"root_validator"} | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class MapEntryCompiler(FieldCompiler): | class MapEntryCompiler(FieldCompiler): | ||||||
|     py_k_type: Type = PLACEHOLDER |     py_k_type: Type = PLACEHOLDER | ||||||
| @@ -664,7 +713,6 @@ class ServiceCompiler(ProtoContentBase): | |||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class ServiceMethodCompiler(ProtoContentBase): | class ServiceMethodCompiler(ProtoContentBase): | ||||||
|  |  | ||||||
|     parent: ServiceCompiler |     parent: ServiceCompiler | ||||||
|     proto_obj: MethodDescriptorProto |     proto_obj: MethodDescriptorProto | ||||||
|     path: List[int] = PLACEHOLDER |     path: List[int] = PLACEHOLDER | ||||||
| @@ -675,12 +723,8 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|         self.parent.methods.append(self) |         self.parent.methods.append(self) | ||||||
|  |  | ||||||
|         # Check for imports |         # Check for imports | ||||||
|         if self.py_input_message: |  | ||||||
|             for f in self.py_input_message.fields: |  | ||||||
|                 f.add_imports_to(self.output_file) |  | ||||||
|         if "Optional" in self.py_output_message_type: |         if "Optional" in self.py_output_message_type: | ||||||
|             self.output_file.typing_imports.add("Optional") |             self.output_file.typing_imports.add("Optional") | ||||||
|         self.mutable_default_args  # ensure this is called before rendering |  | ||||||
|  |  | ||||||
|         # Check for Async imports |         # Check for Async imports | ||||||
|         if self.client_streaming: |         if self.client_streaming: | ||||||
| @@ -692,39 +736,18 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|         if self.client_streaming or self.server_streaming: |         if self.client_streaming or self.server_streaming: | ||||||
|             self.output_file.typing_imports.add("AsyncIterator") |             self.output_file.typing_imports.add("AsyncIterator") | ||||||
|  |  | ||||||
|  |         # add imports required for request arguments timeout, deadline and metadata | ||||||
|  |         self.output_file.typing_imports.add("Optional") | ||||||
|  |         self.output_file.imports_type_checking_only.add("import grpclib.server") | ||||||
|  |         self.output_file.imports_type_checking_only.add( | ||||||
|  |             "from betterproto.grpc.grpclib_client import MetadataLike" | ||||||
|  |         ) | ||||||
|  |         self.output_file.imports_type_checking_only.add( | ||||||
|  |             "from grpclib.metadata import Deadline" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         super().__post_init__()  # check for unset fields |         super().__post_init__()  # check for unset fields | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def mutable_default_args(self) -> Dict[str, str]: |  | ||||||
|         """Handle mutable default arguments. |  | ||||||
|  |  | ||||||
|         Returns a list of tuples containing the name and default value |  | ||||||
|         for arguments to this message who's default value is mutable. |  | ||||||
|         The defaults are swapped out for None and replaced back inside |  | ||||||
|         the method's body. |  | ||||||
|         Reference: |  | ||||||
|         https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments |  | ||||||
|  |  | ||||||
|         Returns |  | ||||||
|         ------- |  | ||||||
|         Dict[str, str] |  | ||||||
|             Name and actual default value (as a string) |  | ||||||
|             for each argument with mutable default values. |  | ||||||
|         """ |  | ||||||
|         mutable_default_args = {} |  | ||||||
|  |  | ||||||
|         if self.py_input_message: |  | ||||||
|             for f in self.py_input_message.fields: |  | ||||||
|                 if ( |  | ||||||
|                     not self.client_streaming |  | ||||||
|                     and f.default_value_string != "None" |  | ||||||
|                     and f.mutable |  | ||||||
|                 ): |  | ||||||
|                     mutable_default_args[f.py_name] = f.default_value_string |  | ||||||
|                     self.output_file.typing_imports.add("Optional") |  | ||||||
|  |  | ||||||
|         return mutable_default_args |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def py_name(self) -> str: |     def py_name(self) -> str: | ||||||
|         """Pythonized method name.""" |         """Pythonized method name.""" | ||||||
| @@ -760,7 +783,7 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|         # comparable with method.input_type |         # comparable with method.input_type | ||||||
|         for msg in self.request.all_messages: |         for msg in self.request.all_messages: | ||||||
|             if ( |             if ( | ||||||
|                 msg.py_name == name.replace(".", "") |                 msg.py_name == pythonize_class_name(name.replace(".", "")) | ||||||
|                 and msg.output_file.package == package |                 and msg.output_file.package == package | ||||||
|             ): |             ): | ||||||
|                 return msg |                 return msg | ||||||
| @@ -780,8 +803,20 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|             package=self.output_file.package, |             package=self.output_file.package, | ||||||
|             imports=self.output_file.imports, |             imports=self.output_file.imports, | ||||||
|             source_type=self.proto_obj.input_type, |             source_type=self.proto_obj.input_type, | ||||||
|  |             unwrap=False, | ||||||
|         ).strip('"') |         ).strip('"') | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def py_input_message_param(self) -> str: | ||||||
|  |         """Param name corresponding to py_input_message_type. | ||||||
|  |  | ||||||
|  |         Returns | ||||||
|  |         ------- | ||||||
|  |         str | ||||||
|  |             Param name corresponding to py_input_message_type. | ||||||
|  |         """ | ||||||
|  |         return pythonize_field_name(self.py_input_message_type) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def py_output_message_type(self) -> str: |     def py_output_message_type(self) -> str: | ||||||
|         """String representation of the Python type corresponding to the |         """String representation of the Python type corresponding to the | ||||||
|   | |||||||
| @@ -1,3 +1,13 @@ | |||||||
|  | import pathlib | ||||||
|  | import sys | ||||||
|  | from typing import ( | ||||||
|  |     Generator, | ||||||
|  |     List, | ||||||
|  |     Set, | ||||||
|  |     Tuple, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  |  | ||||||
| from betterproto.lib.google.protobuf import ( | from betterproto.lib.google.protobuf import ( | ||||||
|     DescriptorProto, |     DescriptorProto, | ||||||
|     EnumDescriptorProto, |     EnumDescriptorProto, | ||||||
| @@ -11,10 +21,7 @@ from betterproto.lib.google.protobuf.compiler import ( | |||||||
|     CodeGeneratorResponseFeature, |     CodeGeneratorResponseFeature, | ||||||
|     CodeGeneratorResponseFile, |     CodeGeneratorResponseFile, | ||||||
| ) | ) | ||||||
| import itertools |  | ||||||
| import pathlib |  | ||||||
| import sys |  | ||||||
| from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union |  | ||||||
| from .compiler import outputfile_compiler | from .compiler import outputfile_compiler | ||||||
| from .models import ( | from .models import ( | ||||||
|     EnumDefinitionCompiler, |     EnumDefinitionCompiler, | ||||||
| @@ -24,41 +31,40 @@ from .models import ( | |||||||
|     OneOfFieldCompiler, |     OneOfFieldCompiler, | ||||||
|     OutputTemplate, |     OutputTemplate, | ||||||
|     PluginRequestCompiler, |     PluginRequestCompiler, | ||||||
|  |     PydanticOneOfFieldCompiler, | ||||||
|     ServiceCompiler, |     ServiceCompiler, | ||||||
|     ServiceMethodCompiler, |     ServiceMethodCompiler, | ||||||
|     is_map, |     is_map, | ||||||
|     is_oneof, |     is_oneof, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from google.protobuf.descriptor import Descriptor |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def traverse( | def traverse( | ||||||
|     proto_file: FieldDescriptorProto, |     proto_file: FileDescriptorProto, | ||||||
| ) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": | ) -> Generator[ | ||||||
|  |     Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None | ||||||
|  | ]: | ||||||
|     # Todo: Keep information about nested hierarchy |     # Todo: Keep information about nested hierarchy | ||||||
|     def _traverse( |     def _traverse( | ||||||
|         path: List[int], items: List["EnumDescriptorProto"], prefix="" |         path: List[int], | ||||||
|     ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: |         items: Union[List[EnumDescriptorProto], List[DescriptorProto]], | ||||||
|  |         prefix: str = "", | ||||||
|  |     ) -> Generator[ | ||||||
|  |         Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None | ||||||
|  |     ]: | ||||||
|         for i, item in enumerate(items): |         for i, item in enumerate(items): | ||||||
|             # Adjust the name since we flatten the hierarchy. |             # Adjust the name since we flatten the hierarchy. | ||||||
|             # Todo: don't change the name, but include full name in returned tuple |             # Todo: don't change the name, but include full name in returned tuple | ||||||
|             item.name = next_prefix = prefix + item.name |             item.name = next_prefix = f"{prefix}_{item.name}" | ||||||
|             yield item, path + [i] |             yield item, [*path, i] | ||||||
|  |  | ||||||
|             if isinstance(item, DescriptorProto): |             if isinstance(item, DescriptorProto): | ||||||
|                 for enum in item.enum_type: |                 # Get nested types. | ||||||
|                     enum.name = next_prefix + enum.name |                 yield from _traverse([*path, i, 4], item.enum_type, next_prefix) | ||||||
|                     yield enum, path + [i, 4] |                 yield from _traverse([*path, i, 3], item.nested_type, next_prefix) | ||||||
|  |  | ||||||
|                 if item.nested_type: |     yield from _traverse([5], proto_file.enum_type) | ||||||
|                     for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): |     yield from _traverse([4], proto_file.message_type) | ||||||
|                         yield n, p |  | ||||||
|  |  | ||||||
|     return itertools.chain( |  | ||||||
|         _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | ||||||
| @@ -70,14 +76,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|     request_data = PluginRequestCompiler(plugin_request_obj=request) |     request_data = PluginRequestCompiler(plugin_request_obj=request) | ||||||
|     # Gather output packages |     # Gather output packages | ||||||
|     for proto_file in request.proto_file: |     for proto_file in request.proto_file: | ||||||
|         if ( |  | ||||||
|             proto_file.package == "google.protobuf" |  | ||||||
|             and "INCLUDE_GOOGLE" not in plugin_options |  | ||||||
|         ): |  | ||||||
|             # If not INCLUDE_GOOGLE, |  | ||||||
|             # skip re-compiling Google's well-known types |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         output_package_name = proto_file.package |         output_package_name = proto_file.package | ||||||
|         if output_package_name not in request_data.output_packages: |         if output_package_name not in request_data.output_packages: | ||||||
|             # Create a new output if there is no output for this package |             # Create a new output if there is no output for this package | ||||||
| @@ -87,6 +85,19 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|         # Add this input file to the output corresponding to this package |         # Add this input file to the output corresponding to this package | ||||||
|         request_data.output_packages[output_package_name].input_files.append(proto_file) |         request_data.output_packages[output_package_name].input_files.append(proto_file) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             proto_file.package == "google.protobuf" | ||||||
|  |             and "INCLUDE_GOOGLE" not in plugin_options | ||||||
|  |         ): | ||||||
|  |             # If not INCLUDE_GOOGLE, | ||||||
|  |             # skip outputting Google's well-known types | ||||||
|  |             request_data.output_packages[output_package_name].output = False | ||||||
|  |  | ||||||
|  |         if "pydantic_dataclasses" in plugin_options: | ||||||
|  |             request_data.output_packages[ | ||||||
|  |                 output_package_name | ||||||
|  |             ].pydantic_dataclasses = True | ||||||
|  |  | ||||||
|     # Read Messages and Enums |     # Read Messages and Enums | ||||||
|     # We need to read Messages before Services in so that we can |     # We need to read Messages before Services in so that we can | ||||||
|     # get the references to input/output messages for each service |     # get the references to input/output messages for each service | ||||||
| @@ -109,6 +120,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|     # Generate output files |     # Generate output files | ||||||
|     output_paths: Set[pathlib.Path] = set() |     output_paths: Set[pathlib.Path] = set() | ||||||
|     for output_package_name, output_package in request_data.output_packages.items(): |     for output_package_name, output_package in request_data.output_packages.items(): | ||||||
|  |         if not output_package.output: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|         # Add files to the response object |         # Add files to the response object | ||||||
|         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") |         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") | ||||||
| @@ -127,6 +140,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|         directory.joinpath("__init__.py") |         directory.joinpath("__init__.py") | ||||||
|         for path in output_paths |         for path in output_paths | ||||||
|         for directory in path.parents |         for directory in path.parents | ||||||
|  |         if not directory.joinpath("__init__.py").exists() | ||||||
|     } - output_paths |     } - output_paths | ||||||
|  |  | ||||||
|     for init_file in init_files: |     for init_file in init_files: | ||||||
| @@ -138,6 +152,23 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|     return response |     return response | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _make_one_of_field_compiler( | ||||||
|  |     output_package: OutputTemplate, | ||||||
|  |     source_file: "FileDescriptorProto", | ||||||
|  |     parent: MessageCompiler, | ||||||
|  |     proto_obj: "FieldDescriptorProto", | ||||||
|  |     path: List[int], | ||||||
|  | ) -> FieldCompiler: | ||||||
|  |     pydantic = output_package.pydantic_dataclasses | ||||||
|  |     Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler | ||||||
|  |     return Cls( | ||||||
|  |         source_file=source_file, | ||||||
|  |         parent=parent, | ||||||
|  |         proto_obj=proto_obj, | ||||||
|  |         path=path, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_protobuf_type( | def read_protobuf_type( | ||||||
|     item: DescriptorProto, |     item: DescriptorProto, | ||||||
|     path: List[int], |     path: List[int], | ||||||
| @@ -161,11 +192,8 @@ def read_protobuf_type( | |||||||
|                     path=path + [2, index], |                     path=path + [2, index], | ||||||
|                 ) |                 ) | ||||||
|             elif is_oneof(field): |             elif is_oneof(field): | ||||||
|                 OneOfFieldCompiler( |                 _make_one_of_field_compiler( | ||||||
|                     source_file=source_file, |                     output_package, source_file, message_data, field, path + [2, index] | ||||||
|                     parent=message_data, |  | ||||||
|                     proto_obj=field, |  | ||||||
|                     path=path + [2, index], |  | ||||||
|                 ) |                 ) | ||||||
|             else: |             else: | ||||||
|                 FieldCompiler( |                 FieldCompiler( | ||||||
|   | |||||||
| @@ -1,10 +1,21 @@ | |||||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||||
| # sources: {{ ', '.join(output_file.input_filenames) }} | # sources: {{ ', '.join(output_file.input_filenames) }} | ||||||
| # plugin: python-betterproto | # plugin: python-betterproto | ||||||
|  | # This file has been @generated | ||||||
| {% for i in output_file.python_module_imports|sort %} | {% for i in output_file.python_module_imports|sort %} | ||||||
| import {{ i }} | import {{ i }} | ||||||
| {% endfor %} | {% endfor %} | ||||||
|  |  | ||||||
|  | {% if output_file.pydantic_dataclasses %} | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from dataclasses import dataclass | ||||||
|  | else: | ||||||
|  |     from pydantic.dataclasses import dataclass | ||||||
|  | {%- else -%} | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | {% endif %} | ||||||
|  |  | ||||||
| {% if output_file.datetime_imports %} | {% if output_file.datetime_imports %} | ||||||
| from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||||
|  |  | ||||||
| @@ -14,12 +25,28 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no | |||||||
|  |  | ||||||
| {% endif %} | {% endif %} | ||||||
|  |  | ||||||
|  | {% if output_file.pydantic_imports %} | ||||||
|  | from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||||
|  |  | ||||||
|  | {% endif %} | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from betterproto.grpc.grpclib_server import ServiceBase |  | ||||||
| {% if output_file.services %} | {% if output_file.services %} | ||||||
|  | from betterproto.grpc.grpclib_server import ServiceBase | ||||||
| import grpclib | import grpclib | ||||||
| {% endif %} | {% endif %} | ||||||
|  |  | ||||||
|  | {% for i in output_file.imports|sort %} | ||||||
|  | {{ i }} | ||||||
|  | {% endfor %} | ||||||
|  |  | ||||||
|  | {% if output_file.imports_type_checking_only %} | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  | {% for i in output_file.imports_type_checking_only|sort %}    {{ i }} | ||||||
|  | {% endfor %} | ||||||
|  | {% endif %} | ||||||
|  |  | ||||||
| {% if output_file.enums %}{% for enum in output_file.enums %} | {% if output_file.enums %}{% for enum in output_file.enums %} | ||||||
| class {{ enum.py_name }}(betterproto.Enum): | class {{ enum.py_name }}(betterproto.Enum): | ||||||
| @@ -28,10 +55,11 @@ class {{ enum.py_name }}(betterproto.Enum): | |||||||
|  |  | ||||||
|     {% endif %} |     {% endif %} | ||||||
|     {% for entry in enum.entries %} |     {% for entry in enum.entries %} | ||||||
|  |     {{ entry.name }} = {{ entry.value }} | ||||||
|         {% if entry.comment %} |         {% if entry.comment %} | ||||||
| {{ entry.comment }} | {{ entry.comment }} | ||||||
|  |  | ||||||
|         {% endif %} |         {% endif %} | ||||||
|     {{ entry.name }} = {{ entry.value }} |  | ||||||
|     {% endfor %} |     {% endfor %} | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -45,10 +73,11 @@ class {{ message.py_name }}(betterproto.Message): | |||||||
|  |  | ||||||
|     {% endif %} |     {% endif %} | ||||||
|     {% for field in message.fields %} |     {% for field in message.fields %} | ||||||
|  |     {{ field.get_field_string() }} | ||||||
|         {% if field.comment %} |         {% if field.comment %} | ||||||
| {{ field.comment }} | {{ field.comment }} | ||||||
|  |  | ||||||
|         {% endif %} |         {% endif %} | ||||||
|     {{ field.get_field_string() }} |  | ||||||
|     {% endfor %} |     {% endfor %} | ||||||
|     {% if not message.fields %} |     {% if not message.fields %} | ||||||
|     pass |     pass | ||||||
| @@ -61,11 +90,16 @@ class {{ message.py_name }}(betterproto.Message): | |||||||
|         {% endif %} |         {% endif %} | ||||||
|         super().__post_init__() |         super().__post_init__() | ||||||
|         {% for field in message.deprecated_fields %} |         {% for field in message.deprecated_fields %} | ||||||
|         if self.{{ field }}: |         if self.is_set("{{ field }}"): | ||||||
|             warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning) |             warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning) | ||||||
|         {% endfor %} |         {% endfor %} | ||||||
|     {%  endif %} |     {%  endif %} | ||||||
|  |  | ||||||
|  |     {% if output_file.pydantic_dataclasses and message.has_oneof_fields %} | ||||||
|  |     @root_validator() | ||||||
|  |     def check_oneof(cls, values): | ||||||
|  |         return cls._validate_field_groups(values) | ||||||
|  |     {%  endif %} | ||||||
|  |  | ||||||
| {% endfor %} | {% endfor %} | ||||||
| {% for service in output_file.services %} | {% for service in output_file.services %} | ||||||
| @@ -79,60 +113,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | |||||||
|     {% for method in service.methods %} |     {% for method in service.methods %} | ||||||
|     async def {{ method.py_name }}(self |     async def {{ method.py_name }}(self | ||||||
|         {%- if not method.client_streaming -%} |         {%- if not method.client_streaming -%} | ||||||
|             {%- if method.py_input_message and method.py_input_message.fields -%}, *, |             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||||
|                 {%- for field in method.py_input_message.fields -%} |  | ||||||
|                     {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} |  | ||||||
|                                             Optional[{{ field.annotation }}] |  | ||||||
|                                          {%- else -%} |  | ||||||
|                                             {{ field.annotation }} |  | ||||||
|                                          {%- endif -%} = |  | ||||||
|                                             {%- if field.py_name not in method.mutable_default_args -%} |  | ||||||
|                                                 {{ field.default_value_string }} |  | ||||||
|                                             {%- else -%} |  | ||||||
|                                                 None |  | ||||||
|                                             {% endif -%} |  | ||||||
|                     {%- if not loop.last %}, {% endif -%} |  | ||||||
|                 {%- endfor -%} |  | ||||||
|             {%- endif -%} |  | ||||||
|         {%- else -%} |         {%- else -%} | ||||||
|             {# Client streaming: need a request iterator instead #} |             {# Client streaming: need a request iterator instead #} | ||||||
|             , request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] |             , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] | ||||||
|         {%- endif -%} |         {%- endif -%} | ||||||
|  |             , | ||||||
|  |             * | ||||||
|  |             , timeout: Optional[float] = None | ||||||
|  |             , deadline: Optional["Deadline"] = None | ||||||
|  |             , metadata: Optional["MetadataLike"] = None | ||||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: |             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||||
|         {% if method.comment %} |         {% if method.comment %} | ||||||
| {{ method.comment }} | {{ method.comment }} | ||||||
|  |  | ||||||
|         {% endif %} |         {% endif %} | ||||||
|         {%- for py_name, zero in method.mutable_default_args.items() %} |  | ||||||
|         {{ py_name }} = {{ py_name }} or {{ zero }} |  | ||||||
|         {% endfor %} |  | ||||||
|  |  | ||||||
|         {% if not method.client_streaming %} |  | ||||||
|         request = {{ method.py_input_message_type }}() |  | ||||||
|         {% for field in method.py_input_message.fields %} |  | ||||||
|             {% if field.field_type == 'message' %} |  | ||||||
|         if {{ field.py_name }} is not None: |  | ||||||
|             request.{{ field.py_name }} = {{ field.py_name }} |  | ||||||
|             {% else %} |  | ||||||
|         request.{{ field.py_name }} = {{ field.py_name }} |  | ||||||
|             {% endif %} |  | ||||||
|         {% endfor %} |  | ||||||
|         {% endif %} |  | ||||||
|  |  | ||||||
|         {% if method.server_streaming %} |         {% if method.server_streaming %} | ||||||
|             {% if method.client_streaming %} |             {% if method.client_streaming %} | ||||||
|         async for response in self._stream_stream( |         async for response in self._stream_stream( | ||||||
|             "{{ method.route }}", |             "{{ method.route }}", | ||||||
|             request_iterator, |             {{ method.py_input_message_param }}_iterator, | ||||||
|             {{ method.py_input_message_type }}, |             {{ method.py_input_message_type }}, | ||||||
|             {{ method.py_output_message_type.strip('"') }}, |             {{ method.py_output_message_type.strip('"') }}, | ||||||
|  |             timeout=timeout, | ||||||
|  |             deadline=deadline, | ||||||
|  |             metadata=metadata, | ||||||
|         ): |         ): | ||||||
|             yield response |             yield response | ||||||
|             {% else %}{# i.e. not client streaming #} |             {% else %}{# i.e. not client streaming #} | ||||||
|         async for response in self._unary_stream( |         async for response in self._unary_stream( | ||||||
|             "{{ method.route }}", |             "{{ method.route }}", | ||||||
|             request, |             {{ method.py_input_message_param }}, | ||||||
|             {{ method.py_output_message_type.strip('"') }}, |             {{ method.py_output_message_type.strip('"') }}, | ||||||
|  |             timeout=timeout, | ||||||
|  |             deadline=deadline, | ||||||
|  |             metadata=metadata, | ||||||
|         ): |         ): | ||||||
|             yield response |             yield response | ||||||
|  |  | ||||||
| @@ -141,15 +156,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | |||||||
|             {% if method.client_streaming %} |             {% if method.client_streaming %} | ||||||
|         return await self._stream_unary( |         return await self._stream_unary( | ||||||
|             "{{ method.route }}", |             "{{ method.route }}", | ||||||
|             request_iterator, |             {{ method.py_input_message_param }}_iterator, | ||||||
|             {{ method.py_input_message_type }}, |             {{ method.py_input_message_type }}, | ||||||
|             {{ method.py_output_message_type.strip('"') }} |             {{ method.py_output_message_type.strip('"') }}, | ||||||
|  |             timeout=timeout, | ||||||
|  |             deadline=deadline, | ||||||
|  |             metadata=metadata, | ||||||
|         ) |         ) | ||||||
|             {% else %}{# i.e. not client streaming #} |             {% else %}{# i.e. not client streaming #} | ||||||
|         return await self._unary_unary( |         return await self._unary_unary( | ||||||
|             "{{ method.route }}", |             "{{ method.route }}", | ||||||
|             request, |             {{ method.py_input_message_param }}, | ||||||
|             {{ method.py_output_message_type.strip('"') }} |             {{ method.py_output_message_type.strip('"') }}, | ||||||
|  |             timeout=timeout, | ||||||
|  |             deadline=deadline, | ||||||
|  |             metadata=metadata, | ||||||
|         ) |         ) | ||||||
|             {% endif %}{# client streaming #} |             {% endif %}{# client streaming #} | ||||||
|         {% endif %} |         {% endif %} | ||||||
| @@ -167,19 +188,10 @@ class {{ service.py_name }}Base(ServiceBase): | |||||||
|     {% for method in service.methods %} |     {% for method in service.methods %} | ||||||
|     async def {{ method.py_name }}(self |     async def {{ method.py_name }}(self | ||||||
|         {%- if not method.client_streaming -%} |         {%- if not method.client_streaming -%} | ||||||
|             {%- if method.py_input_message and method.py_input_message.fields -%}, |             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||||
|                 {%- for field in method.py_input_message.fields -%} |  | ||||||
|                     {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} |  | ||||||
|                                             Optional[{{ field.annotation }}] |  | ||||||
|                                          {%- else -%} |  | ||||||
|                                             {{ field.annotation }} |  | ||||||
|                                          {%- endif -%} |  | ||||||
|                     {%- if not loop.last %}, {% endif -%} |  | ||||||
|                 {%- endfor -%} |  | ||||||
|             {%- endif -%} |  | ||||||
|         {%- else -%} |         {%- else -%} | ||||||
|             {# Client streaming: need a request iterator instead #} |             {# Client streaming: need a request iterator instead #} | ||||||
|             , request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] |             , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"] | ||||||
|         {%- endif -%} |         {%- endif -%} | ||||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: |             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||||
|         {% if method.comment %} |         {% if method.comment %} | ||||||
| @@ -187,32 +199,27 @@ class {{ service.py_name }}Base(ServiceBase): | |||||||
|  |  | ||||||
|         {% endif %} |         {% endif %} | ||||||
|         raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) |         raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) | ||||||
|  |         {% if method.server_streaming %} | ||||||
|  |         yield {{ method.py_output_message_type }}() | ||||||
|  |         {% endif %} | ||||||
|  |  | ||||||
|     {% endfor %} |     {% endfor %} | ||||||
|  |  | ||||||
|     {% for method in service.methods %} |     {% for method in service.methods %} | ||||||
|     async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: |     async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None: | ||||||
|         {% if not method.client_streaming %} |         {% if not method.client_streaming %} | ||||||
|         request = await stream.recv_message() |         request = await stream.recv_message() | ||||||
|  |  | ||||||
|         request_kwargs = { |  | ||||||
|         {% for field in method.py_input_message.fields %} |  | ||||||
|             "{{ field.py_name }}": request.{{ field.py_name }}, |  | ||||||
|         {% endfor %} |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         {% else %} |         {% else %} | ||||||
|         request_kwargs = {"request_iterator": stream.__aiter__()} |         request = stream.__aiter__() | ||||||
|         {% endif %} |         {% endif %} | ||||||
|  |  | ||||||
|         {% if not method.server_streaming %} |         {% if not method.server_streaming %} | ||||||
|         response = await self.{{ method.py_name }}(**request_kwargs) |         response = await self.{{ method.py_name }}(request) | ||||||
|         await stream.send_message(response) |         await stream.send_message(response) | ||||||
|         {% else %} |         {% else %} | ||||||
|         await self._call_rpc_handler_server_stream( |         await self._call_rpc_handler_server_stream( | ||||||
|             self.{{ method.py_name }}, |             self.{{ method.py_name }}, | ||||||
|             stream, |             stream, | ||||||
|             request_kwargs, |             request, | ||||||
|         ) |         ) | ||||||
|         {% endif %} |         {% endif %} | ||||||
|  |  | ||||||
| @@ -240,6 +247,10 @@ class {{ service.py_name }}Base(ServiceBase): | |||||||
|  |  | ||||||
| {% endfor %} | {% endfor %} | ||||||
|  |  | ||||||
| {% for i in output_file.imports|sort %} | {% if output_file.pydantic_dataclasses %} | ||||||
| {{ i }} | {% for message in output_file.messages %} | ||||||
|  | {% if message.has_message_field %} | ||||||
|  | {{ message.py_name }}.__pydantic_model__.update_forward_refs()  # type: ignore | ||||||
|  | {% endif %} | ||||||
| {% endfor %} | {% endfor %} | ||||||
|  | {% endif %} | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import copy | ||||||
|  | import sys | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -10,3 +13,10 @@ def pytest_addoption(parser): | |||||||
| @pytest.fixture(scope="session") | @pytest.fixture(scope="session") | ||||||
| def repeat(request): | def repeat(request): | ||||||
|     return request.config.getoption("repeat") |     return request.config.getoption("repeat") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def reset_sys_path(): | ||||||
|  |     original = copy.deepcopy(sys.path) | ||||||
|  |     yield | ||||||
|  |     sys.path = original | ||||||
|   | |||||||
| @@ -1,20 +1,22 @@ | |||||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||||
| import asyncio | import asyncio | ||||||
| import os | import os | ||||||
| from pathlib import Path |  | ||||||
| import platform | import platform | ||||||
| import shutil | import shutil | ||||||
| import sys | import sys | ||||||
|  | from pathlib import Path | ||||||
| from typing import Set | from typing import Set | ||||||
|  |  | ||||||
| from tests.util import ( | from tests.util import ( | ||||||
|     get_directories, |     get_directories, | ||||||
|     inputs_path, |     inputs_path, | ||||||
|     output_path_betterproto, |     output_path_betterproto, | ||||||
|  |     output_path_betterproto_pydantic, | ||||||
|     output_path_reference, |     output_path_reference, | ||||||
|     protoc, |     protoc, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Force pure-python implementation instead of C++, otherwise imports | # Force pure-python implementation instead of C++, otherwise imports | ||||||
| # break things because we can't properly reset the symbol database. | # break things because we can't properly reset the symbol database. | ||||||
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | ||||||
| @@ -78,10 +80,12 @@ async def generate_test_case_output( | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     test_case_output_path_reference = output_path_reference.joinpath(test_case_name) |     test_case_output_path_reference = output_path_reference.joinpath(test_case_name) | ||||||
|     test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name) |     test_case_output_path_betterproto = output_path_betterproto | ||||||
|  |     test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic | ||||||
|  |  | ||||||
|     os.makedirs(test_case_output_path_reference, exist_ok=True) |     os.makedirs(test_case_output_path_reference, exist_ok=True) | ||||||
|     os.makedirs(test_case_output_path_betterproto, exist_ok=True) |     os.makedirs(test_case_output_path_betterproto, exist_ok=True) | ||||||
|  |     os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True) | ||||||
|  |  | ||||||
|     clear_directory(test_case_output_path_reference) |     clear_directory(test_case_output_path_reference) | ||||||
|     clear_directory(test_case_output_path_betterproto) |     clear_directory(test_case_output_path_betterproto) | ||||||
| @@ -89,9 +93,13 @@ async def generate_test_case_output( | |||||||
|     ( |     ( | ||||||
|         (ref_out, ref_err, ref_code), |         (ref_out, ref_err, ref_code), | ||||||
|         (plg_out, plg_err, plg_code), |         (plg_out, plg_err, plg_code), | ||||||
|  |         (plg_out_pyd, plg_err_pyd, plg_code_pyd), | ||||||
|     ) = await asyncio.gather( |     ) = await asyncio.gather( | ||||||
|         protoc(test_case_input_path, test_case_output_path_reference, True), |         protoc(test_case_input_path, test_case_output_path_reference, True), | ||||||
|         protoc(test_case_input_path, test_case_output_path_betterproto, False), |         protoc(test_case_input_path, test_case_output_path_betterproto, False), | ||||||
|  |         protoc( | ||||||
|  |             test_case_input_path, test_case_output_path_betterproto_pyd, False, True | ||||||
|  |         ), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     if ref_code == 0: |     if ref_code == 0: | ||||||
| @@ -130,7 +138,27 @@ async def generate_test_case_output( | |||||||
|             sys.stderr.buffer.write(plg_err) |             sys.stderr.buffer.write(plg_err) | ||||||
|             sys.stderr.buffer.flush() |             sys.stderr.buffer.flush() | ||||||
|  |  | ||||||
|     return max(ref_code, plg_code) |     if plg_code_pyd == 0: | ||||||
|  |         print( | ||||||
|  |             f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m" | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         print( | ||||||
|  |             f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     if verbose: | ||||||
|  |         if plg_out_pyd: | ||||||
|  |             print("Plugin stdout:") | ||||||
|  |             sys.stdout.buffer.write(plg_out_pyd) | ||||||
|  |             sys.stdout.buffer.flush() | ||||||
|  |  | ||||||
|  |         if plg_err_pyd: | ||||||
|  |             print("Plugin stderr:") | ||||||
|  |             sys.stderr.buffer.write(plg_err_pyd) | ||||||
|  |             sys.stderr.buffer.flush() | ||||||
|  |  | ||||||
|  |     return max(ref_code, plg_code, plg_code_pyd) | ||||||
|  |  | ||||||
|  |  | ||||||
| HELP = "\n".join( | HELP = "\n".join( | ||||||
| @@ -159,9 +187,19 @@ def main(): | |||||||
|         whitelist = set(sys.argv[1:]) |         whitelist = set(sys.argv[1:]) | ||||||
|  |  | ||||||
|     if platform.system() == "Windows": |     if platform.system() == "Windows": | ||||||
|         asyncio.set_event_loop(asyncio.ProactorEventLoop()) |         # for python version prior to 3.8, loop policy needs to be set explicitly | ||||||
|  |         # https://docs.python.org/3/library/asyncio-policy.html#asyncio.DefaultEventLoopPolicy | ||||||
|  |         try: | ||||||
|  |             asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) | ||||||
|  |         except AttributeError: | ||||||
|  |             # python < 3.7 does not have asyncio.WindowsProactorEventLoopPolicy | ||||||
|  |             asyncio.get_event_loop_policy().set_event_loop(asyncio.ProactorEventLoop()) | ||||||
|  |  | ||||||
|     asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) |     try: | ||||||
|  |         asyncio.run(generate(whitelist, verbose)) | ||||||
|  |     except AttributeError: | ||||||
|  |         # compatibility code for python < 3.7 | ||||||
|  |         asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -1,23 +1,27 @@ | |||||||
| import asyncio | import asyncio | ||||||
| import sys | import sys | ||||||
|  | import uuid | ||||||
|  |  | ||||||
| from tests.output_betterproto.service.service import ( | import grpclib | ||||||
|  | import grpclib.client | ||||||
|  | import grpclib.metadata | ||||||
|  | import grpclib.server | ||||||
|  | import pytest | ||||||
|  | from grpclib.testing import ChannelFor | ||||||
|  |  | ||||||
|  | from betterproto.grpc.util.async_channel import AsyncChannel | ||||||
|  | from tests.output_betterproto.service import ( | ||||||
|     DoThingRequest, |     DoThingRequest, | ||||||
|     DoThingResponse, |     DoThingResponse, | ||||||
|     GetThingRequest, |     GetThingRequest, | ||||||
|     TestStub as ThingServiceClient, |     TestStub as ThingServiceClient, | ||||||
| ) | ) | ||||||
| import grpclib |  | ||||||
| import grpclib.metadata |  | ||||||
| import grpclib.server |  | ||||||
| from grpclib.testing import ChannelFor |  | ||||||
| import pytest |  | ||||||
| from betterproto.grpc.util.async_channel import AsyncChannel |  | ||||||
| from .thing_service import ThingService | from .thing_service import ThingService | ||||||
|  |  | ||||||
|  |  | ||||||
| async def _test_client(client, name="clean room", **kwargs): | async def _test_client(client: ThingServiceClient, name="clean room", **kwargs): | ||||||
|     response = await client.do_thing(name=name) |     response = await client.do_thing(DoThingRequest(name=name), **kwargs) | ||||||
|     assert response.names == [name] |     assert response.names == [name] | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -62,7 +66,7 @@ async def test_trailer_only_error_unary_unary( | |||||||
|     ) |     ) | ||||||
|     async with ChannelFor([service]) as channel: |     async with ChannelFor([service]) as channel: | ||||||
|         with pytest.raises(grpclib.exceptions.GRPCError) as e: |         with pytest.raises(grpclib.exceptions.GRPCError) as e: | ||||||
|             await ThingServiceClient(channel).do_thing(name="something") |             await ThingServiceClient(channel).do_thing(DoThingRequest(name="something")) | ||||||
|         assert e.value.status == grpclib.Status.UNAUTHENTICATED |         assert e.value.status == grpclib.Status.UNAUTHENTICATED | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -80,7 +84,7 @@ async def test_trailer_only_error_stream_unary( | |||||||
|     async with ChannelFor([service]) as channel: |     async with ChannelFor([service]) as channel: | ||||||
|         with pytest.raises(grpclib.exceptions.GRPCError) as e: |         with pytest.raises(grpclib.exceptions.GRPCError) as e: | ||||||
|             await ThingServiceClient(channel).do_many_things( |             await ThingServiceClient(channel).do_many_things( | ||||||
|                 request_iterator=[DoThingRequest(name="something")] |                 do_thing_request_iterator=[DoThingRequest(name="something")] | ||||||
|             ) |             ) | ||||||
|             await _test_client(ThingServiceClient(channel)) |             await _test_client(ThingServiceClient(channel)) | ||||||
|         assert e.value.status == grpclib.Status.UNAUTHENTICATED |         assert e.value.status == grpclib.Status.UNAUTHENTICATED | ||||||
| @@ -171,6 +175,56 @@ async def test_service_call_lower_level_with_overrides(): | |||||||
|         assert response.names == [THING_TO_DO] |         assert response.names == [THING_TO_DO] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     ("overrides_gen",), | ||||||
|  |     [ | ||||||
|  |         (lambda: dict(timeout=10),), | ||||||
|  |         (lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),), | ||||||
|  |         (lambda: dict(metadata={"authorization": str(uuid.uuid4())}),), | ||||||
|  |         (lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | async def test_service_call_high_level_with_overrides(mocker, overrides_gen): | ||||||
|  |     overrides = overrides_gen() | ||||||
|  |     request_spy = mocker.spy(grpclib.client.Channel, "request") | ||||||
|  |     name = str(uuid.uuid4()) | ||||||
|  |     defaults = dict( | ||||||
|  |         timeout=99, | ||||||
|  |         deadline=grpclib.metadata.Deadline.from_timeout(99), | ||||||
|  |         metadata={"authorization": name}, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     async with ChannelFor( | ||||||
|  |         [ | ||||||
|  |             ThingService( | ||||||
|  |                 test_hook=_assert_request_meta_received( | ||||||
|  |                     deadline=grpclib.metadata.Deadline.from_timeout( | ||||||
|  |                         overrides.get("timeout", 99) | ||||||
|  |                     ), | ||||||
|  |                     metadata=overrides.get("metadata", defaults.get("metadata")), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         ] | ||||||
|  |     ) as channel: | ||||||
|  |         client = ThingServiceClient(channel, **defaults) | ||||||
|  |         await _test_client(client, name=name, **overrides) | ||||||
|  |         assert request_spy.call_count == 1 | ||||||
|  |  | ||||||
|  |         # for python <3.8 request_spy.call_args.kwargs do not work | ||||||
|  |         _, request_spy_call_kwargs = request_spy.call_args_list[0] | ||||||
|  |  | ||||||
|  |         # ensure all overrides were successful | ||||||
|  |         for key, value in overrides.items(): | ||||||
|  |             assert key in request_spy_call_kwargs | ||||||
|  |             assert request_spy_call_kwargs[key] == value | ||||||
|  |  | ||||||
|  |         # ensure default values were retained | ||||||
|  |         for key in set(defaults.keys()) - set(overrides.keys()): | ||||||
|  |             assert key in request_spy_call_kwargs | ||||||
|  |             assert request_spy_call_kwargs[key] == defaults[key] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| async def test_async_gen_for_unary_stream_request(): | async def test_async_gen_for_unary_stream_request(): | ||||||
|     thing_name = "my milkshakes" |     thing_name = "my milkshakes" | ||||||
| @@ -178,7 +232,9 @@ async def test_async_gen_for_unary_stream_request(): | |||||||
|     async with ChannelFor([ThingService()]) as channel: |     async with ChannelFor([ThingService()]) as channel: | ||||||
|         client = ThingServiceClient(channel) |         client = ThingServiceClient(channel) | ||||||
|         expected_versions = [5, 4, 3, 2, 1] |         expected_versions = [5, 4, 3, 2, 1] | ||||||
|         async for response in client.get_thing_versions(name=thing_name): |         async for response in client.get_thing_versions( | ||||||
|  |             GetThingRequest(name=thing_name) | ||||||
|  |         ): | ||||||
|             assert response.name == thing_name |             assert response.name == thing_name | ||||||
|             assert response.version == expected_versions.pop() |             assert response.version == expected_versions.pop() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,9 +1,11 @@ | |||||||
| import asyncio | import asyncio | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from typing import AsyncIterator | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from betterproto.grpc.util.async_channel import AsyncChannel | from betterproto.grpc.util.async_channel import AsyncChannel | ||||||
| from dataclasses import dataclass |  | ||||||
| import pytest |  | ||||||
| from typing import AsyncIterator |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
|   | |||||||
| @@ -1,12 +1,14 @@ | |||||||
| from tests.output_betterproto.service.service import ( | from typing import Dict | ||||||
|     DoThingResponse, |  | ||||||
|  | import grpclib | ||||||
|  | import grpclib.server | ||||||
|  |  | ||||||
|  | from tests.output_betterproto.service import ( | ||||||
|     DoThingRequest, |     DoThingRequest, | ||||||
|  |     DoThingResponse, | ||||||
|     GetThingRequest, |     GetThingRequest, | ||||||
|     GetThingResponse, |     GetThingResponse, | ||||||
| ) | ) | ||||||
| import grpclib |  | ||||||
| import grpclib.server |  | ||||||
| from typing import Dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ThingService: | class ThingService: | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package bool; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|     bool value = 1; |     bool value = 1; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,19 @@ | |||||||
|  | import pytest | ||||||
|  |  | ||||||
| from tests.output_betterproto.bool import Test | from tests.output_betterproto.bool import Test | ||||||
|  | from tests.output_betterproto_pydantic.bool import Test as TestPyd | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_value(): | def test_value(): | ||||||
|     message = Test() |     message = Test() | ||||||
|     assert not message.value, "Boolean is False by default" |     assert not message.value, "Boolean is False by default" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pydantic_no_value(): | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         TestPyd() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pydantic_value(): | ||||||
|  |     message = Test(value=False) | ||||||
|  |     assert not message.value | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package bytes; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|     bytes data = 1; |     bytes data = 1; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package casing; | ||||||
|  |  | ||||||
| enum my_enum { | enum my_enum { | ||||||
|   ZERO = 0; |   ZERO = 0; | ||||||
|   ONE = 1; |   ONE = 1; | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								tests/inputs/casing_inner_class/casing_inner_class.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								tests/inputs/casing_inner_class/casing_inner_class.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | |||||||
|  | // https://github.com/danielgtaylor/python-betterproto/issues/344 | ||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package casing_inner_class; | ||||||
|  |  | ||||||
|  | message Test { | ||||||
|  |   message inner_class { | ||||||
|  |     sint32 old_exp = 1; | ||||||
|  |   } | ||||||
|  |   inner_class inner = 2; | ||||||
|  | } | ||||||
							
								
								
									
										14
									
								
								tests/inputs/casing_inner_class/test_casing_inner_class.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								tests/inputs/casing_inner_class/test_casing_inner_class.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | import tests.output_betterproto.casing_inner_class as casing_inner_class | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_message_casing_inner_class_name(): | ||||||
|  |     assert hasattr( | ||||||
|  |         casing_inner_class, "TestInnerClass" | ||||||
|  |     ), "Inline defined Message is correctly converted to CamelCase" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_message_casing_inner_class_attributes(): | ||||||
|  |     message = casing_inner_class.Test() | ||||||
|  |     assert hasattr( | ||||||
|  |         message.inner, "old_exp" | ||||||
|  |     ), "Inline defined Message attribute is snake_case" | ||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package casing_message_field_uppercase; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|   int32 UPPERCASE = 1; |   int32 UPPERCASE = 1; | ||||||
|   int32 UPPERCASE_V2 = 2; |   int32 UPPERCASE_V2 = 2; | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ xfail = { | |||||||
| } | } | ||||||
|  |  | ||||||
| services = { | services = { | ||||||
|  |     "googletypes_request", | ||||||
|     "googletypes_response", |     "googletypes_response", | ||||||
|     "googletypes_response_embedded", |     "googletypes_response_embedded", | ||||||
|     "service", |     "service", | ||||||
| @@ -18,6 +19,7 @@ services = { | |||||||
|     "googletypes_service_returns_googletype", |     "googletypes_service_returns_googletype", | ||||||
|     "example_service", |     "example_service", | ||||||
|     "empty_service", |     "empty_service", | ||||||
|  |     "service_uppercase", | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| { | { | ||||||
|   "v": 10, |   "message": { | ||||||
|  |     "value": "hello" | ||||||
|  |   }, | ||||||
|   "value": 10 |   "value": 10 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,9 +1,14 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package deprecated; | ||||||
|  |  | ||||||
| // Some documentation about the Test message. | // Some documentation about the Test message. | ||||||
| message Test { | message Test { | ||||||
|     // Some documentation about the value. |     Message message = 1 [deprecated=true]; | ||||||
|     option deprecated = true; |  | ||||||
|     int32 v = 1 [deprecated=true]; |  | ||||||
|     int32 value = 2; |     int32 value = 2; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | message Message { | ||||||
|  |     option deprecated = true; | ||||||
|  |     string value = 1; | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,4 +0,0 @@ | |||||||
| { |  | ||||||
|   "v": 10, |  | ||||||
|   "value": 10 |  | ||||||
| } |  | ||||||
| @@ -1,8 +0,0 @@ | |||||||
| syntax = "proto3"; |  | ||||||
|  |  | ||||||
| // Some documentation about the Test message. |  | ||||||
| message Test { |  | ||||||
|     // Some documentation about the value. |  | ||||||
|     int32 v = 1 [deprecated=true]; |  | ||||||
|     int32 value = 2; |  | ||||||
| } |  | ||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package double; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|     double count = 1; |     double count = 1; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package empty_repeated; | ||||||
|  |  | ||||||
| message MessageA { | message MessageA { | ||||||
|   repeated float values = 1; |   repeated float values = 1; | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								tests/inputs/entry/entry.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								tests/inputs/entry/entry.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package entry; | ||||||
|  |  | ||||||
|  | // This is a minimal example of a repeated message field that caused issues when | ||||||
|  | // checking whether a message is a map. | ||||||
|  | // | ||||||
|  | // During the check wheter a field is a "map", the string "entry" is added to | ||||||
|  | // the field name, checked against the type name and then further checks are | ||||||
|  | // made against the nested type of a parent message. In this edge-case, the | ||||||
|  | // first check would pass even though it shouldn't and that would cause an | ||||||
|  | // error because the parent type does not have a "nested_type" attribute. | ||||||
|  |  | ||||||
|  | message Test { | ||||||
|  |     repeated ExportEntry export = 1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message ExportEntry { | ||||||
|  |     string name = 1; | ||||||
|  | } | ||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package enum; | ||||||
|  |  | ||||||
| // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values | // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values | ||||||
| message Test { | message Test { | ||||||
|   Choice choice = 1; |   Choice choice = 1; | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| from tests.output_betterproto.enum import ( | from tests.output_betterproto.enum import ( | ||||||
|     Test, |  | ||||||
|     Choice, |     Choice, | ||||||
|  |     Test, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -39,6 +39,8 @@ | |||||||
|  |  | ||||||
| syntax = "proto2"; | syntax = "proto2"; | ||||||
|  |  | ||||||
|  | package example; | ||||||
|  |  | ||||||
| // package google.protobuf; | // package google.protobuf; | ||||||
|  |  | ||||||
| option go_package = "google.golang.org/protobuf/types/descriptorpb"; | option go_package = "google.golang.org/protobuf/types/descriptorpb"; | ||||||
|   | |||||||
| @@ -1,49 +1,52 @@ | |||||||
| from typing import AsyncIterator, AsyncIterable | from typing import ( | ||||||
|  |     AsyncIterable, | ||||||
|  |     AsyncIterator, | ||||||
|  | ) | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
| from grpclib.testing import ChannelFor | from grpclib.testing import ChannelFor | ||||||
|  |  | ||||||
| from tests.output_betterproto.example_service.example_service import ( | from tests.output_betterproto.example_service import ( | ||||||
|     TestBase, |  | ||||||
|     TestStub, |  | ||||||
|     ExampleRequest, |     ExampleRequest, | ||||||
|     ExampleResponse, |     ExampleResponse, | ||||||
|  |     TestBase, | ||||||
|  |     TestStub, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ExampleService(TestBase): | class ExampleService(TestBase): | ||||||
|     async def example_unary_unary( |     async def example_unary_unary( | ||||||
|         self, example_string: str, example_integer: int |         self, example_request: ExampleRequest | ||||||
|     ) -> "ExampleResponse": |     ) -> "ExampleResponse": | ||||||
|         return ExampleResponse( |         return ExampleResponse( | ||||||
|             example_string=example_string, |             example_string=example_request.example_string, | ||||||
|             example_integer=example_integer, |             example_integer=example_request.example_integer, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     async def example_unary_stream( |     async def example_unary_stream( | ||||||
|         self, example_string: str, example_integer: int |         self, example_request: ExampleRequest | ||||||
|     ) -> AsyncIterator["ExampleResponse"]: |     ) -> AsyncIterator["ExampleResponse"]: | ||||||
|         response = ExampleResponse( |         response = ExampleResponse( | ||||||
|             example_string=example_string, |             example_string=example_request.example_string, | ||||||
|             example_integer=example_integer, |             example_integer=example_request.example_integer, | ||||||
|         ) |         ) | ||||||
|         yield response |         yield response | ||||||
|         yield response |         yield response | ||||||
|         yield response |         yield response | ||||||
|  |  | ||||||
|     async def example_stream_unary( |     async def example_stream_unary( | ||||||
|         self, request_iterator: AsyncIterator["ExampleRequest"] |         self, example_request_iterator: AsyncIterator["ExampleRequest"] | ||||||
|     ) -> "ExampleResponse": |     ) -> "ExampleResponse": | ||||||
|         async for example_request in request_iterator: |         async for example_request in example_request_iterator: | ||||||
|             return ExampleResponse( |             return ExampleResponse( | ||||||
|                 example_string=example_request.example_string, |                 example_string=example_request.example_string, | ||||||
|                 example_integer=example_request.example_integer, |                 example_integer=example_request.example_integer, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     async def example_stream_stream( |     async def example_stream_stream( | ||||||
|         self, request_iterator: AsyncIterator["ExampleRequest"] |         self, example_request_iterator: AsyncIterator["ExampleRequest"] | ||||||
|     ) -> AsyncIterator["ExampleResponse"]: |     ) -> AsyncIterator["ExampleResponse"]: | ||||||
|         async for example_request in request_iterator: |         async for example_request in example_request_iterator: | ||||||
|             yield ExampleResponse( |             yield ExampleResponse( | ||||||
|                 example_string=example_request.example_string, |                 example_string=example_request.example_string, | ||||||
|                 example_integer=example_request.example_integer, |                 example_integer=example_request.example_integer, | ||||||
| @@ -52,44 +55,32 @@ class ExampleService(TestBase): | |||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| async def test_calls_with_different_cardinalities(): | async def test_calls_with_different_cardinalities(): | ||||||
|     test_string = "test string" |     example_request = ExampleRequest("test string", 42) | ||||||
|     test_int = 42 |  | ||||||
|  |  | ||||||
|     async with ChannelFor([ExampleService()]) as channel: |     async with ChannelFor([ExampleService()]) as channel: | ||||||
|         stub = TestStub(channel) |         stub = TestStub(channel) | ||||||
|  |  | ||||||
|         # unary unary |         # unary unary | ||||||
|         response = await stub.example_unary_unary( |         response = await stub.example_unary_unary(example_request) | ||||||
|             example_string="test string", |         assert response.example_string == example_request.example_string | ||||||
|             example_integer=42, |         assert response.example_integer == example_request.example_integer | ||||||
|         ) |  | ||||||
|         assert response.example_string == test_string |  | ||||||
|         assert response.example_integer == test_int |  | ||||||
|  |  | ||||||
|         # unary stream |         # unary stream | ||||||
|         async for response in stub.example_unary_stream( |         async for response in stub.example_unary_stream(example_request): | ||||||
|             example_string="test string", |             assert response.example_string == example_request.example_string | ||||||
|             example_integer=42, |             assert response.example_integer == example_request.example_integer | ||||||
|         ): |  | ||||||
|             assert response.example_string == test_string |  | ||||||
|             assert response.example_integer == test_int |  | ||||||
|  |  | ||||||
|         # stream unary |         # stream unary | ||||||
|         request = ExampleRequest( |  | ||||||
|             example_string=test_string, |  | ||||||
|             example_integer=42, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         async def request_iterator(): |         async def request_iterator(): | ||||||
|             yield request |             yield example_request | ||||||
|             yield request |             yield example_request | ||||||
|             yield request |             yield example_request | ||||||
|  |  | ||||||
|         response = await stub.example_stream_unary(request_iterator()) |         response = await stub.example_stream_unary(request_iterator()) | ||||||
|         assert response.example_string == test_string |         assert response.example_string == example_request.example_string | ||||||
|         assert response.example_integer == test_int |         assert response.example_integer == example_request.example_integer | ||||||
|  |  | ||||||
|         # stream stream |         # stream stream | ||||||
|         async for response in stub.example_stream_stream(request_iterator()): |         async for response in stub.example_stream_stream(request_iterator()): | ||||||
|             assert response.example_string == test_string |             assert response.example_string == example_request.example_string | ||||||
|             assert response.example_integer == test_int |             assert response.example_integer == example_request.example_integer | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package field_name_identical_to_type; | ||||||
|  |  | ||||||
| // Tests that messages may contain fields with names that are identical to their python types (PR #294) | // Tests that messages may contain fields with names that are identical to their python types (PR #294) | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package fixed; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|   fixed32 foo = 1; |   fixed32 foo = 1; | ||||||
|   sfixed32 bar = 2; |   sfixed32 bar = 2; | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package float; | ||||||
|  |  | ||||||
| // Some documentation about the Test message. | // Some documentation about the Test message. | ||||||
| message Test { | message Test { | ||||||
|     double positive = 1; |     double positive = 1; | ||||||
|   | |||||||
| @@ -1,13 +1,17 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| message Foo{ | package google_impl_behavior_equivalence; | ||||||
|   int64 bar = 1; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| message Test{ | message Foo { int64 bar = 1; } | ||||||
|   oneof group{ |  | ||||||
|  | message Test { | ||||||
|  |   oneof group { | ||||||
|     string string = 1; |     string string = 1; | ||||||
|     int64 integer = 2; |     int64 integer = 2; | ||||||
|     Foo foo = 3; |     Foo foo = 3; | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | message Request { Empty foo = 1; } | ||||||
|  |  | ||||||
|  | message Empty {} | ||||||
| @@ -1,19 +1,22 @@ | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from google.protobuf import json_format | from google.protobuf import json_format | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from tests.output_betterproto.google_impl_behavior_equivalence import ( | from tests.output_betterproto.google_impl_behavior_equivalence import ( | ||||||
|     Test, |     Empty, | ||||||
|     Foo, |     Foo, | ||||||
|  |     Request, | ||||||
|  |     Test, | ||||||
| ) | ) | ||||||
| from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( | from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( | ||||||
|     Test as ReferenceTest, |     Empty as ReferenceEmpty, | ||||||
|     Foo as ReferenceFoo, |     Foo as ReferenceFoo, | ||||||
|  |     Request as ReferenceRequest, | ||||||
|  |     Test as ReferenceTest, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_oneof_serializes_similar_to_google_oneof(): | def test_oneof_serializes_similar_to_google_oneof(): | ||||||
|  |  | ||||||
|     tests = [ |     tests = [ | ||||||
|         (Test(string="abc"), ReferenceTest(string="abc")), |         (Test(string="abc"), ReferenceTest(string="abc")), | ||||||
|         (Test(integer=2), ReferenceTest(integer=2)), |         (Test(integer=2), ReferenceTest(integer=2)), | ||||||
| @@ -30,7 +33,6 @@ def test_oneof_serializes_similar_to_google_oneof(): | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_bytes_are_the_same_for_oneof(): | def test_bytes_are_the_same_for_oneof(): | ||||||
|  |  | ||||||
|     message = Test(string="") |     message = Test(string="") | ||||||
|     message_reference = ReferenceTest(string="") |     message_reference = ReferenceTest(string="") | ||||||
|  |  | ||||||
| @@ -48,8 +50,23 @@ def test_bytes_are_the_same_for_oneof(): | |||||||
|  |  | ||||||
|     # None of these fields were explicitly set BUT they should not actually be null |     # None of these fields were explicitly set BUT they should not actually be null | ||||||
|     # themselves |     # themselves | ||||||
|     assert isinstance(message.foo, Foo) |     assert not hasattr(message, "foo") | ||||||
|     assert isinstance(message2.foo, Foo) |     assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER | ||||||
|  |     assert not hasattr(message2, "foo") | ||||||
|  |     assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER | ||||||
|  |  | ||||||
|     assert isinstance(message_reference.foo, ReferenceFoo) |     assert isinstance(message_reference.foo, ReferenceFoo) | ||||||
|     assert isinstance(message_reference2.foo, ReferenceFoo) |     assert isinstance(message_reference2.foo, ReferenceFoo) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_empty_message_field(): | ||||||
|  |     message = Request() | ||||||
|  |     reference_message = ReferenceRequest() | ||||||
|  |  | ||||||
|  |     message.foo = Empty() | ||||||
|  |     reference_message.foo.CopyFrom(ReferenceEmpty()) | ||||||
|  |  | ||||||
|  |     assert betterproto.serialized_on_wire(message.foo) | ||||||
|  |     assert reference_message.HasField("foo") | ||||||
|  |  | ||||||
|  |     assert bytes(message) == reference_message.SerializeToString() | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes; | ||||||
|  |  | ||||||
| import "google/protobuf/duration.proto"; | import "google/protobuf/duration.proto"; | ||||||
| import "google/protobuf/timestamp.proto"; | import "google/protobuf/timestamp.proto"; | ||||||
| import "google/protobuf/wrappers.proto"; | import "google/protobuf/wrappers.proto"; | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								tests/inputs/googletypes_request/googletypes_request.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/inputs/googletypes_request/googletypes_request.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_request; | ||||||
|  |  | ||||||
|  | import "google/protobuf/duration.proto"; | ||||||
|  | import "google/protobuf/empty.proto"; | ||||||
|  | import "google/protobuf/timestamp.proto"; | ||||||
|  | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
|  | // Tests that google types can be used as params | ||||||
|  |  | ||||||
|  | service Test { | ||||||
|  |     rpc SendDouble (google.protobuf.DoubleValue) returns (Input); | ||||||
|  |     rpc SendFloat (google.protobuf.FloatValue) returns (Input); | ||||||
|  |     rpc SendInt64 (google.protobuf.Int64Value) returns (Input); | ||||||
|  |     rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input); | ||||||
|  |     rpc SendInt32 (google.protobuf.Int32Value) returns (Input); | ||||||
|  |     rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input); | ||||||
|  |     rpc SendBool (google.protobuf.BoolValue) returns (Input); | ||||||
|  |     rpc SendString (google.protobuf.StringValue) returns (Input); | ||||||
|  |     rpc SendBytes (google.protobuf.BytesValue) returns (Input); | ||||||
|  |     rpc SendDatetime (google.protobuf.Timestamp) returns (Input); | ||||||
|  |     rpc SendTimedelta (google.protobuf.Duration) returns (Input); | ||||||
|  |     rpc SendEmpty (google.protobuf.Empty) returns (Input); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Input { | ||||||
|  |  | ||||||
|  | } | ||||||
							
								
								
									
										47
									
								
								tests/inputs/googletypes_request/test_googletypes_request.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								tests/inputs/googletypes_request/test_googletypes_request.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | from datetime import ( | ||||||
|  |     datetime, | ||||||
|  |     timedelta, | ||||||
|  | ) | ||||||
|  | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | import betterproto.lib.google.protobuf as protobuf | ||||||
|  | from tests.mocks import MockChannel | ||||||
|  | from tests.output_betterproto.googletypes_request import ( | ||||||
|  |     Input, | ||||||
|  |     TestStub, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | test_cases = [ | ||||||
|  |     (TestStub.send_double, protobuf.DoubleValue, 2.5), | ||||||
|  |     (TestStub.send_float, protobuf.FloatValue, 2.5), | ||||||
|  |     (TestStub.send_int64, protobuf.Int64Value, -64), | ||||||
|  |     (TestStub.send_u_int64, protobuf.UInt64Value, 64), | ||||||
|  |     (TestStub.send_int32, protobuf.Int32Value, -32), | ||||||
|  |     (TestStub.send_u_int32, protobuf.UInt32Value, 32), | ||||||
|  |     (TestStub.send_bool, protobuf.BoolValue, True), | ||||||
|  |     (TestStub.send_string, protobuf.StringValue, "string"), | ||||||
|  |     (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), | ||||||
|  |     (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)), | ||||||
|  |     (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)), | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
|  | async def test_channel_receives_wrapped_type( | ||||||
|  |     service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value | ||||||
|  | ): | ||||||
|  |     wrapped_value = wrapper_class() | ||||||
|  |     wrapped_value.value = value | ||||||
|  |     channel = MockChannel(responses=[Input()]) | ||||||
|  |     service = TestStub(channel) | ||||||
|  |  | ||||||
|  |     await service_method(service, wrapped_value) | ||||||
|  |  | ||||||
|  |     assert channel.requests[0]["request"] == type(wrapped_value) | ||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_response; | ||||||
|  |  | ||||||
| import "google/protobuf/wrappers.proto"; | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
| // Tests that wrapped values can be used directly as return values | // Tests that wrapped values can be used directly as return values | ||||||
|   | |||||||
| @@ -1,10 +1,18 @@ | |||||||
| from typing import Any, Callable, Optional | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  |     Optional, | ||||||
|  | ) | ||||||
|  |  | ||||||
| import betterproto.lib.google.protobuf as protobuf |  | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
|  | import betterproto.lib.google.protobuf as protobuf | ||||||
| from tests.mocks import MockChannel | from tests.mocks import MockChannel | ||||||
| from tests.output_betterproto.googletypes_response import TestStub | from tests.output_betterproto.googletypes_response import ( | ||||||
|  |     Input, | ||||||
|  |     TestStub, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| test_cases = [ | test_cases = [ | ||||||
|     (TestStub.get_double, protobuf.DoubleValue, 2.5), |     (TestStub.get_double, protobuf.DoubleValue, 2.5), | ||||||
| @@ -22,14 +30,15 @@ test_cases = [ | |||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
| async def test_channel_receives_wrapped_type( | async def test_channel_receives_wrapped_type( | ||||||
|     service_method: Callable[[TestStub], Any], wrapper_class: Callable, value |     service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value | ||||||
| ): | ): | ||||||
|     wrapped_value = wrapper_class() |     wrapped_value = wrapper_class() | ||||||
|     wrapped_value.value = value |     wrapped_value.value = value | ||||||
|     channel = MockChannel(responses=[wrapped_value]) |     channel = MockChannel(responses=[wrapped_value]) | ||||||
|     service = TestStub(channel) |     service = TestStub(channel) | ||||||
|  |     method_param = Input() | ||||||
|  |  | ||||||
|     await service_method(service) |     await service_method(service, method_param) | ||||||
|  |  | ||||||
|     assert channel.requests[0]["response_type"] != Optional[type(value)] |     assert channel.requests[0]["response_type"] != Optional[type(value)] | ||||||
|     assert channel.requests[0]["response_type"] == type(wrapped_value) |     assert channel.requests[0]["response_type"] == type(wrapped_value) | ||||||
| @@ -39,7 +48,7 @@ async def test_channel_receives_wrapped_type( | |||||||
| @pytest.mark.xfail | @pytest.mark.xfail | ||||||
| @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
| async def test_service_unwraps_response( | async def test_service_unwraps_response( | ||||||
|     service_method: Callable[[TestStub], Any], wrapper_class: Callable, value |     service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     grpclib does not unwrap wrapper values returned by services |     grpclib does not unwrap wrapper values returned by services | ||||||
| @@ -47,8 +56,9 @@ async def test_service_unwraps_response( | |||||||
|     wrapped_value = wrapper_class() |     wrapped_value = wrapper_class() | ||||||
|     wrapped_value.value = value |     wrapped_value.value = value | ||||||
|     service = TestStub(MockChannel(responses=[wrapped_value])) |     service = TestStub(MockChannel(responses=[wrapped_value])) | ||||||
|  |     method_param = Input() | ||||||
|  |  | ||||||
|     response_value = await service_method(service) |     response_value = await service_method(service, method_param) | ||||||
|  |  | ||||||
|     assert response_value == value |     assert response_value == value | ||||||
|     assert type(response_value) == type(value) |     assert type(response_value) == type(value) | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_response_embedded; | ||||||
|  |  | ||||||
| import "google/protobuf/wrappers.proto"; | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
| // Tests that wrapped values are supported as part of output message | // Tests that wrapped values are supported as part of output message | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ import pytest | |||||||
|  |  | ||||||
| from tests.mocks import MockChannel | from tests.mocks import MockChannel | ||||||
| from tests.output_betterproto.googletypes_response_embedded import ( | from tests.output_betterproto.googletypes_response_embedded import ( | ||||||
|  |     Input, | ||||||
|     Output, |     Output, | ||||||
|     TestStub, |     TestStub, | ||||||
| ) | ) | ||||||
| @@ -26,7 +27,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response(): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     service = TestStub(MockChannel(responses=[output])) |     service = TestStub(MockChannel(responses=[output])) | ||||||
|     response = await service.get_output() |     response = await service.get_output(Input()) | ||||||
|  |  | ||||||
|     assert response.double_value == 10.0 |     assert response.double_value == 10.0 | ||||||
|     assert response.float_value == 12.0 |     assert response.float_value == 12.0 | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_service_returns_empty; | ||||||
|  |  | ||||||
| import "google/protobuf/empty.proto"; | import "google/protobuf/empty.proto"; | ||||||
|  |  | ||||||
| service Test { | service Test { | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_service_returns_googletype; | ||||||
|  |  | ||||||
| import "google/protobuf/empty.proto"; | import "google/protobuf/empty.proto"; | ||||||
| import "google/protobuf/struct.proto"; | import "google/protobuf/struct.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_struct; | ||||||
|  |  | ||||||
| import "google/protobuf/struct.proto"; | import "google/protobuf/struct.proto"; | ||||||
|  |  | ||||||
| message Test { | message Test { | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_value; | ||||||
|  |  | ||||||
| import "google/protobuf/struct.proto"; | import "google/protobuf/struct.proto"; | ||||||
|  |  | ||||||
| // Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values. | // Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values. | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  |  | ||||||
| package Capitalized; | package import_capitalized_package.Capitalized; | ||||||
|  |  | ||||||
| message Message { | message Message { | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_capitalized_package; | ||||||
|  |  | ||||||
| import "capitalized.proto"; | import "capitalized.proto"; | ||||||
|  |  | ||||||
| // Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. | // Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package package.childpackage; | package import_child_package_from_package.package.childpackage; | ||||||
|  |  | ||||||
| message ChildMessage { | message ChildMessage { | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_child_package_from_package; | ||||||
|  |  | ||||||
| import "package_message.proto"; | import "package_message.proto"; | ||||||
|  |  | ||||||
| // Tests generated imports when a message in a package refers to a message in a nested child package. | // Tests generated imports when a message in a package refers to a message in a nested child package. | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ syntax = "proto3"; | |||||||
|  |  | ||||||
| import "child.proto"; | import "child.proto"; | ||||||
|  |  | ||||||
| package package; | package import_child_package_from_package.package; | ||||||
|  |  | ||||||
| message PackageMessage { | message PackageMessage { | ||||||
|     package.childpackage.ChildMessage c = 1; |     package.childpackage.ChildMessage c = 1; | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package childpackage; | package import_child_package_from_root.childpackage; | ||||||
|  |  | ||||||
| message Message { | message Message { | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_child_package_from_root; | ||||||
|  |  | ||||||
| import "child.proto"; | import "child.proto"; | ||||||
|  |  | ||||||
| // Tests generated imports when a message in root refers to a message in a child package. | // Tests generated imports when a message in root refers to a message in a child package. | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_circular_dependency; | ||||||
|  |  | ||||||
| import "root.proto"; | import "root.proto"; | ||||||
| import "other.proto"; | import "other.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| import "root.proto"; | import "root.proto"; | ||||||
| package other; | package import_circular_dependency.other; | ||||||
|  |  | ||||||
| message OtherPackageMessage { | message OtherPackageMessage { | ||||||
|     RootPackageMessage rootPackageMessage = 1; |     RootPackageMessage rootPackageMessage = 1; | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_circular_dependency; | ||||||
|  |  | ||||||
| message RootPackageMessage { | message RootPackageMessage { | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package cousin.cousin_subpackage; | package import_cousin_package.cousin.cousin_subpackage; | ||||||
|  |  | ||||||
| message CousinMessage { | message CousinMessage { | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package test.subpackage; | package import_cousin_package.test.subpackage; | ||||||
|  |  | ||||||
| import "cousin.proto"; | import "cousin.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package cousin.subpackage; | package import_cousin_package_same_name.cousin.subpackage; | ||||||
|  |  | ||||||
| message CousinMessage { | message CousinMessage { | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package test.subpackage; | package import_cousin_package_same_name.test.subpackage; | ||||||
|  |  | ||||||
| import "cousin.proto"; | import "cousin.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_packages_same_name; | ||||||
|  |  | ||||||
| import "users_v1.proto"; | import "users_v1.proto"; | ||||||
| import "posts_v1.proto"; | import "posts_v1.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package posts.v1; | package import_packages_same_name.posts.v1; | ||||||
|  |  | ||||||
| message Post { | message Post { | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package users.v1; | package import_packages_same_name.users.v1; | ||||||
|  |  | ||||||
| message User { | message User { | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ syntax = "proto3"; | |||||||
|  |  | ||||||
| import "parent_package_message.proto"; | import "parent_package_message.proto"; | ||||||
|  |  | ||||||
| package parent.child; | package import_parent_package_from_child.parent.child; | ||||||
|  |  | ||||||
| // Tests generated imports when a message refers to a message defined in its parent package | // Tests generated imports when a message refers to a message defined in its parent package | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package parent; | package import_parent_package_from_child.parent; | ||||||
|  |  | ||||||
| message ParentPackageMessage { | message ParentPackageMessage { | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
| package child; | package import_root_package_from_child.child; | ||||||
|  |  | ||||||
| import "root.proto"; | import "root.proto"; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_root_package_from_child; | ||||||
|  |  | ||||||
|  |  | ||||||
| message RootMessage { | message RootMessage { | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| syntax = "proto3"; | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package import_root_sibling; | ||||||
|  |  | ||||||
| import "sibling.proto"; | import "sibling.proto"; | ||||||
|  |  | ||||||
| // Tests generated imports when a message in the root package refers to another message in the root package | // Tests generated imports when a message in the root package refers to another message in the root package | ||||||
|   | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user