85 Commits

Author SHA1 Message Date
Georg K
87b84afc4b Merge branch 'rust_extras' into next 2023-09-06 14:34:14 +03:00
Erik Friese
8283ef7298 bugfix
map fields with values of message type were not serialized correctly
2023-09-05 21:58:58 +02:00
Erik Friese
0931eb3bf5 bugfix
byte fields were deserialized incorrectly
2023-09-05 20:26:46 +02:00
Erik Friese
8f535913a1 bugfix
using python identifier in message names
necessary to distinguish between dynamically created classes with same name
2023-09-05 20:06:00 +02:00
Erik Friese
fd02cb6180 supporting datetime and timedelta 2023-09-05 11:27:04 +02:00
Erik Friese
950d2f6536 google wrapper types 2023-09-04 21:09:06 +02:00
Erik Friese
29f12ea88d map support 2023-09-04 12:52:12 +02:00
Erik Friese
219233b50e bugfix: parsing unknown fields properly 2023-08-31 17:57:28 +02:00
Erik Friese
2d30bdb7b2 Merge branch 'master' into rust_extras 2023-08-31 17:39:34 +02:00
Erik Friese
bdd3389b17 bugfix in proto descriptor creation
reference cycles in betterproto messages have led to infinite recursion
2023-08-31 17:17:28 +02:00
Erik Friese
441844b97a avoiding name clash 2023-08-31 13:18:07 +02:00
Georg K
aa81680c83 Merge branch 'master_gh' 2023-08-31 00:51:55 +03:00
Erik Friese
a413d08fc1 enum support 2023-08-30 21:06:32 +02:00
Erik Friese
24d694afe2 storing unknown fields 2023-08-30 15:49:25 +02:00
Erik Friese
84af157122 minor refactoring 2023-08-30 15:39:34 +02:00
Erik Friese
df0c17bf0a optional support 2023-08-29 18:08:59 +02:00
Joshua Leivers
8659c51123 Add message streaming support (#518) 2023-08-29 14:26:25 +01:00
Erik Friese
d1825026db lock file reverted to master 2023-08-27 18:46:20 +02:00
Erik Friese
a12c9d24de oneof support 2023-08-27 14:37:23 +02:00
Erik Friese
d79a9eee14 deserializing lists 2023-08-27 13:22:51 +02:00
Erik Friese
d848d05710 minor optimization 2023-08-26 21:47:35 +02:00
Erik Friese
26da86d2cd proper error handling 2023-08-26 21:32:35 +02:00
Erik Friese
604dcb104f type info + doc string added 2023-08-26 21:04:30 +02:00
Erik Friese
421aa78014 Native deserialization based on Rust and PyO3
Proof of concept
Only capable of deserializing (nested) Messages with primitive fields
No handling of lists, maps, enums, .. implemented yet
See `example.py` for a working example
2023-08-26 13:04:26 +02:00
Georg K
0fda2cc05d Merge branch 'master_gh' 2023-08-11 02:56:45 +03:00
Andrew
4cdf1bb9e0 Fix Message equality comparison (#513) 2023-07-29 12:06:56 +01:00
Georg K
d203659a44 Merge branch 'master_gh' 2023-07-27 00:06:35 +03:00
Alexander Khabarov
6faac1d1ca Raise AttributeError on attempts to access unset oneof fields (#510) 2023-07-21 13:26:30 +01:00
Ashwin Madavan
098989e9e9 Bump version to 2.0.0b6 (#500)
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
2023-06-26 00:12:49 +01:00
Ollie
182aedaec4 Handle empty value objects properly (#481)
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
2023-06-24 20:19:13 +01:00
konstantin
a7532bbadc Add Python 3.11 to CI Runs (#445)
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
2023-06-24 19:49:34 +01:00
Alexander Khabarov
73d1fa3d5b Upgrade grpcio-tools and protobuf (#498) 2023-06-24 19:39:11 +01:00
Sriansh Raj Pradhan
c00bc96db7 Create LICENSE.md (#502) 2023-06-16 19:19:51 +01:00
Georg K
d3e9621aa8 Merge branch 'master_gh' 2023-05-29 17:18:33 +03:00
Nick DeRobertis
fcbd8a3759 Fix pydict serialization for optional fields (#495) 2023-05-28 17:47:52 +01:00
pi-slh
aad7d2ad76 Replace pkg_resources with importlib (#462) 2023-05-25 11:12:15 +01:00
Georg K
37e53fce85 Merge branch 'master_gh' 2023-05-05 03:27:56 +03:00
Jinyu Liu
2b41383745 Fix dict encoding for timezone aware datetimes (#468)
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
2023-04-13 23:34:19 +01:00
Marek Pikuła
b0b6cd24ad Fix pydantic_dataclasses reference in README (#474) 2023-04-13 21:56:38 +01:00
Georg K
b81195eb44 fix: _Timestamp.to_datetime works with negative ts 2023-04-05 13:02:25 +03:00
Georg K
d2af2f2fac Merge branch 'master_gh' 2023-04-05 12:51:02 +03:00
James Hilton-Balfe
e7f07fa2a1 Update __init__.py (#451) 2023-03-08 08:20:56 +00:00
Georg K
50fa4e6268 fix: protoc to local 2023-03-02 21:55:15 +03:00
Samuel Yvon
2fa0be2141 Fix for #459 (pydantic code gen only) (#460) 2023-02-21 19:41:32 +00:00
Samuel Yvon
13d656587c Add support for pydantic dataclasses (#406) 2023-02-13 15:37:16 +00:00
James Hilton-Balfe
6df8cef3f0 Fix CI (#456) 2023-02-13 00:20:58 +00:00
James Hilton-Balfe
1b1bd47cb1 Drop support for python3.6 (#444) 2023-02-09 08:35:41 +00:00
Wouter Horré
0adcc9020c Pythonize input_type name in py_input_message (#436)
Co-authored-by: konstantin <konstantin.klein@hochfrequenz.de>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
Fixes https://github.com/danielgtaylor/python-betterproto/issues/427
Fixes https://github.com/danielgtaylor/python-betterproto/issues/438
2022-12-02 22:18:48 +00:00
sterliakov
bfc0fac754 Enforce serialize_empty for repeated fields (#417) 2022-08-31 18:59:12 +01:00
Antonín Říha
8fbf4476a8 Fix typechecker compatiblity checks in server streaming methods (#413) 2022-08-31 00:05:29 +01:00
Samuel Yvon
591ec5efb3 Pull down the include_default_values argument to to_json (#405) 2022-08-08 14:26:28 +01:00
Antonín Říha
f31d51cf3c Added support for @generated marker (#382) 2022-08-03 11:05:13 +01:00
James Hilton-Balfe
496eba2750 Bump version to b5 (#404) 2022-08-02 09:23:44 +10:00
James Hilton-Balfe
d663a318b7 Release v.2.0.0b5 (#350)
* Implement Message.__bool__ for #130
* Add __bool__ to special members
* Tweak __bool__ docstring
* remove compiler: prefix

Co-authored-by: nat <n@natn.me>
2022-08-02 08:59:44 +10:00
Justin Torre
2fb37dd108 Update Jinja 2 version (#402) 2022-08-01 10:44:37 +01:00
Vasile Razdalovschi
42d2df6de6 Fix broken link in readme to tests (#400) 2022-07-14 13:14:05 +01:00
James Hilton-Balfe
3fd5a0d662 Fix parameters missing from services (#381) 2022-07-06 19:05:40 +01:00
MatejKastak
bc13e7070d Fix link to google files (#392) 2022-06-16 16:08:39 +01:00
Flynn
6536181902 Add to/from_pydict methods (#203)
* add to/from_pydict methods

* Remove unnecessary method call

* Fix formatting

Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
2022-05-09 17:34:12 +01:00
GrownNed
85e4be96d8 fix Message.to_dict mutating the underlying Message (#378)
* [fix] to_dict modifies the underlying message (#151)

* add test for mapmessage

* fix for to_dict

* formatting

* Apply suggestions from code review

Co-authored-by: Arun Babu Neelicattu <arun.neelicattu@gmail.com>

* change to_json to to_dict

Co-authored-by: Arun Babu Neelicattu <arun.neelicattu@gmail.com>
2022-05-09 17:29:42 +01:00
efokschaner
06c26ba60d tests: Lazy evaluate Deadline parameter in pytest (#380)
The pytest parameters are evaluated when the tests are loading.
The Deadline.from_timeout is a fixed point in time. 
By deferring the evaluation it helps ensure that the deadline is not reached before the test is executed.
2022-04-30 23:11:21 +01:00
James Hilton-Balfe
6a70b8e8ea compiler: do not overwrite top level __init__.py
Resolves: #168 #260
2022-04-24 01:44:16 +02:00
James Hilton-Balfe
3ca092a724 Fix is_set for optional proto3 fields 2022-04-24 01:37:15 +02:00
James Hilton-Balfe
6f7d706a8e fix float nan comparison
This change also adds minor docstrings fixes and bumps pre-commit blac
version.
2022-04-24 01:13:58 +02:00
James Hilton-Balfe
ac96d8254b Add a minimal repro for #312 (#340) 2022-04-22 11:04:40 +01:00
Max
e7133adeb3 fix: map field edge-case
This change ensures a parent is a nested type when checking if a field is a map.
2022-04-22 11:06:44 +02:00
Pavel Savchenko
204e04dd69 Update README.md
linebreak after bullet-point, and remove redundant conjunction at the start of the section
2022-04-22 10:57:35 +02:00
James Hilton-Balfe
b9b0b22d57 Make Message.__getattribute__ invisible to type checkers (#359)
This lets linters know that we shouldn't access fields that aren't actually defined
2022-04-21 15:44:55 +01:00
James Hilton-Balfe
402c21256f Fix unicodefun import error in black (#366) 2022-04-16 22:15:14 +01:00
Gabriel Pajot
5f7e4d58ef Fix documentation for nested enums (#351) 2022-03-18 22:36:27 +00:00
Arun Babu Neelicattu
1aaf7728cc compiler: Run isort on compiled code (#355) 2022-03-18 22:29:42 +00:00
Arun Babu Neelicattu
70310c9e8c pre-commit: add isort hook and apply (#354) 2022-03-17 00:01:17 +00:00
Arun Babu Neelicattu
18a518efa7 Expose timeout, deadline and metadata parameters from grpclib (#352) 2022-03-13 22:34:11 +00:00
Arun Babu Neelicattu
62da35b3ea parser: ensure prefix is separated when traversing (#353) 2022-03-12 09:08:03 +00:00
Arun Babu Neelicattu
69f4192341 Fix incorrect deprecation warnings on defaults (#348)
This change ensures that deprecation warnings are only raised when
either a deprecated field is explicitly set or a deprecated message is
initialised.

Resolves: #347
2022-03-11 23:36:14 +00:00
Arun Babu Neelicattu
9c1bf25304 tests.generate: stop using asyncio.get_event_loop (#349)
The use of `asyncio.get_event_loop()` has been deprecated in python 3.10+. We replace this usage with `asyncio.run()` for python 3.7+.
2022-03-03 18:11:57 +00:00
Arun Babu Neelicattu
a836fb23bc Configure pre-commit for project (#346) 2022-03-03 18:10:01 +00:00
Arun Babu Neelicattu
bd69862a02 test input: use explicit package declaration (#345) 2022-03-03 13:34:53 +00:00
James Hilton-Balfe
74205e3319 Implement __deepcopy__ for Message (#339) 2022-02-16 23:12:51 +00:00
James Hilton-Balfe
3f377e3bfd Remove the poetry.lock (#338) 2022-02-15 15:37:47 +00:00
Eitan Mosenkis
8c727d904f Fix from_dict() in the presence of optional datetime fields. (#329) 2022-02-03 09:00:56 +00:00
Eitan Mosenkis
eeddc844a5 Bump Jinja2 to 3.0.3. (#330) 2022-02-01 08:32:25 +00:00
Michael Osthege
9b5594adbe Format field comments also as docstrings (#304)
Closes #303

* Format field comments also as docstrings
To make it clear that they refer to the item above.
* Fix placement of enum item docstrings
* Add line breaks after class attribute or enum item docstrings
2022-01-27 09:25:48 +11:00
Danil Akhtarov
d991040ff6 Fix message text in NotImplementedError (#325) 2022-01-21 11:39:09 +00:00
efokschaner
d260f071e0 Client and Service Stubs take 1 request parameter, not one for each field (#311) 2022-01-17 19:58:57 +01:00
161 changed files with 6060 additions and 2211 deletions

View File

@@ -13,17 +13,15 @@ jobs:
name: ${{ matrix.os }} / ${{ matrix.python-version }} name: ${{ matrix.os }} / ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}-latest runs-on: ${{ matrix.os }}-latest
strategy: strategy:
fail-fast: false
matrix: matrix:
os: [Ubuntu, MacOS, Windows] os: [Ubuntu, MacOS, Windows]
python-version: ['3.6.7', '3.7', '3.8', '3.9', '3.10'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
exclude:
- os: Windows
python-version: 3.6
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@@ -43,7 +41,7 @@ jobs:
run: poetry config virtualenvs.in-project true run: poetry config virtualenvs.in-project true
- name: Set up cache - name: Set up cache
uses: actions/cache@v2 uses: actions/cache@v3
id: cache id: cache
with: with:
path: .venv path: .venv
@@ -56,9 +54,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: bash shell: bash
run: | run: poetry install -E compiler
poetry run python -m pip install pip -U
poetry install
- name: Generate code from proto files - name: Generate code from proto files
shell: bash shell: bash

View File

@@ -13,14 +13,6 @@ jobs:
name: Check code/doc formatting name: Check code/doc formatting
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Run Black - uses: actions/setup-python@v4
uses: lgeiger/black-action@master - uses: pre-commit/action@v2.0.3
with:
args: --check src/ tests/ benchmarks/
- name: Install rST dependcies
run: python -m pip install doc8
- name: Lint documentation for errors
run: python -m doc8 docs --max-line-length 88 --ignore-path-errors "docs/migrating.rst;D001"
# it has a table which is longer than 88 characters long

View File

@@ -15,9 +15,9 @@ jobs:
name: Distribution name: Distribution
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.8 - name: Set up Python 3.8
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: 3.8 python-version: 3.8
- name: Install poetry - name: Install poetry

1
.gitignore vendored
View File

@@ -17,3 +17,4 @@ output
.venv .venv
.asv .asv
venv venv
.devcontainer

21
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,21 @@
ci:
autofix_prs: false
repos:
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: ["--target-version", "py310"]
- repo: https://github.com/PyCQA/doc8
rev: 0.10.1
hooks:
- id: doc8
additional_dependencies:
- toml

View File

@@ -7,6 +7,78 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`. - Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [2.0.0b6] - 2023-06-25
- **Breaking**: the minimum Python version has been bumped to `3.7` [#444](https://github.com/danielgtaylor/python-betterproto/pull/444)
- Support generating [Pydantic dataclasses](https://docs.pydantic.dev/latest/usage/dataclasses).
Pydantic dataclasses are are drop-in replacement for dataclasses in the standard library that additionally supports validation.
Pass `--python_betterproto_opt=pydantic_dataclasses` to enable this feature.
Refer to [#406](https://github.com/danielgtaylor/python-betterproto/pull/406)
and [README.md](https://github.com/danielgtaylor/python-betterproto#generating-pydantic-models) for more information.
- Added support for `@generated` marker [#382](https://github.com/danielgtaylor/python-betterproto/pull/382)
- Pull down the `include_default_values` argument to `to_json()` [#405](https://github.com/danielgtaylor/python-betterproto/pull/405)
- Pythonize input_type name in py_input_message [#436](https://github.com/danielgtaylor/python-betterproto/pull/436)
- Widen `from_dict()` to accept any `Mapping` [#451](https://github.com/danielgtaylor/python-betterproto/pull/451)
- Replace `pkg_resources` with `importlib` [#462](https://github.com/danielgtaylor/python-betterproto/pull/462)
- Fix typechecker compatiblity checks in server streaming methods [#413](https://github.com/danielgtaylor/python-betterproto/pull/413)
- Fix "empty-valued" repeated fields not being serialised [#417](https://github.com/danielgtaylor/python-betterproto/pull/417)
- Fix `dict` encoding for timezone-aware `datetimes` [#468](https://github.com/danielgtaylor/python-betterproto/pull/468)
- Fix `to_pydict()` serialization for optional fields [#495](https://github.com/danielgtaylor/python-betterproto/pull/495)
- Handle empty value objects properly [#481](https://github.com/danielgtaylor/python-betterproto/pull/481)
## [2.0.0b5] - 2022-08-01
- **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters [#331](https://github.com/danielgtaylor/python-betterproto/pull/311)
Update your client calls and server handlers as follows:
Clients before:
```py
response = await service.echo(value="hello", extra_times=1)
```
Clients after:
```py
response = await service.echo(EchoRequest(value="hello", extra_times=1))
```
Servers before:
```py
async def echo(self, value: str, extra_times: int) -> EchoResponse: ...
```
Servers after:
```py
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
# Use echo_request.value
# Use echo_request.extra_times
...
```
- Add `to/from_pydict()` for `Message` [#203](https://github.com/danielgtaylor/python-betterproto/pull/203)
- Format field comments also as docstrings [#304](https://github.com/danielgtaylor/python-betterproto/pull/304)
- Implement `__deepcopy__` for `Message` [#339](https://github.com/danielgtaylor/python-betterproto/pull/339)
- Run isort on compiled code [#355](https://github.com/danielgtaylor/python-betterproto/pull/355)
- Expose timeout, deadline and metadata parameters from grpclib [#352](https://github.com/danielgtaylor/python-betterproto/pull/352)
- Make `Message.__getattribute__` invisible to type checkers [#359](https://github.com/danielgtaylor/python-betterproto/pull/359)
- Fix map field edge-case [#254](https://github.com/danielgtaylor/python-betterproto/pull/254)
- Fix message text in `NotImplementedError` [#325](https://github.com/danielgtaylor/python-betterproto/pull/325)
- Fix `Message.from_dict()` in the presence of optional datetime fields [#329](https://github.com/danielgtaylor/python-betterproto/pull/329)
- Support Jinja2 3.0 to prevent version conflicts [#330](https://github.com/danielgtaylor/python-betterproto/pull/330)
- Fix overwriting top level `__init__.py` [#337](https://github.com/danielgtaylor/python-betterproto/pull/337)
- Remove deprecation warnings when fields are initialised with non-default values [#348](https://github.com/danielgtaylor/python-betterproto/pull/348)
- Ensure nested class names are converted to PascalCase [#353](https://github.com/danielgtaylor/python-betterproto/pull/353)
- Fix `Message.to_dict()` mutating the underlying Message [#378](https://github.com/danielgtaylor/python-betterproto/pull/378)
- Fix some parameters being missing from services [#381](https://github.com/danielgtaylor/python-betterproto/pull/381)
## [2.0.0b4] - 2022-01-03 ## [2.0.0b4] - 2022-01-03
- **Breaking**: the minimum Python version has been bumped to `3.6.2` - **Breaking**: the minimum Python version has been bumped to `3.6.2`

21
LICENSE.md Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Daniel G. Taylor
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -7,13 +7,14 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
- Protobuf 3 & gRPC code generation - Protobuf 3 & gRPC code generation
- Both binary & JSON serialization is built-in - Both binary & JSON serialization is built-in
- Python 3.6+ making use of: - Python 3.7+ making use of:
- Enums - Enums
- Dataclasses - Dataclasses
- `async`/`await` - `async`/`await`
- Timezone-aware `datetime` and `timedelta` objects - Timezone-aware `datetime` and `timedelta` objects
- Relative imports - Relative imports
- Mypy type checking - Mypy type checking
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)
This project is heavily inspired by, and borrows functionality from: This project is heavily inspired by, and borrows functionality from:
@@ -38,6 +39,8 @@ This project exists because I am unhappy with the state of the official Google p
- Uses `SerializeToString()` rather than the built-in `__bytes__()` - Uses `SerializeToString()` rather than the built-in `__bytes__()`
- Special wrapped types don't use Python's `None` - Special wrapped types don't use Python's `None`
- Timestamp/duration types don't use Python's built-in `datetime` module - Timestamp/duration types don't use Python's built-in `datetime` module
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
## Installation ## Installation
@@ -58,7 +61,7 @@ pip install betterproto
### Compiling proto files ### Compiling proto files
Now, given you installed the compiler and have a proto file, e.g `example.proto`: Given you installed the compiler and have a proto file, e.g `example.proto`:
```protobuf ```protobuf
syntax = "proto3"; syntax = "proto3";
@@ -177,10 +180,10 @@ from grpclib.client import Channel
async def main(): async def main():
channel = Channel(host="127.0.0.1", port=50051) channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel) service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1) response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
print(response) print(response)
async for response in service.echo_stream(value="hello", extra_times=1): async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
print(response) print(response)
# don't forget to close the channel when done! # don't forget to close the channel when done!
@@ -192,6 +195,7 @@ if __name__ == "__main__":
loop.run_until_complete(main()) loop.run_until_complete(main())
``` ```
which would output which would output
```python ```python
EchoResponse(values=['hello', 'hello']) EchoResponse(values=['hello', 'hello'])
@@ -206,18 +210,18 @@ service methods:
```python ```python
import asyncio import asyncio
from echo import EchoBase, EchoResponse, EchoStreamResponse from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
from grpclib.server import Server from grpclib.server import Server
from typing import AsyncIterator from typing import AsyncIterator
class EchoService(EchoBase): class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse": async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
return EchoResponse([value for _ in range(extra_times)]) return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]: async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times): for _ in range(echo_request.extra_times):
yield EchoStreamResponse(value) yield EchoStreamResponse(echo_request.value)
async def main(): async def main():
@@ -361,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'} {'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
``` ```
## Generating Pydantic Models
You can use python-betterproto to generate pydantic based models, using
pydantic dataclasses. This means the results of the protobuf unmarshalling will
be typed checked. The usage is the same, but you need to add a custom option
when calling the protobuf compiler:
```
protoc -I . --python_betterproto_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
```
With the important change being `--python_betterproto_opt=pydantic_dataclasses`. This will
swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.
## Development ## Development
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_ - _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
@@ -368,7 +391,7 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
### Requirements ### Requirements
- Python (3.6 or higher) - Python (3.7 or higher)
- [poetry](https://python-poetry.org/docs/#installation) - [poetry](https://python-poetry.org/docs/#installation)
*Needed to install dependencies in a virtual environment* *Needed to install dependencies in a virtual environment*
@@ -381,8 +404,7 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
```sh ```sh
# Get set up with the virtual env & dependencies # Get set up with the virtual env & dependencies
poetry run pip install --upgrade pip poetry install -E compiler
poetry install
# Activate the poetry environment # Activate the poetry environment
poetry shell poetry shell
@@ -417,7 +439,7 @@ Adding a standard test case is easy.
It will be picked up automatically when you run the tests. It will be picked up automatically when you run the tests.
- See also: [Standard Tests Development Guide](betterproto/tests/README.md) - See also: [Standard Tests Development Guide](tests/README.md)
#### Custom tests #### Custom tests
@@ -442,7 +464,7 @@ poe full-test
### (Re)compiling Google Well-known Types ### (Re)compiling Google Well-known Types
Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google). Betterproto includes compiled versions for Google's well-known types at [src/betterproto/lib/google](src/betterproto/lib/google).
Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests. Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`.

View File

@@ -1 +0,0 @@

View File

@@ -1,8 +1,8 @@
import betterproto
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import betterproto
@dataclass @dataclass
class TestMessage(betterproto.Message): class TestMessage(betterproto.Message):

72
betterproto-extras/.gitignore vendored Normal file
View File

@@ -0,0 +1,72 @@
/target
# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
.DS_Store
# Sphinx documentation
docs/_build/
# PyCharm
.idea/
# VSCode
.vscode/
# Pyenv
.python-version

383
betterproto-extras/Cargo.lock generated Normal file
View File

@@ -0,0 +1,383 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "anyhow"
version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "betterproto-extras"
version = "0.1.0"
dependencies = [
"indoc 2.0.3",
"prost-reflect",
"pyo3",
"thiserror",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bytes"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "indoc"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
[[package]]
name = "indoc"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c785eefb63ebd0e33416dfcb8d6da0bf27ce752843a45632a67bf10d4d4b5c4"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "libc"
version = "0.2.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "lock_api"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "memoffset"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
name = "parking_lot"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets",
]
[[package]]
name = "proc-macro2"
version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
dependencies = [
"unicode-ident",
]
[[package]]
name = "prost"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive",
]
[[package]]
name = "prost-derive"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "prost-reflect"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b823de344848e011658ac981009100818b322421676740546f8b52ed5249428"
dependencies = [
"once_cell",
"prost",
"prost-types",
]
[[package]]
name = "prost-types"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost",
]
[[package]]
name = "pyo3"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38"
dependencies = [
"cfg-if",
"indoc 1.0.9",
"libc",
"memoffset",
"parking_lot",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 1.0.109",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "quote"
version = "1.0.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
dependencies = [
"proc-macro2",
]
[[package]]
name = "redox_syscall"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
dependencies = [
"bitflags",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "smallvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "target-lexicon"
version = "0.12.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a"
[[package]]
name = "thiserror"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "unicode-ident"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
[[package]]
name = "unindent"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"

View File

@@ -0,0 +1,14 @@
[package]
name = "betterproto-extras"
version = "0.1.0"
edition = "2021"
[lib]
name = "betterproto_extras"
crate-type = ["cdylib"]
[dependencies]
indoc = "2.0.3"
prost-reflect = "0.11.5"
pyo3 = { version = "0.19.2", features = ["abi3-py37", "extension-module"] }
thiserror = "1.0.47"

View File

@@ -0,0 +1,5 @@
def deserialize(msg, data: bytes):
"""
Parses the binary encoded Protobuf `data` with respect to the metadata
given by the betterproto message `msg`, and merges the result into `msg`.
"""

View File

@@ -0,0 +1,16 @@
[build-system]
requires = ["maturin>=1.2,<2.0"]
build-backend = "maturin"
[project]
name = "betterproto-extras"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
[tool.maturin]
features = ["pyo3/extension-module"]

View File

@@ -0,0 +1,289 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{
prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
FileDescriptorProto, MessageOptions, OneofDescriptorProto,
},
DescriptorPool, MessageDescriptor,
};
use pyo3::PyAny;
use std::sync::{Mutex, OnceLock};
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::global()))
.lock()
.unwrap();
let cls = obj.getattr("__class__")?;
let name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
if let Some(desc) = pool.get_message_by_name(&name) {
return Ok(desc);
}
let mut file = FileDescriptorProto {
name: Some(name.clone()),
..Default::default()
};
add_message_to_file(name.clone(), obj, &pool, &mut file)?;
pool.add_file_descriptor_proto(file)?;
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
}
fn add_message_to_file(
message_name: String,
obj: &PyAny,
pool: &DescriptorPool,
file: &mut FileDescriptorProto,
) -> Result<()> {
let mut messages_to_add = vec![(message_name, obj)];
while let Some((message_name, obj)) = messages_to_add.pop() {
let meta = obj.get_proto_meta()?;
let mut message = DescriptorProto {
name: Some(message_name.to_string()),
..Default::default()
};
for item in meta
.getattr("meta_by_field_name")?
.call_method0("items")?
.iter()?
{
let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
message.field.push({
let mut field = FieldDescriptorProto {
name: Some(field_name.to_string()),
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
..Default::default()
};
let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?;
if proto_type == "map" {
field.set_type(Type::Message);
let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?;
let key = map_type(key)?;
let val = map_type(val)?;
if matches!(
key,
Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum
) {
return Err(Error::UnsupportedMapKeyType(key));
}
let map_entry_name = format!("{field_name}Entry");
field.type_name = Some(format!("{message_name}.{map_entry_name}"));
field.set_label(Label::Repeated);
message.nested_type.push(DescriptorProto {
name: Some(map_entry_name),
field: vec![
{
let mut proto = FieldDescriptorProto {
name: Some("key".to_string()),
number: Some(1),
..Default::default()
};
proto.set_type(key);
proto
},
{
let mut proto = FieldDescriptorProto {
name: Some("value".to_string()),
number: Some(2),
..Default::default()
};
proto.set_type(val);
if val == Type::Message {
set_type_name(
&message_name,
meta.get_class(&format!("{field_name}.value"))?,
&mut proto,
file,
&mut messages_to_add,
pool,
)?;
}
proto
},
],
options: Some(MessageOptions {
map_entry: Some(true),
..Default::default()
}),
..Default::default()
})
} else {
field.set_type(map_type(proto_type)?);
match field.r#type() {
Type::Message => match field_meta
.getattr("wraps")?
.extract::<Option<&str>>()?
.map(map_type)
.transpose()?
{
Some(Type::Bool) => {
field.type_name = Some("google.protobuf.BoolValue".to_string());
}
Some(Type::Double) => {
field.type_name = Some("google.protobuf.DoubleValue".to_string());
}
Some(Type::Float) => {
field.type_name = Some("google.protobuf.FloatValue".to_string());
}
Some(Type::Int64) => {
field.type_name = Some("google.protobuf.Int64Value".to_string());
}
Some(Type::Uint64) => {
field.type_name = Some("google.protobuf.UInt64Value".to_string());
}
Some(Type::Int32) => {
field.type_name = Some("google.protobuf.Int32Value".to_string());
}
Some(Type::Uint32) => {
field.type_name = Some("google.protobuf.UInt32Value".to_string());
}
Some(Type::String) => {
field.type_name = Some("google.protobuf.StringValue".to_string());
}
Some(Type::Bytes) => {
field.type_name = Some("google.protobuf.BytesValue".to_string());
}
Some(t) => return Err(Error::UnsupportedWrapperType(t)),
None => {
set_type_name(
&message_name,
meta.get_class(field_name)?,
&mut field,
file,
&mut messages_to_add,
pool,
)?;
}
},
Type::Enum => {
let cls = meta.get_class(field_name)?;
let cls_name =
format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
field.type_name = Some(cls_name.to_string());
if pool.get_enum_by_name(&cls_name).is_none()
&& !file.enum_type.iter().any(|item| item.name() == cls_name)
{
let mut proto = EnumDescriptorProto {
name: Some(cls_name.clone()),
..Default::default()
};
for item in cls.iter()? {
let item = item?;
proto.value.push(EnumValueDescriptorProto {
number: Some(item.getattr("value")?.extract()?),
name: Some(format!(
"{}_{}",
cls_name,
item.getattr("name")?.extract::<&str>()?
)),
..Default::default()
});
}
file.enum_type.push(proto);
}
}
_ => {}
}
if meta.is_list_field(field_name)? {
field.set_label(Label::Repeated);
} else if field_meta.getattr("optional")?.extract::<bool>()? {
field.proto3_optional = Some(true);
}
}
if let Some(grp) = meta.oneof_group(field_name)? {
let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
match oneof_index {
Some(i) => field.oneof_index = Some(i as i32),
None => {
message.oneof_decl.push(OneofDescriptorProto {
name: Some(grp),
..Default::default()
});
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
}
}
}
field
});
}
file.message_type.push(message);
}
Ok(())
}
fn map_type(str: &str) -> Result<Type> {
match str {
"enum" => Ok(Type::Enum),
"bool" => Ok(Type::Bool),
"int32" => Ok(Type::Int32),
"int64" => Ok(Type::Int64),
"uint32" => Ok(Type::Uint32),
"uint64" => Ok(Type::Uint64),
"sint32" => Ok(Type::Sint32),
"sint64" => Ok(Type::Sint64),
"float" => Ok(Type::Float),
"double" => Ok(Type::Double),
"fixed32" => Ok(Type::Fixed32),
"sfixed32" => Ok(Type::Sfixed32),
"fixed64" => Ok(Type::Fixed64),
"sfixed64" => Ok(Type::Sfixed64),
"string" => Ok(Type::String),
"bytes" => Ok(Type::Bytes),
"message" => Ok(Type::Message),
_ => Err(Error::UnsupportedType(str.to_string())),
}
}
fn set_type_name<'py>(
message_name: &str,
field_cls: &'py PyAny,
field: &mut FieldDescriptorProto,
file: &FileDescriptorProto,
messages_to_add: &mut Vec<(String, &'py PyAny)>,
pool: &DescriptorPool,
) -> Result<()> {
let cls_name = field_cls.qualified_name()?;
match cls_name.as_str() {
"datetime.datetime" => {
field.type_name = Some("google.protobuf.Timestamp".to_string());
}
"datetime.timedelta" => {
field.type_name = Some("google.protobuf.Duration".to_string());
}
_ => {
let cls_name = format!("{}_{}", cls_name, field_cls.py_identifier());
field.type_name = Some(cls_name.clone());
if message_name != cls_name
&& pool.get_message_by_name(&cls_name).is_none()
&& !file.message_type.iter().any(|item| item.name() == cls_name)
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
{
messages_to_add.push((cls_name, field_cls.call0()?));
}
}
}
Ok(())
}

View File

@@ -0,0 +1,29 @@
use prost_reflect::{
prost::DecodeError, prost_types::field_descriptor_proto::Type, DescriptorError,
};
use pyo3::{exceptions::PyRuntimeError, PyErr};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Given object is not a valid betterproto message.")]
NoBetterprotoMessage(#[from] PyErr),
#[error("Unsupported type `{0}`.")]
UnsupportedType(String),
#[error("Unsupported map key type `{0:?}`.")]
UnsupportedMapKeyType(Type),
#[error("Unsupported wrapper type `{0:?}`.")]
UnsupportedWrapperType(Type),
#[error("Error on proto registration")]
FailedToRegisterDescriptor(#[from] DescriptorError),
#[error("The given binary data does not match the protobuf schema.")]
FailedToDecode(#[from] DecodeError),
}
pub type Result<T> = core::result::Result<T, Error>;
impl From<Error> for PyErr {
fn from(value: Error) -> Self {
PyRuntimeError::new_err(value.to_string())
}
}

View File

@@ -0,0 +1,24 @@
mod descriptor_pool;
mod error;
mod merging;
mod py_any_extras;
use descriptor_pool::create_cached_descriptor;
use error::Result;
use merging::merge_msg_into_pyobj;
use prost_reflect::DynamicMessage;
use pyo3::prelude::*;
#[pyfunction]
fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> {
let desc = create_cached_descriptor(obj)?;
let msg = DynamicMessage::decode(desc, buf)?;
merge_msg_into_pyobj(obj, msg)?;
Ok(())
}
#[pymodule]
fn betterproto_extras(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
Ok(())
}

View File

@@ -0,0 +1,182 @@
use crate::{error::Result, py_any_extras::PyAnyExtras};
use indoc::indoc;
use prost_reflect::{
prost_types::{Duration, Timestamp},
DynamicMessage, MapKey, ReflectMessage, Value,
};
use pyo3::{
sync::GILOnceCell,
types::{IntoPyDict, PyBytes, PyModule},
Py, PyAny, PyObject, Python, ToPyObject,
};
pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> {
for field in msg.take_fields() {
let field_name = field.0.name();
let proto_meta = obj.get_proto_meta()?;
obj.setattr(
field_name,
map_field_value(field_name, field.1, proto_meta)?,
)?;
}
let mut buf = vec![];
for field in msg.unknown_fields() {
field.encode(&mut buf);
}
if !buf.is_empty() {
let mut unknown_fields = obj.getattr("_unknown_fields")?.extract::<Vec<u8>>()?;
unknown_fields.append(&mut buf);
obj.setattr("_unknown_fields", PyBytes::new(obj.py(), &unknown_fields))?;
}
obj.setattr("_serialized_on_wire", true)?;
Ok(())
}
fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> Result<PyObject> {
let py = proto_meta.py();
match field_value {
Value::Bool(x) => Ok(x.to_object(py)),
Value::Bytes(x) => Ok(PyBytes::new(py, &x).to_object(py)),
Value::F32(x) => Ok(x.to_object(py)),
Value::F64(x) => Ok(x.to_object(py)),
Value::I32(x) => Ok(x.to_object(py)),
Value::I64(x) => Ok(x.to_object(py)),
Value::String(x) => Ok(x.to_object(py)),
Value::U32(x) => Ok(x.to_object(py)),
Value::U64(x) => Ok(x.to_object(py)),
Value::Message(msg) => match msg.descriptor().full_name() {
"google.protobuf.BoolValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bool())
.to_object(py)),
"google.protobuf.DoubleValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f64())
.to_object(py)),
"google.protobuf.FloatValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f32())
.to_object(py)),
"google.protobuf.Int64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i64())
.to_object(py)),
"google.protobuf.UInt64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u64())
.to_object(py)),
"google.protobuf.Int32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i32())
.to_object(py)),
"google.protobuf.UInt32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u32())
.to_object(py)),
"google.protobuf.StringValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_str().map(|s| s.to_string()))
.to_object(py)),
"google.protobuf.BytesValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b)))
.to_object(py)),
"google.protobuf.Timestamp" => {
let msg = msg.transcode_to::<Timestamp>()?;
Ok(create_py_datetime(&msg, py))
}
"google.protobuf.Duration" => {
let msg = msg.transcode_to::<Duration>()?;
Ok(create_py_timedelta(&msg, py))
}
_ => {
let obj = proto_meta.create_instance(field_name)?;
merge_msg_into_pyobj(obj, msg)?;
Ok(obj.to_object(py))
}
},
Value::List(ls) => Ok(ls
.into_iter()
.map(|x| map_field_value(field_name, x, proto_meta))
.collect::<Result<Vec<PyObject>>>()?
.to_object(py)),
Value::EnumNumber(x) => {
let cls = proto_meta.get_class(field_name)?;
Ok(cls.call1((x,))?.to_object(py))
}
Value::Map(map) => {
let res: Result<Vec<_>> = map
.into_iter()
.map(|(k, v)| {
let key = map_key(k, py);
let val = map_field_value(&format!("{field_name}.value"), v, proto_meta)?;
Ok((key, val))
})
.collect();
Ok(res?.into_py_dict(py).to_object(py))
}
}
}
fn map_key(key: MapKey, py: Python) -> PyObject {
match key {
MapKey::Bool(x) => x.to_object(py),
MapKey::I32(x) => x.to_object(py),
MapKey::I64(x) => x.to_object(py),
MapKey::U32(x) => x.to_object(py),
MapKey::U64(x) => x.to_object(py),
MapKey::String(x) => x.to_object(py),
}
}
fn create_py_datetime(ts: &Timestamp, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import datetime, timezone
def constructor(ts):
return datetime.fromtimestamp(ts, tz=timezone.utc)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
let ts = (ts.seconds as f64) + (ts.nanos as f64) / 1e9;
constructor
.call1(py, (ts,))
.expect("static function will not fail")
}
fn create_py_timedelta(duration: &Duration, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import timedelta
def constructor(s, ms):
return timedelta(seconds=s, microseconds=ms)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
constructor
.call1(py, (duration.seconds as f64, (duration.nanos as f64) / 1e3))
.expect("static function will not fail")
}

View File

@@ -0,0 +1,68 @@
use crate::error::Result;
use pyo3::{PyAny, Py, sync::GILOnceCell};
pub trait PyAnyExtras {
fn qualified_name(&self) -> Result<String>;
fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&PyAny>;
fn get_class(&self, field_name: &str) -> Result<&PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
fn is_list_field(&self, field_name: &str) -> Result<bool>;
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
fn py_identifier(&self) -> u64;
}
impl PyAnyExtras for PyAny {
fn qualified_name(&self) -> Result<String> {
let module = self.getattr("__module__")?;
let name = self.getattr("__name__")?;
Ok(format!("{module}.{name}"))
}
fn qualified_class_name(&self) -> Result<String> {
self.getattr("__class__")?.qualified_name()
}
fn get_proto_meta(&self) -> Result<&PyAny> {
Ok(self.getattr("_betterproto")?)
}
fn get_class(&self, field_name: &str) -> Result<&PyAny> {
let cls = self.getattr("cls_by_field")?.get_item(field_name)?;
Ok(cls)
}
fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
Ok(self.get_class(field_name)?.call0()?)
}
fn is_list_field(&self, field_name: &str) -> Result<bool> {
let cls = self.getattr("default_gen")?.get_item(field_name)?;
let module = cls.getattr("__module__")?;
let name = cls.getattr("__name__")?;
Ok(module.to_string() == "builtins" && name.to_string() == "list")
}
fn oneof_group(&self, field_name: &str) -> Result<Option<String>> {
let opt = self
.getattr("oneof_group_by_field")?
.call_method1("get", (field_name,))?
.extract()?;
Ok(opt)
}
fn py_identifier(&self) -> u64 {
static FUN_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let py = self.py();
let fun = FUN_CACHE.get_or_init(py, || {
let fun = py
.eval("id", None, None)
.expect("This is a valid Python expression");
Py::from(fun)
});
fun.call1(py, (self,))
.expect("Identity function is callable")
.extract::<u64>(py)
.expect("Identity function always returns an integer")
}
}

55
example.py Normal file
View File

@@ -0,0 +1,55 @@
# dev tests
# to be deleted later
import betterproto
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass(repr=False)
class Baz(betterproto.Message):
a: float = betterproto.float_field(1, group = "x")
b: int = betterproto.int64_field(2, group = "x")
c: float = betterproto.float_field(3, group = "y")
d: int = betterproto.int64_field(4, group = "y")
e: Optional[int] = betterproto.int32_field(5, group = "_e", optional = True)
@dataclass(repr=False)
class Foo(betterproto.Message):
x: int = betterproto.int32_field(1)
y: float = betterproto.double_field(2)
z: List[Baz] = betterproto.message_field(3)
class Enm(betterproto.Enum):
A = 0
B = 1
C = 2
@dataclass(repr=False)
class Bar(betterproto.Message):
foo1: Foo = betterproto.message_field(1)
foo2: Foo = betterproto.message_field(2)
packed: List[int] = betterproto.int64_field(3)
enm: Enm = betterproto.enum_field(4)
map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL)
maybe: Optional[bool] = betterproto.message_field(6, wraps=betterproto.TYPE_BOOL)
bts: bytes = betterproto.bytes_field(7)
# Serialization has not been changed yet. So nothing unusual here
buffer = bytes(
Bar(
foo1=Foo(1, 2.34),
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
packed=[5, 3, 1],
enm=Enm.B,
map={
1: True,
42: False
},
maybe=True,
bts=b'Hi There!'
)
)
# Native deserialization happening here
bar = Bar().parse(buffer)
print(bar)

2620
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "betterproto" name = "betterproto"
version = "2.0.0b4" version = "2.0.0b6"
description = "A better Protobuf / gRPC generator & library" description = "A better Protobuf / gRPC generator & library"
authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"] authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"]
readme = "README.md" readme = "README.md"
@@ -12,22 +12,23 @@ packages = [
] ]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.6.2,<4.0" python = "^3.7"
black = { version = ">=19.3b0", optional = true } black = { version = ">=23.1.0", optional = true }
dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpclib = "^0.4.1" grpclib = "^0.4.1"
jinja2 = { version = "^2.11.2", optional = true } importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8" python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true}
betterproto-extras = { path = "betterproto-extras" }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
asv = "^0.4.2" asv = "^0.4.2"
black = "^21.11b0"
bpython = "^0.19" bpython = "^0.19"
grpcio-tools = "^1.40.0" grpcio-tools = "^1.54.2"
jinja2 = "^2.11.2" jinja2 = ">=3.0.3"
mypy = "^0.930" mypy = "^0.930"
poethepoet = ">=0.9.0" poethepoet = ">=0.9.0"
protobuf = "^3.12.2" protobuf = "^4.21.6"
pytest = "^6.2.5" pytest = "^6.2.5"
pytest-asyncio = "^0.12.0" pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0" pytest-cov = "^2.9.0"
@@ -36,13 +37,15 @@ sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0" sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0" tomlkit = "^0.7.0"
tox = "^3.15.1" tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
[tool.poetry.scripts] [tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main" protoc-gen-python_betterproto = "betterproto.plugin:main"
[tool.poetry.extras] [tool.poetry.extras]
compiler = ["black", "jinja2"] compiler = ["black", "isort", "jinja2"]
# Dev workflow tasks # Dev workflow tasks
@@ -60,7 +63,7 @@ cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy" help = "Check types with mypy"
[tool.poe.tasks.format] [tool.poe.tasks.format]
cmd = "black . --exclude tests/output_" cmd = "black . --exclude tests/output_ --target-version py310"
help = "Apply black formatting to source code" help = "Apply black formatting to source code"
[tool.poe.tasks.docs] [tool.poe.tasks.docs]
@@ -85,8 +88,8 @@ protoc
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py --plugin=protoc-gen-custom=src/betterproto/plugin/main.py
--custom_opt=INCLUDE_GOOGLE --custom_opt=INCLUDE_GOOGLE
--custom_out=src/betterproto/lib --custom_out=src/betterproto/lib
-I /usr/local/include/ -I C:\\work\\include
/usr/local/include/google/protobuf/**/*.proto C:\\work\\include\\google\\protobuf\\**\\*.proto
""" """
help = "Regenerate the types in betterproto.lib.google" help = "Regenerate the types in betterproto.lib.google"
@@ -97,12 +100,30 @@ shell = "poe generate && tox"
help = "Run tests with multiple pythons" help = "Run tests with multiple pythons"
[tool.poe.tasks.check-style] [tool.poe.tasks.check-style]
cmd = "black . --check --diff --exclude tests/output_" cmd = "black . --check --diff"
help = "Check if code style is correct" help = "Check if code style is correct"
[tool.isort]
py_version = 37
profile = "black"
force_single_line = false
combine_as_imports = true
lines_after_imports = 2
include_trailing_comma = true
force_grid_wrap = 2
src_paths = ["src", "tests"]
[tool.black] [tool.black]
target-version = ['py36'] target-version = ['py37']
[tool.doc8]
paths = ["docs"]
max_line_length = 88
[tool.doc8.ignore_path_errors]
"docs/migrating.rst" = [
"D001", # contains table which is longer than 88 characters long
]
[tool.coverage.run] [tool.coverage.run]
omit = ["betterproto/tests/*"] omit = ["betterproto/tests/*"]
@@ -111,7 +132,7 @@ omit = ["betterproto/tests/*"]
legacy_tox_ini = """ legacy_tox_ini = """
[tox] [tox]
isolated_build = true isolated_build = true
envlist = py36, py37, py38 envlist = py37, py38, py310
[testenv] [testenv]
whitelist_externals = poetry whitelist_externals = poetry

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,12 @@
from typing import TYPE_CHECKING, TypeVar from typing import (
TYPE_CHECKING,
TypeVar,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib._typing import IProtoMessage from grpclib._typing import IProtoMessage
from . import Message from . import Message
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses

View File

@@ -1,3 +1,7 @@
from pkg_resources import get_distribution try:
from importlib import metadata
except ImportError: # for Python<3.8
import importlib_metadata as metadata # type: ignore
__version__ = get_distribution("betterproto").version
__version__ = metadata.version("betterproto")

View File

@@ -1,6 +1,7 @@
import keyword import keyword
import re import re
# Word delimiters and symbols that will not be preserved when re-casing. # Word delimiters and symbols that will not be preserved when re-casing.
# language=PythonRegExp # language=PythonRegExp
SYMBOLS = "[^a-zA-Z0-9]*" SYMBOLS = "[^a-zA-Z0-9]*"

View File

@@ -1,11 +1,18 @@
import os import os
import re import re
from typing import Dict, List, Set, Tuple, Type from typing import (
Dict,
List,
Set,
Tuple,
Type,
)
from ..casing import safe_snake_case from ..casing import safe_snake_case
from ..lib.google import protobuf as google_protobuf from ..lib.google import protobuf as google_protobuf
from .naming import pythonize_class_name from .naming import pythonize_class_name
WRAPPER_TYPES: Dict[str, Type] = { WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
".google.protobuf.FloatValue": google_protobuf.FloatValue, ".google.protobuf.FloatValue": google_protobuf.FloatValue,
@@ -36,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
def get_type_reference( def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True *, package: str, imports: set, source_type: str, unwrap: bool = True
) -> str: ) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if

View File

@@ -15,17 +15,22 @@ from typing import (
import grpclib.const import grpclib.const
from .._types import ST, T
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib.client import Channel from grpclib.client import Channel
from grpclib.metadata import Deadline from grpclib.metadata import Deadline
from .._types import (
ST,
IProtoMessage,
Message,
T,
)
_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] Value = Union[str, bytes]
_MessageLike = Union[T, ST] MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]] MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
class ServiceStub(ABC): class ServiceStub(ABC):
@@ -39,7 +44,7 @@ class ServiceStub(ABC):
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> None: ) -> None:
self.channel = channel self.channel = channel
self.timeout = timeout self.timeout = timeout
@@ -50,7 +55,7 @@ class ServiceStub(ABC):
self, self,
timeout: Optional[float], timeout: Optional[float],
deadline: Optional["Deadline"], deadline: Optional["Deadline"],
metadata: Optional[_MetadataLike], metadata: Optional[MetadataLike],
): ):
return { return {
"timeout": self.timeout if timeout is None else timeout, "timeout": self.timeout if timeout is None else timeout,
@@ -61,13 +66,13 @@ class ServiceStub(ABC):
async def _unary_unary( async def _unary_unary(
self, self,
route: str, route: str,
request: _MessageLike, request: "IProtoMessage",
response_type: Type[T], response_type: Type["T"],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> T: ) -> "T":
"""Make a unary request and return the response.""" """Make a unary request and return the response."""
async with self.channel.request( async with self.channel.request(
route, route,
@@ -84,13 +89,13 @@ class ServiceStub(ABC):
async def _unary_stream( async def _unary_stream(
self, self,
route: str, route: str,
request: _MessageLike, request: "IProtoMessage",
response_type: Type[T], response_type: Type["T"],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]: ) -> AsyncIterator["T"]:
"""Make a unary request and return the stream response iterator.""" """Make a unary request and return the stream response iterator."""
async with self.channel.request( async with self.channel.request(
route, route,
@@ -106,14 +111,14 @@ class ServiceStub(ABC):
async def _stream_unary( async def _stream_unary(
self, self,
route: str, route: str,
request_iterator: _MessageSource, request_iterator: MessageSource,
request_type: Type[ST], request_type: Type["IProtoMessage"],
response_type: Type[T], response_type: Type["T"],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> T: ) -> "T":
"""Make a stream request and return the response.""" """Make a stream request and return the response."""
async with self.channel.request( async with self.channel.request(
route, route,
@@ -130,14 +135,14 @@ class ServiceStub(ABC):
async def _stream_stream( async def _stream_stream(
self, self,
route: str, route: str,
request_iterator: _MessageSource, request_iterator: MessageSource,
request_type: Type[ST], request_type: Type["IProtoMessage"],
response_type: Type[T], response_type: Type["T"],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]: ) -> AsyncIterator["T"]:
""" """
Make a stream request and return an AsyncIterator to iterate over response Make a stream request and return an AsyncIterator to iterate over response
messages. messages.
@@ -161,7 +166,7 @@ class ServiceStub(ABC):
raise raise
@staticmethod @staticmethod
async def _send_messages(stream, messages: _MessageSource): async def _send_messages(stream, messages: MessageSource):
if isinstance(messages, AsyncIterable): if isinstance(messages, AsyncIterable):
async for message in messages: async for message in messages:
await stream.send_message(message) await stream.send_message(message)

View File

@@ -1,6 +1,10 @@
from abc import ABC from abc import ABC
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from typing import Callable, Any, Dict from typing import (
Any,
Callable,
Dict,
)
import grpclib import grpclib
import grpclib.server import grpclib.server
@@ -15,10 +19,9 @@ class ServiceBase(ABC):
self, self,
handler: Callable, handler: Callable,
stream: grpclib.server.Stream, stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any], request: Any,
) -> None: ) -> None:
response_iter = handler(request)
response_iter = handler(**request_kwargs)
# check if response is actually an AsyncIterator # check if response is actually an AsyncIterator
# this might be false if the method just returns without # this might be false if the method just returns without
# yielding at least once # yielding at least once

View File

@@ -1,5 +1,13 @@
import asyncio import asyncio
from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union from typing import (
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)
T = TypeVar("T") T = TypeVar("T")

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,17 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/compiler/plugin.proto # sources: google/protobuf/compiler/plugin.proto
# plugin: python-betterproto # plugin: python-betterproto
# This file has been @generated
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import betterproto import betterproto
from betterproto.grpc.grpclib_server import ServiceBase import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
class CodeGeneratorResponseFeature(betterproto.Enum): class CodeGeneratorResponseFeature(betterproto.Enum):
"""Sync with code_generator.h."""
FEATURE_NONE = 0 FEATURE_NONE = 0
FEATURE_PROTO3_OPTIONAL = 1 FEATURE_PROTO3_OPTIONAL = 1
@@ -20,54 +23,69 @@ class Version(betterproto.Message):
major: int = betterproto.int32_field(1) major: int = betterproto.int32_field(1)
minor: int = betterproto.int32_field(2) minor: int = betterproto.int32_field(2)
patch: int = betterproto.int32_field(3) patch: int = betterproto.int32_field(3)
# A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
# be empty for mainline stable releases.
suffix: str = betterproto.string_field(4) suffix: str = betterproto.string_field(4)
"""
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
be empty for mainline stable releases.
"""
@dataclass(eq=False, repr=False) @dataclass(eq=False, repr=False)
class CodeGeneratorRequest(betterproto.Message): class CodeGeneratorRequest(betterproto.Message):
"""An encoded CodeGeneratorRequest is written to the plugin's stdin.""" """An encoded CodeGeneratorRequest is written to the plugin's stdin."""
# The .proto files that were explicitly listed on the command-line. The code
# generator should generate code only for these files. Each file's
# descriptor will be included in proto_file, below.
file_to_generate: List[str] = betterproto.string_field(1) file_to_generate: List[str] = betterproto.string_field(1)
# The generator parameter passed on the command-line. """
The .proto files that were explicitly listed on the command-line. The code
generator should generate code only for these files. Each file's
descriptor will be included in proto_file, below.
"""
parameter: str = betterproto.string_field(2) parameter: str = betterproto.string_field(2)
# FileDescriptorProtos for all files in files_to_generate and everything they """The generator parameter passed on the command-line."""
# import. The files will appear in topological order, so each file appears
# before any file that imports it. protoc guarantees that all proto_files
# will be written after the fields above, even though this is not technically
# guaranteed by the protobuf wire format. This theoretically could allow a
# plugin to stream in the FileDescriptorProtos and handle them one by one
# rather than read the entire set into memory at once. However, as of this
# writing, this is not similarly optimized on protoc's end -- it will store
# all fields in memory at once before sending them to the plugin. Type names
# of fields and extensions in the FileDescriptorProto are always fully
# qualified.
proto_file: List[ proto_file: List[
"betterproto_lib_google_protobuf.FileDescriptorProto" "betterproto_lib_google_protobuf.FileDescriptorProto"
] = betterproto.message_field(15) ] = betterproto.message_field(15)
# The version number of protocol compiler. """
FileDescriptorProtos for all files in files_to_generate and everything they
import. The files will appear in topological order, so each file appears
before any file that imports it. protoc guarantees that all proto_files
will be written after the fields above, even though this is not technically
guaranteed by the protobuf wire format. This theoretically could allow a
plugin to stream in the FileDescriptorProtos and handle them one by one
rather than read the entire set into memory at once. However, as of this
writing, this is not similarly optimized on protoc's end -- it will store
all fields in memory at once before sending them to the plugin. Type names
of fields and extensions in the FileDescriptorProto are always fully
qualified.
"""
compiler_version: "Version" = betterproto.message_field(3) compiler_version: "Version" = betterproto.message_field(3)
"""The version number of protocol compiler."""
@dataclass(eq=False, repr=False) @dataclass(eq=False, repr=False)
class CodeGeneratorResponse(betterproto.Message): class CodeGeneratorResponse(betterproto.Message):
"""The plugin writes an encoded CodeGeneratorResponse to stdout.""" """The plugin writes an encoded CodeGeneratorResponse to stdout."""
# Error message. If non-empty, code generation failed. The plugin process
# should exit with status code zero even if it reports an error in this way.
# This should be used to indicate errors in .proto files which prevent the
# code generator from generating correct code. Errors which indicate a
# problem in protoc itself -- such as the input CodeGeneratorRequest being
# unparseable -- should be reported by writing a message to stderr and
# exiting with a non-zero status code.
error: str = betterproto.string_field(1) error: str = betterproto.string_field(1)
# A bitmask of supported features that the code generator supports. This is a """
# bitwise "or" of values from the Feature enum. Error message. If non-empty, code generation failed. The plugin process
should exit with status code zero even if it reports an error in this way.
This should be used to indicate errors in .proto files which prevent the
code generator from generating correct code. Errors which indicate a
problem in protoc itself -- such as the input CodeGeneratorRequest being
unparseable -- should be reported by writing a message to stderr and
exiting with a non-zero status code.
"""
supported_features: int = betterproto.uint64_field(2) supported_features: int = betterproto.uint64_field(2)
"""
A bitmask of supported features that the code generator supports. This is a
bitwise "or" of values from the Feature enum.
"""
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15) file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
@@ -75,54 +93,60 @@ class CodeGeneratorResponse(betterproto.Message):
class CodeGeneratorResponseFile(betterproto.Message): class CodeGeneratorResponseFile(betterproto.Message):
"""Represents a single generated file.""" """Represents a single generated file."""
# The file name, relative to the output directory. The name must not contain
# "." or ".." components and must be relative, not be absolute (so, the file
# cannot lie outside the output directory). "/" must be used as the path
# separator, not "\". If the name is omitted, the content will be appended to
# the previous file. This allows the generator to break large files into
# small chunks, and allows the generated text to be streamed back to protoc
# so that large files need not reside completely in memory at one time. Note
# that as of this writing protoc does not optimize for this -- it will read
# the entire CodeGeneratorResponse before writing files to disk.
name: str = betterproto.string_field(1) name: str = betterproto.string_field(1)
# If non-empty, indicates that the named file should already exist, and the """
# content here is to be inserted into that file at a defined insertion point. The file name, relative to the output directory. The name must not contain
# This feature allows a code generator to extend the output produced by "." or ".." components and must be relative, not be absolute (so, the file
# another code generator. The original generator may provide insertion cannot lie outside the output directory). "/" must be used as the path
# points by placing special annotations in the file that look like: separator, not "\". If the name is omitted, the content will be appended to
# @@protoc_insertion_point(NAME) The annotation can have arbitrary text the previous file. This allows the generator to break large files into
# before and after it on the line, which allows it to be placed in a comment. small chunks, and allows the generated text to be streamed back to protoc
# NAME should be replaced with an identifier naming the point -- this is what so that large files need not reside completely in memory at one time. Note
# other generators will use as the insertion_point. Code inserted at this that as of this writing protoc does not optimize for this -- it will read
# point will be placed immediately above the line containing the insertion the entire CodeGeneratorResponse before writing files to disk.
# point (thus multiple insertions to the same point will come out in the """
# order they were added). The double-@ is intended to make it unlikely that
# the generated code could contain things that look like insertion points by
# accident. For example, the C++ code generator places the following line in
# the .pb.h files that it generates: //
# @@protoc_insertion_point(namespace_scope) This line appears within the
# scope of the file's package namespace, but outside of any particular class.
# Another plugin can then specify the insertion_point "namespace_scope" to
# generate additional classes or other declarations that should be placed in
# this scope. Note that if the line containing the insertion point begins
# with whitespace, the same whitespace will be added to every line of the
# inserted text. This is useful for languages like Python, where indentation
# matters. In these languages, the insertion point comment should be
# indented the same amount as any inserted code will need to be in order to
# work correctly in that context. The code generator that generates the
# initial file and the one which inserts into it must both run as part of a
# single invocation of protoc. Code generators are executed in the order in
# which they appear on the command line. If |insertion_point| is present,
# |name| must also be present.
insertion_point: str = betterproto.string_field(2) insertion_point: str = betterproto.string_field(2)
# The file contents. """
If non-empty, indicates that the named file should already exist, and the
content here is to be inserted into that file at a defined insertion point.
This feature allows a code generator to extend the output produced by
another code generator. The original generator may provide insertion
points by placing special annotations in the file that look like:
@@protoc_insertion_point(NAME) The annotation can have arbitrary text
before and after it on the line, which allows it to be placed in a comment.
NAME should be replaced with an identifier naming the point -- this is what
other generators will use as the insertion_point. Code inserted at this
point will be placed immediately above the line containing the insertion
point (thus multiple insertions to the same point will come out in the
order they were added). The double-@ is intended to make it unlikely that
the generated code could contain things that look like insertion points by
accident. For example, the C++ code generator places the following line in
the .pb.h files that it generates: //
@@protoc_insertion_point(namespace_scope) This line appears within the
scope of the file's package namespace, but outside of any particular class.
Another plugin can then specify the insertion_point "namespace_scope" to
generate additional classes or other declarations that should be placed in
this scope. Note that if the line containing the insertion point begins
with whitespace, the same whitespace will be added to every line of the
inserted text. This is useful for languages like Python, where indentation
matters. In these languages, the insertion point comment should be
indented the same amount as any inserted code will need to be in order to
work correctly in that context. The code generator that generates the
initial file and the one which inserts into it must both run as part of a
single invocation of protoc. Code generators are executed in the order in
which they appear on the command line. If |insertion_point| is present,
|name| must also be present.
"""
content: str = betterproto.string_field(15) content: str = betterproto.string_field(15)
# Information describing the file content being inserted. If an insertion """The file contents."""
# point is used, this information will be appropriately offset and inserted
# into the code generation metadata for the generated files.
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = ( generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
betterproto.message_field(16) betterproto.message_field(16)
) )
"""
Information describing the file content being inserted. If an insertion
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf point is used, this information will be appropriately offset and inserted
into the code generation metadata for the generated files.
"""

View File

@@ -1,8 +1,10 @@
import os.path import os.path
try: try:
# betterproto[compiler] specific dependencies # betterproto[compiler] specific dependencies
import black import black
import isort.api
import jinja2 import jinja2
except ImportError as err: except ImportError as err:
print( print(
@@ -19,7 +21,6 @@ from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str: def outputfile_compiler(output_file: OutputTemplate) -> str:
templates_folder = os.path.abspath( templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates") os.path.join(os.path.dirname(__file__), "..", "templates")
) )
@@ -31,7 +32,19 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
) )
template = env.get_template("template.py.j2") template = env.get_template("template.py.j2")
code = template.render(output_file=output_file)
code = isort.api.sort_code_string(
code=code,
show_diff=False,
py_version=37,
profile="black",
combine_as_imports=True,
lines_after_imports=2,
quiet=True,
force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"],
)
return black.format_str( return black.format_str(
template.render(output_file=output_file), src_contents=code,
mode=black.Mode(), mode=black.Mode(),
) )

View File

@@ -7,9 +7,8 @@ from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest, CodeGeneratorRequest,
CodeGeneratorResponse, CodeGeneratorResponse,
) )
from betterproto.plugin.parser import generate_code
from betterproto.plugin.models import monkey_patch_oneof_index from betterproto.plugin.models import monkey_patch_oneof_index
from betterproto.plugin.parser import generate_code
def main() -> None: def main() -> None:

View File

@@ -31,6 +31,23 @@ reference to `A` to `B`'s `fields` attribute.
import builtins import builtins
import re
import textwrap
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Type,
Union,
)
import betterproto import betterproto
from betterproto import which_one_of from betterproto import which_one_of
from betterproto.casing import sanitize_name from betterproto.casing import sanitize_name
@@ -46,23 +63,20 @@ from betterproto.compile.naming import (
from betterproto.lib.google.protobuf import ( from betterproto.lib.google.protobuf import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field, Field,
FieldDescriptorProto, FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel, FieldDescriptorProtoLabel,
FieldDescriptorProtoType,
FileDescriptorProto,
MethodDescriptorProto,
) )
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
from ..casing import sanitize_name from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name from ..compile.importing import (
get_type_reference,
parse_source_type_name,
)
from ..compile.naming import ( from ..compile.naming import (
pythonize_class_name, pythonize_class_name,
pythonize_field_name, pythonize_field_name,
@@ -128,12 +142,12 @@ def monkey_patch_oneof_index():
"betterproto" "betterproto"
], ],
"group", "group",
"oneof_index", "_oneof_index",
) )
object.__setattr__( object.__setattr__(
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"], Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
"group", "group",
"oneof_index", "_oneof_index",
) )
@@ -147,11 +161,7 @@ def get_comment(
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
) )
if path[-2] == 2 and path[-4] != 6: # This is a field, message, enum, service, or method
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"') lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""' return f'{pad}"""{lines[0]}"""'
@@ -204,7 +214,6 @@ class ProtoContentBase:
@dataclass @dataclass
class PluginRequestCompiler: class PluginRequestCompiler:
plugin_request_obj: CodeGeneratorRequest plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@@ -237,10 +246,14 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set) imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list) messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list) enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
@property @property
def package(self) -> str: def package(self) -> str:
@@ -322,12 +335,29 @@ class MessageCompiler(ProtoContentBase):
def has_deprecated_fields(self) -> bool: def has_deprecated_fields(self) -> bool:
return any(self.deprecated_fields) return any(self.deprecated_fields)
@property
def has_oneof_fields(self) -> bool:
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
@property
def has_message_field(self) -> bool:
return any(
(
field.proto_obj.type in PROTO_MESSAGE_TYPES
for field in self.fields
if isinstance(field.proto_obj, FieldDescriptorProto)
)
)
def is_map( def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool: ) -> bool:
"""True if proto_field_obj is a map, otherwise False.""" """True if proto_field_obj is a map, otherwise False."""
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE: if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
if not hasattr(parent_message, "nested_type"):
return False
# This might be a map... # This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower() message_type = proto_field_obj.type_name.split(".").pop().lower()
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
@@ -355,7 +385,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
us to tell whether it was set, via the which_one_of interface. us to tell whether it was set, via the which_one_of interface.
""" """
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" return which_one_of(proto_field_obj, "_oneof_index")[0] == "oneof_index"
@dataclass @dataclass
@@ -416,6 +446,10 @@ class FieldCompiler(MessageCompiler):
imports.add("Dict") imports.add("Dict")
return imports return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@property @property
def use_builtins(self) -> bool: def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or ( return self.py_type in self.parent.builtins_types or (
@@ -425,6 +459,7 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None: def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports) output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports) output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins output_file.builtins_import = output_file.builtins_import or self.use_builtins
@property @property
@@ -529,7 +564,7 @@ class FieldCompiler(MessageCompiler):
source_type=self.proto_obj.type_name, source_type=self.proto_obj.type_name,
) )
else: else:
raise NotImplementedError(f"Unknown type {field.type}") raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
@property @property
def annotation(self) -> str: def annotation(self) -> str:
@@ -553,6 +588,20 @@ class OneOfFieldCompiler(FieldCompiler):
return args return args
@dataclass
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def optional(self) -> bool:
# Force the optional to be True. This will allow the pydantic dataclass
# to validate the object correctly by allowing the field to be let empty.
# We add a pydantic validator later to ensure exactly one field is defined.
return True
@property
def pydantic_imports(self) -> Set[str]:
return {"root_validator"}
@dataclass @dataclass
class MapEntryCompiler(FieldCompiler): class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER py_k_type: Type = PLACEHOLDER
@@ -664,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
@dataclass @dataclass
class ServiceMethodCompiler(ProtoContentBase): class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler parent: ServiceCompiler
proto_obj: MethodDescriptorProto proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER path: List[int] = PLACEHOLDER
@@ -675,12 +723,8 @@ class ServiceMethodCompiler(ProtoContentBase):
self.parent.methods.append(self) self.parent.methods.append(self)
# Check for imports # Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type: if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional") self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports # Check for Async imports
if self.client_streaming: if self.client_streaming:
@@ -692,38 +736,17 @@ class ServiceMethodCompiler(ProtoContentBase):
if self.client_streaming or self.server_streaming: if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator") self.output_file.typing_imports.add("AsyncIterator")
super().__post_init__() # check for unset fields # add imports required for request arguments timeout, deadline and metadata
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional") self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
self.output_file.imports_type_checking_only.add(
"from grpclib.metadata import Deadline"
)
return mutable_default_args super().__post_init__() # check for unset fields
@property @property
def py_name(self) -> str: def py_name(self) -> str:
@@ -760,7 +783,7 @@ class ServiceMethodCompiler(ProtoContentBase):
# comparable with method.input_type # comparable with method.input_type
for msg in self.request.all_messages: for msg in self.request.all_messages:
if ( if (
msg.py_name == name.replace(".", "") msg.py_name == pythonize_class_name(name.replace(".", ""))
and msg.output_file.package == package and msg.output_file.package == package
): ):
return msg return msg
@@ -780,8 +803,20 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package, package=self.output_file.package,
imports=self.output_file.imports, imports=self.output_file.imports,
source_type=self.proto_obj.input_type, source_type=self.proto_obj.input_type,
unwrap=False,
).strip('"') ).strip('"')
@property
def py_input_message_param(self) -> str:
"""Param name corresponding to py_input_message_type.
Returns
-------
str
Param name corresponding to py_input_message_type.
"""
return pythonize_field_name(self.py_input_message_type)
@property @property
def py_output_message_type(self) -> str: def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the """String representation of the Python type corresponding to the

View File

@@ -1,3 +1,13 @@
import pathlib
import sys
from typing import (
Generator,
List,
Set,
Tuple,
Union,
)
from betterproto.lib.google.protobuf import ( from betterproto.lib.google.protobuf import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
@@ -11,10 +21,7 @@ from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorResponseFeature, CodeGeneratorResponseFeature,
CodeGeneratorResponseFile, CodeGeneratorResponseFile,
) )
import itertools
import pathlib
import sys
from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union
from .compiler import outputfile_compiler from .compiler import outputfile_compiler
from .models import ( from .models import (
EnumDefinitionCompiler, EnumDefinitionCompiler,
@@ -24,41 +31,40 @@ from .models import (
OneOfFieldCompiler, OneOfFieldCompiler,
OutputTemplate, OutputTemplate,
PluginRequestCompiler, PluginRequestCompiler,
PydanticOneOfFieldCompiler,
ServiceCompiler, ServiceCompiler,
ServiceMethodCompiler, ServiceMethodCompiler,
is_map, is_map,
is_oneof, is_oneof,
) )
if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse( def traverse(
proto_file: FieldDescriptorProto, proto_file: FileDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": ) -> Generator[
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
]:
# Todo: Keep information about nested hierarchy # Todo: Keep information about nested hierarchy
def _traverse( def _traverse(
path: List[int], items: List["EnumDescriptorProto"], prefix="" path: List[int],
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
prefix: str = "",
) -> Generator[
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
]:
for i, item in enumerate(items): for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy. # Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple # Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = prefix + item.name item.name = next_prefix = f"{prefix}_{item.name}"
yield item, path + [i] yield item, [*path, i]
if isinstance(item, DescriptorProto): if isinstance(item, DescriptorProto):
for enum in item.enum_type: # Get nested types.
enum.name = next_prefix + enum.name yield from _traverse([*path, i, 4], item.enum_type, next_prefix)
yield enum, path + [i, 4] yield from _traverse([*path, i, 3], item.nested_type, next_prefix)
if item.nested_type: yield from _traverse([5], proto_file.enum_type)
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): yield from _traverse([4], proto_file.message_type)
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
@@ -70,14 +76,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
request_data = PluginRequestCompiler(plugin_request_obj=request) request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages # Gather output packages
for proto_file in request.proto_file: for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
output_package_name = proto_file.package output_package_name = proto_file.package
if output_package_name not in request_data.output_packages: if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package # Create a new output if there is no output for this package
@@ -87,6 +85,19 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Add this input file to the output corresponding to this package # Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file) request_data.output_packages[output_package_name].input_files.append(proto_file)
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False
if "pydantic_dataclasses" in plugin_options:
request_data.output_packages[
output_package_name
].pydantic_dataclasses = True
# Read Messages and Enums # Read Messages and Enums
# We need to read Messages before Services in so that we can # We need to read Messages before Services in so that we can
# get the references to input/output messages for each service # get the references to input/output messages for each service
@@ -109,6 +120,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Generate output files # Generate output files
output_paths: Set[pathlib.Path] = set() output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items(): for output_package_name, output_package in request_data.output_packages.items():
if not output_package.output:
continue
# Add files to the response object # Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
@@ -127,6 +140,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
directory.joinpath("__init__.py") directory.joinpath("__init__.py")
for path in output_paths for path in output_paths
for directory in path.parents for directory in path.parents
if not directory.joinpath("__init__.py").exists()
} - output_paths } - output_paths
for init_file in init_files: for init_file in init_files:
@@ -138,6 +152,23 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
return response return response
def _make_one_of_field_compiler(
output_package: OutputTemplate,
source_file: "FileDescriptorProto",
parent: MessageCompiler,
proto_obj: "FieldDescriptorProto",
path: List[int],
) -> FieldCompiler:
pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls(
source_file=source_file,
parent=parent,
proto_obj=proto_obj,
path=path,
)
def read_protobuf_type( def read_protobuf_type(
item: DescriptorProto, item: DescriptorProto,
path: List[int], path: List[int],
@@ -161,11 +192,8 @@ def read_protobuf_type(
path=path + [2, index], path=path + [2, index],
) )
elif is_oneof(field): elif is_oneof(field):
OneOfFieldCompiler( _make_one_of_field_compiler(
source_file=source_file, output_package, source_file, message_data, field, path + [2, index]
parent=message_data,
proto_obj=field,
path=path + [2, index],
) )
else: else:
FieldCompiler( FieldCompiler(

View File

@@ -1,10 +1,21 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }} # sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto # plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %} {% for i in output_file.python_module_imports|sort %}
import {{ i }} import {{ i }}
{% endfor %} {% endfor %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %} {% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@@ -14,12 +25,28 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %} {% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
{% if output_file.services %} {% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib import grpclib
{% endif %} {% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %} {% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum): class {{ enum.py_name }}(betterproto.Enum):
@@ -28,10 +55,11 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endif %} {% endif %}
{% for entry in enum.entries %} {% for entry in enum.entries %}
{{ entry.name }} = {{ entry.value }}
{% if entry.comment %} {% if entry.comment %}
{{ entry.comment }} {{ entry.comment }}
{% endif %} {% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %} {% endfor %}
@@ -45,10 +73,11 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %} {% endif %}
{% for field in message.fields %} {% for field in message.fields %}
{{ field.get_field_string() }}
{% if field.comment %} {% if field.comment %}
{{ field.comment }} {{ field.comment }}
{% endif %} {% endif %}
{{ field.get_field_string() }}
{% endfor %} {% endfor %}
{% if not message.fields %} {% if not message.fields %}
pass pass
@@ -61,11 +90,16 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %} {% endif %}
super().__post_init__() super().__post_init__()
{% for field in message.deprecated_fields %} {% for field in message.deprecated_fields %}
if self.{{ field }}: if self.is_set("{{ field }}"):
warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning) warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning)
{% endfor %} {% endfor %}
{% endif %} {% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
{% endfor %} {% endfor %}
{% for service in output_file.services %} {% for service in output_file.services %}
@@ -79,60 +113,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self async def {{ method.py_name }}(self
{%- if not method.client_streaming -%} {%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, *, {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%} {%- endif -%}
,
*
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
{% endif %} {% endif %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %} {% if method.server_streaming %}
{% if method.client_streaming %} {% if method.client_streaming %}
async for response in self._stream_stream( async for response in self._stream_stream(
"{{ method.route }}", "{{ method.route }}",
request_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
): ):
yield response yield response
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
async for response in self._unary_stream( async for response in self._unary_stream(
"{{ method.route }}", "{{ method.route }}",
request, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
): ):
yield response yield response
@@ -141,15 +156,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if method.client_streaming %} {% if method.client_streaming %}
return await self._stream_unary( return await self._stream_unary(
"{{ method.route }}", "{{ method.route }}",
request_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
) )
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
return await self._unary_unary( return await self._unary_unary(
"{{ method.route }}", "{{ method.route }}",
request, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
) )
{% endif %}{# client streaming #} {% endif %}{# client streaming #}
{% endif %} {% endif %}
@@ -167,19 +188,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self async def {{ method.py_name }}(self
{%- if not method.client_streaming -%} {%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%} {%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
@@ -187,32 +199,27 @@ class {{ service.py_name }}Base(ServiceBase):
{% endif %} {% endif %}
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
{% if method.server_streaming %}
yield {{ method.py_output_message_type }}()
{% endif %}
{% endfor %} {% endfor %}
{% for method in service.methods %} {% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None:
{% if not method.client_streaming %} {% if not method.client_streaming %}
request = await stream.recv_message() request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %} {% else %}
request_kwargs = {"request_iterator": stream.__aiter__()} request = stream.__aiter__()
{% endif %} {% endif %}
{% if not method.server_streaming %} {% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs) response = await self.{{ method.py_name }}(request)
await stream.send_message(response) await stream.send_message(response)
{% else %} {% else %}
await self._call_rpc_handler_server_stream( await self._call_rpc_handler_server_stream(
self.{{ method.py_name }}, self.{{ method.py_name }},
stream, stream,
request_kwargs, request,
) )
{% endif %} {% endif %}
@@ -240,6 +247,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %} {% endfor %}
{% for i in output_file.imports|sort %} {% if output_file.pydantic_dataclasses %}
{{ i }} {% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
{% endif %}
{% endfor %} {% endfor %}
{% endif %}

View File

@@ -1,3 +1,6 @@
import copy
import sys
import pytest import pytest
@@ -10,3 +13,10 @@ def pytest_addoption(parser):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def repeat(request): def repeat(request):
return request.config.getoption("repeat") return request.config.getoption("repeat")
@pytest.fixture
def reset_sys_path():
original = copy.deepcopy(sys.path)
yield
sys.path = original

View File

@@ -1,20 +1,22 @@
#!/usr/bin/env python #!/usr/bin/env python
import asyncio import asyncio
import os import os
from pathlib import Path
import platform import platform
import shutil import shutil
import sys import sys
from pathlib import Path
from typing import Set from typing import Set
from tests.util import ( from tests.util import (
get_directories, get_directories,
inputs_path, inputs_path,
output_path_betterproto, output_path_betterproto,
output_path_betterproto_pydantic,
output_path_reference, output_path_reference,
protoc, protoc,
) )
# Force pure-python implementation instead of C++, otherwise imports # Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database. # break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@@ -78,10 +80,12 @@ async def generate_test_case_output(
""" """
test_case_output_path_reference = output_path_reference.joinpath(test_case_name) test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name) test_case_output_path_betterproto = output_path_betterproto
test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic
os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_reference, exist_ok=True)
os.makedirs(test_case_output_path_betterproto, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True)
os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True)
clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_reference)
clear_directory(test_case_output_path_betterproto) clear_directory(test_case_output_path_betterproto)
@@ -89,9 +93,13 @@ async def generate_test_case_output(
( (
(ref_out, ref_err, ref_code), (ref_out, ref_err, ref_code),
(plg_out, plg_err, plg_code), (plg_out, plg_err, plg_code),
(plg_out_pyd, plg_err_pyd, plg_code_pyd),
) = await asyncio.gather( ) = await asyncio.gather(
protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_reference, True),
protoc(test_case_input_path, test_case_output_path_betterproto, False), protoc(test_case_input_path, test_case_output_path_betterproto, False),
protoc(
test_case_input_path, test_case_output_path_betterproto_pyd, False, True
),
) )
if ref_code == 0: if ref_code == 0:
@@ -130,7 +138,27 @@ async def generate_test_case_output(
sys.stderr.buffer.write(plg_err) sys.stderr.buffer.write(plg_err)
sys.stderr.buffer.flush() sys.stderr.buffer.flush()
return max(ref_code, plg_code) if plg_code_pyd == 0:
print(
f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
else:
print(
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
if verbose:
if plg_out_pyd:
print("Plugin stdout:")
sys.stdout.buffer.write(plg_out_pyd)
sys.stdout.buffer.flush()
if plg_err_pyd:
print("Plugin stderr:")
sys.stderr.buffer.write(plg_err_pyd)
sys.stderr.buffer.flush()
return max(ref_code, plg_code, plg_code_pyd)
HELP = "\n".join( HELP = "\n".join(
@@ -159,8 +187,18 @@ def main():
whitelist = set(sys.argv[1:]) whitelist = set(sys.argv[1:])
if platform.system() == "Windows": if platform.system() == "Windows":
asyncio.set_event_loop(asyncio.ProactorEventLoop()) # for python version prior to 3.8, loop policy needs to be set explicitly
# https://docs.python.org/3/library/asyncio-policy.html#asyncio.DefaultEventLoopPolicy
try:
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
except AttributeError:
# python < 3.7 does not have asyncio.WindowsProactorEventLoopPolicy
asyncio.get_event_loop_policy().set_event_loop(asyncio.ProactorEventLoop())
try:
asyncio.run(generate(whitelist, verbose))
except AttributeError:
# compatibility code for python < 3.7
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))

View File

@@ -1,23 +1,27 @@
import asyncio import asyncio
import sys import sys
import uuid
from tests.output_betterproto.service.service import ( import grpclib
import grpclib.client
import grpclib.metadata
import grpclib.server
import pytest
from grpclib.testing import ChannelFor
from betterproto.grpc.util.async_channel import AsyncChannel
from tests.output_betterproto.service import (
DoThingRequest, DoThingRequest,
DoThingResponse, DoThingResponse,
GetThingRequest, GetThingRequest,
TestStub as ThingServiceClient, TestStub as ThingServiceClient,
) )
import grpclib
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from .thing_service import ThingService from .thing_service import ThingService
async def _test_client(client, name="clean room", **kwargs): async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(name=name) response = await client.do_thing(DoThingRequest(name=name), **kwargs)
assert response.names == [name] assert response.names == [name]
@@ -62,7 +66,7 @@ async def test_trailer_only_error_unary_unary(
) )
async with ChannelFor([service]) as channel: async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e: with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something") await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED assert e.value.status == grpclib.Status.UNAUTHENTICATED
@@ -80,7 +84,7 @@ async def test_trailer_only_error_stream_unary(
async with ChannelFor([service]) as channel: async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e: with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things( await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")] do_thing_request_iterator=[DoThingRequest(name="something")]
) )
await _test_client(ThingServiceClient(channel)) await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED assert e.value.status == grpclib.Status.UNAUTHENTICATED
@@ -171,6 +175,56 @@ async def test_service_call_lower_level_with_overrides():
assert response.names == [THING_TO_DO] assert response.names == [THING_TO_DO]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("overrides_gen",),
[
(lambda: dict(timeout=10),),
(lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
(lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
(lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
],
)
async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
overrides = overrides_gen()
request_spy = mocker.spy(grpclib.client.Channel, "request")
name = str(uuid.uuid4())
defaults = dict(
timeout=99,
deadline=grpclib.metadata.Deadline.from_timeout(99),
metadata={"authorization": name},
)
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(
deadline=grpclib.metadata.Deadline.from_timeout(
overrides.get("timeout", 99)
),
metadata=overrides.get("metadata", defaults.get("metadata")),
)
)
]
) as channel:
client = ThingServiceClient(channel, **defaults)
await _test_client(client, name=name, **overrides)
assert request_spy.call_count == 1
# for python <3.8 request_spy.call_args.kwargs do not work
_, request_spy_call_kwargs = request_spy.call_args_list[0]
# ensure all overrides were successful
for key, value in overrides.items():
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == value
# ensure default values were retained
for key in set(defaults.keys()) - set(overrides.keys()):
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == defaults[key]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_gen_for_unary_stream_request(): async def test_async_gen_for_unary_stream_request():
thing_name = "my milkshakes" thing_name = "my milkshakes"
@@ -178,7 +232,9 @@ async def test_async_gen_for_unary_stream_request():
async with ChannelFor([ThingService()]) as channel: async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel) client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1] expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(name=thing_name): async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name assert response.name == thing_name
assert response.version == expected_versions.pop() assert response.version == expected_versions.pop()

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
from dataclasses import dataclass
from typing import AsyncIterator
import pytest
import betterproto import betterproto
from betterproto.grpc.util.async_channel import AsyncChannel from betterproto.grpc.util.async_channel import AsyncChannel
from dataclasses import dataclass
import pytest
from typing import AsyncIterator
@dataclass @dataclass

View File

@@ -1,12 +1,14 @@
from tests.output_betterproto.service.service import ( from typing import Dict
DoThingResponse,
import grpclib
import grpclib.server
from tests.output_betterproto.service import (
DoThingRequest, DoThingRequest,
DoThingResponse,
GetThingRequest, GetThingRequest,
GetThingResponse, GetThingResponse,
) )
import grpclib
import grpclib.server
from typing import Dict
class ThingService: class ThingService:

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package bool;
message Test { message Test {
bool value = 1; bool value = 1;
} }

View File

@@ -1,6 +1,19 @@
import pytest
from tests.output_betterproto.bool import Test from tests.output_betterproto.bool import Test
from tests.output_betterproto_pydantic.bool import Test as TestPyd
def test_value(): def test_value():
message = Test() message = Test()
assert not message.value, "Boolean is False by default" assert not message.value, "Boolean is False by default"
def test_pydantic_no_value():
with pytest.raises(ValueError):
TestPyd()
def test_pydantic_value():
message = Test(value=False)
assert not message.value

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package bytes;
message Test { message Test {
bytes data = 1; bytes data = 1;
} }

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package casing;
enum my_enum { enum my_enum {
ZERO = 0; ZERO = 0;
ONE = 1; ONE = 1;

View File

@@ -0,0 +1,11 @@
// https://github.com/danielgtaylor/python-betterproto/issues/344
syntax = "proto3";
package casing_inner_class;
message Test {
message inner_class {
sint32 old_exp = 1;
}
inner_class inner = 2;
}

View File

@@ -0,0 +1,14 @@
import tests.output_betterproto.casing_inner_class as casing_inner_class
def test_message_casing_inner_class_name():
assert hasattr(
casing_inner_class, "TestInnerClass"
), "Inline defined Message is correctly converted to CamelCase"
def test_message_casing_inner_class_attributes():
message = casing_inner_class.Test()
assert hasattr(
message.inner, "old_exp"
), "Inline defined Message attribute is snake_case"

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package casing_message_field_uppercase;
message Test { message Test {
int32 UPPERCASE = 1; int32 UPPERCASE = 1;
int32 UPPERCASE_V2 = 2; int32 UPPERCASE_V2 = 2;

View File

@@ -9,6 +9,7 @@ xfail = {
} }
services = { services = {
"googletypes_request",
"googletypes_response", "googletypes_response",
"googletypes_response_embedded", "googletypes_response_embedded",
"service", "service",
@@ -18,6 +19,7 @@ services = {
"googletypes_service_returns_googletype", "googletypes_service_returns_googletype",
"example_service", "example_service",
"empty_service", "empty_service",
"service_uppercase",
} }

View File

@@ -1,4 +1,6 @@
{ {
"v": 10, "message": {
"value": "hello"
},
"value": 10 "value": 10
} }

View File

@@ -1,9 +1,14 @@
syntax = "proto3"; syntax = "proto3";
package deprecated;
// Some documentation about the Test message. // Some documentation about the Test message.
message Test { message Test {
// Some documentation about the value. Message message = 1 [deprecated=true];
option deprecated = true;
int32 v = 1 [deprecated=true];
int32 value = 2; int32 value = 2;
} }
message Message {
option deprecated = true;
string value = 1;
}

View File

@@ -1,4 +0,0 @@
{
"v": 10,
"value": 10
}

View File

@@ -1,8 +0,0 @@
syntax = "proto3";
// Some documentation about the Test message.
message Test {
// Some documentation about the value.
int32 v = 1 [deprecated=true];
int32 value = 2;
}

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package double;
message Test { message Test {
double count = 1; double count = 1;
} }

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package empty_repeated;
message MessageA { message MessageA {
repeated float values = 1; repeated float values = 1;
} }

View File

@@ -0,0 +1,20 @@
syntax = "proto3";
package entry;
// This is a minimal example of a repeated message field that caused issues when
// checking whether a message is a map.
//
// During the check wheter a field is a "map", the string "entry" is added to
// the field name, checked against the type name and then further checks are
// made against the nested type of a parent message. In this edge-case, the
// first check would pass even though it shouldn't and that would cause an
// error because the parent type does not have a "nested_type" attribute.
message Test {
repeated ExportEntry export = 1;
}
message ExportEntry {
string name = 1;
}

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package enum;
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test { message Test {
Choice choice = 1; Choice choice = 1;

View File

@@ -1,6 +1,6 @@
from tests.output_betterproto.enum import ( from tests.output_betterproto.enum import (
Test,
Choice, Choice,
Test,
) )

View File

@@ -39,6 +39,8 @@
syntax = "proto2"; syntax = "proto2";
package example;
// package google.protobuf; // package google.protobuf;
option go_package = "google.golang.org/protobuf/types/descriptorpb"; option go_package = "google.golang.org/protobuf/types/descriptorpb";

View File

@@ -1,49 +1,52 @@
from typing import AsyncIterator, AsyncIterable from typing import (
AsyncIterable,
AsyncIterator,
)
import pytest import pytest
from grpclib.testing import ChannelFor from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import ( from tests.output_betterproto.example_service import (
TestBase,
TestStub,
ExampleRequest, ExampleRequest,
ExampleResponse, ExampleResponse,
TestBase,
TestStub,
) )
class ExampleService(TestBase): class ExampleService(TestBase):
async def example_unary_unary( async def example_unary_unary(
self, example_string: str, example_integer: int self, example_request: ExampleRequest
) -> "ExampleResponse": ) -> "ExampleResponse":
return ExampleResponse( return ExampleResponse(
example_string=example_string, example_string=example_request.example_string,
example_integer=example_integer, example_integer=example_request.example_integer,
) )
async def example_unary_stream( async def example_unary_stream(
self, example_string: str, example_integer: int self, example_request: ExampleRequest
) -> AsyncIterator["ExampleResponse"]: ) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse( response = ExampleResponse(
example_string=example_string, example_string=example_request.example_string,
example_integer=example_integer, example_integer=example_request.example_integer,
) )
yield response yield response
yield response yield response
yield response yield response
async def example_stream_unary( async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"] self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse": ) -> "ExampleResponse":
async for example_request in request_iterator: async for example_request in example_request_iterator:
return ExampleResponse( return ExampleResponse(
example_string=example_request.example_string, example_string=example_request.example_string,
example_integer=example_request.example_integer, example_integer=example_request.example_integer,
) )
async def example_stream_stream( async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"] self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]: ) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator: async for example_request in example_request_iterator:
yield ExampleResponse( yield ExampleResponse(
example_string=example_request.example_string, example_string=example_request.example_string,
example_integer=example_request.example_integer, example_integer=example_request.example_integer,
@@ -52,44 +55,32 @@ class ExampleService(TestBase):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calls_with_different_cardinalities(): async def test_calls_with_different_cardinalities():
test_string = "test string" example_request = ExampleRequest("test string", 42)
test_int = 42
async with ChannelFor([ExampleService()]) as channel: async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel) stub = TestStub(channel)
# unary unary # unary unary
response = await stub.example_unary_unary( response = await stub.example_unary_unary(example_request)
example_string="test string", assert response.example_string == example_request.example_string
example_integer=42, assert response.example_integer == example_request.example_integer
)
assert response.example_string == test_string
assert response.example_integer == test_int
# unary stream # unary stream
async for response in stub.example_unary_stream( async for response in stub.example_unary_stream(example_request):
example_string="test string", assert response.example_string == example_request.example_string
example_integer=42, assert response.example_integer == example_request.example_integer
):
assert response.example_string == test_string
assert response.example_integer == test_int
# stream unary # stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator(): async def request_iterator():
yield request yield example_request
yield request yield example_request
yield request yield example_request
response = await stub.example_stream_unary(request_iterator()) response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string assert response.example_string == example_request.example_string
assert response.example_integer == test_int assert response.example_integer == example_request.example_integer
# stream stream # stream stream
async for response in stub.example_stream_stream(request_iterator()): async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string assert response.example_string == example_request.example_string
assert response.example_integer == test_int assert response.example_integer == example_request.example_integer

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package field_name_identical_to_type;
// Tests that messages may contain fields with names that are identical to their python types (PR #294) // Tests that messages may contain fields with names that are identical to their python types (PR #294)
message Test { message Test {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package fixed;
message Test { message Test {
fixed32 foo = 1; fixed32 foo = 1;
sfixed32 bar = 2; sfixed32 bar = 2;

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package float;
// Some documentation about the Test message. // Some documentation about the Test message.
message Test { message Test {
double positive = 1; double positive = 1;

View File

@@ -1,13 +1,17 @@
syntax = "proto3"; syntax = "proto3";
message Foo{ package google_impl_behavior_equivalence;
int64 bar = 1;
}
message Test{ message Foo { int64 bar = 1; }
oneof group{
message Test {
oneof group {
string string = 1; string string = 1;
int64 integer = 2; int64 integer = 2;
Foo foo = 3; Foo foo = 3;
} }
} }
message Request { Empty foo = 1; }
message Empty {}

View File

@@ -1,19 +1,22 @@
import pytest import pytest
from google.protobuf import json_format from google.protobuf import json_format
import betterproto import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import ( from tests.output_betterproto.google_impl_behavior_equivalence import (
Test, Empty,
Foo, Foo,
Request,
Test,
) )
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Test as ReferenceTest, Empty as ReferenceEmpty,
Foo as ReferenceFoo, Foo as ReferenceFoo,
Request as ReferenceRequest,
Test as ReferenceTest,
) )
def test_oneof_serializes_similar_to_google_oneof(): def test_oneof_serializes_similar_to_google_oneof():
tests = [ tests = [
(Test(string="abc"), ReferenceTest(string="abc")), (Test(string="abc"), ReferenceTest(string="abc")),
(Test(integer=2), ReferenceTest(integer=2)), (Test(integer=2), ReferenceTest(integer=2)),
@@ -30,7 +33,6 @@ def test_oneof_serializes_similar_to_google_oneof():
def test_bytes_are_the_same_for_oneof(): def test_bytes_are_the_same_for_oneof():
message = Test(string="") message = Test(string="")
message_reference = ReferenceTest(string="") message_reference = ReferenceTest(string="")
@@ -48,8 +50,23 @@ def test_bytes_are_the_same_for_oneof():
# None of these fields were explicitly set BUT they should not actually be null # None of these fields were explicitly set BUT they should not actually be null
# themselves # themselves
assert isinstance(message.foo, Foo) assert not hasattr(message, "foo")
assert isinstance(message2.foo, Foo) assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
assert not hasattr(message2, "foo")
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER
assert isinstance(message_reference.foo, ReferenceFoo) assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo) assert isinstance(message_reference2.foo, ReferenceFoo)
def test_empty_message_field():
message = Request()
reference_message = ReferenceRequest()
message.foo = Empty()
reference_message.foo.CopyFrom(ReferenceEmpty())
assert betterproto.serialized_on_wire(message.foo)
assert reference_message.HasField("foo")
assert bytes(message) == reference_message.SerializeToString()

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes;
import "google/protobuf/duration.proto"; import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";

View File

@@ -0,0 +1,29 @@
syntax = "proto3";
package googletypes_request;
import "google/protobuf/duration.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
// Tests that google types can be used as params
service Test {
rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
rpc SendFloat (google.protobuf.FloatValue) returns (Input);
rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
rpc SendBool (google.protobuf.BoolValue) returns (Input);
rpc SendString (google.protobuf.StringValue) returns (Input);
rpc SendBytes (google.protobuf.BytesValue) returns (Input);
rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
rpc SendTimedelta (google.protobuf.Duration) returns (Input);
rpc SendEmpty (google.protobuf.Empty) returns (Input);
}
message Input {
}

View File

@@ -0,0 +1,47 @@
from datetime import (
datetime,
timedelta,
)
from typing import (
Any,
Callable,
)
import pytest
import betterproto.lib.google.protobuf as protobuf
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_request import (
Input,
TestStub,
)
test_cases = [
(TestStub.send_double, protobuf.DoubleValue, 2.5),
(TestStub.send_float, protobuf.FloatValue, 2.5),
(TestStub.send_int64, protobuf.Int64Value, -64),
(TestStub.send_u_int64, protobuf.UInt64Value, 64),
(TestStub.send_int32, protobuf.Int32Value, -32),
(TestStub.send_u_int32, protobuf.UInt32Value, 32),
(TestStub.send_bool, protobuf.BoolValue, True),
(TestStub.send_string, protobuf.StringValue, "string"),
(TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
(TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
(TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
]
@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[Input()])
service = TestStub(channel)
await service_method(service, wrapped_value)
assert channel.requests[0]["request"] == type(wrapped_value)

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_response;
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";
// Tests that wrapped values can be used directly as return values // Tests that wrapped values can be used directly as return values

View File

@@ -1,10 +1,18 @@
from typing import Any, Callable, Optional from typing import (
Any,
Callable,
Optional,
)
import betterproto.lib.google.protobuf as protobuf
import pytest import pytest
import betterproto.lib.google.protobuf as protobuf
from tests.mocks import MockChannel from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub from tests.output_betterproto.googletypes_response import (
Input,
TestStub,
)
test_cases = [ test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5), (TestStub.get_double, protobuf.DoubleValue, 2.5),
@@ -22,14 +30,15 @@ test_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type( async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
): ):
wrapped_value = wrapper_class() wrapped_value = wrapper_class()
wrapped_value.value = value wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value]) channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel) service = TestStub(channel)
method_param = Input()
await service_method(service) await service_method(service, method_param)
assert channel.requests[0]["response_type"] != Optional[type(value)] assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value) assert channel.requests[0]["response_type"] == type(wrapped_value)
@@ -39,7 +48,7 @@ async def test_channel_receives_wrapped_type(
@pytest.mark.xfail @pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response( async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
): ):
""" """
grpclib does not unwrap wrapper values returned by services grpclib does not unwrap wrapper values returned by services
@@ -47,8 +56,9 @@ async def test_service_unwraps_response(
wrapped_value = wrapper_class() wrapped_value = wrapper_class()
wrapped_value.value = value wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value])) service = TestStub(MockChannel(responses=[wrapped_value]))
method_param = Input()
response_value = await service_method(service) response_value = await service_method(service, method_param)
assert response_value == value assert response_value == value
assert type(response_value) == type(value) assert type(response_value) == type(value)

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_response_embedded;
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";
// Tests that wrapped values are supported as part of output message // Tests that wrapped values are supported as part of output message

View File

@@ -2,6 +2,7 @@ import pytest
from tests.mocks import MockChannel from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import ( from tests.output_betterproto.googletypes_response_embedded import (
Input,
Output, Output,
TestStub, TestStub,
) )
@@ -26,7 +27,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
) )
service = TestStub(MockChannel(responses=[output])) service = TestStub(MockChannel(responses=[output]))
response = await service.get_output() response = await service.get_output(Input())
assert response.double_value == 10.0 assert response.double_value == 10.0
assert response.float_value == 12.0 assert response.float_value == 12.0

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_service_returns_empty;
import "google/protobuf/empty.proto"; import "google/protobuf/empty.proto";
service Test { service Test {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_service_returns_googletype;
import "google/protobuf/empty.proto"; import "google/protobuf/empty.proto";
import "google/protobuf/struct.proto"; import "google/protobuf/struct.proto";

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_struct;
import "google/protobuf/struct.proto"; import "google/protobuf/struct.proto";
message Test { message Test {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package googletypes_value;
import "google/protobuf/struct.proto"; import "google/protobuf/struct.proto";
// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values. // Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.

View File

@@ -1,7 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package Capitalized; package import_capitalized_package.Capitalized;
message Message { message Message {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_capitalized_package;
import "capitalized.proto"; import "capitalized.proto";
// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. // Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package package.childpackage; package import_child_package_from_package.package.childpackage;
message ChildMessage { message ChildMessage {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_child_package_from_package;
import "package_message.proto"; import "package_message.proto";
// Tests generated imports when a message in a package refers to a message in a nested child package. // Tests generated imports when a message in a package refers to a message in a nested child package.

View File

@@ -2,7 +2,7 @@ syntax = "proto3";
import "child.proto"; import "child.proto";
package package; package import_child_package_from_package.package;
message PackageMessage { message PackageMessage {
package.childpackage.ChildMessage c = 1; package.childpackage.ChildMessage c = 1;

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package childpackage; package import_child_package_from_root.childpackage;
message Message { message Message {

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_child_package_from_root;
import "child.proto"; import "child.proto";
// Tests generated imports when a message in root refers to a message in a child package. // Tests generated imports when a message in root refers to a message in a child package.

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_circular_dependency;
import "root.proto"; import "root.proto";
import "other.proto"; import "other.proto";

View File

@@ -1,7 +1,7 @@
syntax = "proto3"; syntax = "proto3";
import "root.proto"; import "root.proto";
package other; package import_circular_dependency.other;
message OtherPackageMessage { message OtherPackageMessage {
RootPackageMessage rootPackageMessage = 1; RootPackageMessage rootPackageMessage = 1;

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_circular_dependency;
message RootPackageMessage { message RootPackageMessage {
} }

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package cousin.cousin_subpackage; package import_cousin_package.cousin.cousin_subpackage;
message CousinMessage { message CousinMessage {
} }

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package test.subpackage; package import_cousin_package.test.subpackage;
import "cousin.proto"; import "cousin.proto";

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package cousin.subpackage; package import_cousin_package_same_name.cousin.subpackage;
message CousinMessage { message CousinMessage {
} }

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package test.subpackage; package import_cousin_package_same_name.test.subpackage;
import "cousin.proto"; import "cousin.proto";

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_packages_same_name;
import "users_v1.proto"; import "users_v1.proto";
import "posts_v1.proto"; import "posts_v1.proto";

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package posts.v1; package import_packages_same_name.posts.v1;
message Post { message Post {

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package users.v1; package import_packages_same_name.users.v1;
message User { message User {

View File

@@ -2,7 +2,7 @@ syntax = "proto3";
import "parent_package_message.proto"; import "parent_package_message.proto";
package parent.child; package import_parent_package_from_child.parent.child;
// Tests generated imports when a message refers to a message defined in its parent package // Tests generated imports when a message refers to a message defined in its parent package

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package parent; package import_parent_package_from_child.parent;
message ParentPackageMessage { message ParentPackageMessage {
} }

View File

@@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package child; package import_root_package_from_child.child;
import "root.proto"; import "root.proto";

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_root_package_from_child;
message RootMessage { message RootMessage {
} }

View File

@@ -1,5 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package import_root_sibling;
import "sibling.proto"; import "sibling.proto";
// Tests generated imports when a message in the root package refers to another message in the root package // Tests generated imports when a message in the root package refers to another message in the root package

Some files were not shown because too many files have changed in this diff Show More