Compare commits
165 Commits
v1.2.3
...
michael-sa
Author | SHA1 | Date | |
---|---|---|---|
|
0c02d1b21a | ||
|
ac32bcd25a | ||
|
cdddb2f42a | ||
|
d21cd6e391 | ||
|
af7115429a | ||
|
0d9387abec | ||
|
f4ebcb0f65 | ||
|
81711d2427 | ||
|
e3135ce766 | ||
|
72855227bd | ||
|
47081617c2 | ||
|
9532844929 | ||
|
d734206fe5 | ||
|
bbf40f9694 | ||
|
0c5d1ff868 | ||
|
5fb4b4b7ff | ||
|
4f820b4a6a | ||
|
75a4c230da | ||
|
6671d87cef | ||
|
5c9a12e2f6 | ||
|
e1ccd540a9 | ||
|
4e78fe9579 | ||
|
50bb67bf5d | ||
|
1ecbf1a125 | ||
|
0814729c5a | ||
|
f7aa6150e2 | ||
|
159c30ddd8 | ||
|
cd66b0511a | ||
|
c48ca2e386 | ||
|
c8229e53a7 | ||
|
3185c67098 | ||
|
52eea5ce4c | ||
|
4b6f55dce5 | ||
|
fdbe0205f1 | ||
|
09f821921f | ||
|
a757da1b29 | ||
|
e2d672a422 | ||
|
63f5191f02 | ||
|
87f4b34930 | ||
|
2c360a55f2 | ||
|
04dce524aa | ||
|
8edec81b11 | ||
|
32c8e77274 | ||
|
d9fa6d2dd3 | ||
|
c88edfd093 | ||
|
a46979c8a6 | ||
|
83e13aa606 | ||
|
3ca75dadd7 | ||
|
5d2f3a2cd9 | ||
|
65c1f366ef | ||
|
34c34bd15a | ||
|
fb54917f2c | ||
|
1a95a7988e | ||
|
76db2f153e | ||
|
8567892352 | ||
|
3105e952ea | ||
|
7c8d47de6d | ||
|
c00e2aef19 | ||
|
fdf3b2e764 | ||
|
f7c2fd1194 | ||
|
d8abb850f8 | ||
|
d7ba27de2b | ||
|
57523a9e7f | ||
|
e5e61c873c | ||
|
9fd1c058e6 | ||
|
d336153845 | ||
|
9a45ea9f16 | ||
|
bb7f5229fb | ||
|
f7769a19d1 | ||
|
d31f90be6b | ||
|
919b0a6a7d | ||
|
7ecf3fe0e6 | ||
|
ff14948a4e | ||
|
cb00273257 | ||
|
973d68a154 | ||
|
ab9857b5fd | ||
|
2f658df666 | ||
|
b813d1cedb | ||
|
f5ce1b7108 | ||
|
62fc421d60 | ||
|
eeed1c0db7 | ||
|
2a3e1e1827 | ||
|
53ce1255d3 | ||
|
e8991339e9 | ||
|
4556d67503 | ||
|
f087c6c9bd | ||
|
eec24e4ee8 | ||
|
91111ab7d8 | ||
|
fcff3dff74 | ||
|
5c4969ff1c | ||
|
ed33a48d64 | ||
|
ee362a7a73 | ||
|
261e55b2c8 | ||
|
98930ce0d7 | ||
|
d7d277eb0d | ||
|
3860c0ab11 | ||
|
cd1c2dc3b5 | ||
|
be2a24d15c | ||
|
a5effb219a | ||
|
b354aeb692 | ||
|
6d9e3fc580 | ||
|
72de590651 | ||
|
3c70f21074 | ||
|
4b7d5d3de4 | ||
|
2d57f0d122 | ||
|
142e976c40 | ||
|
382fabb96c | ||
|
18598e77d4 | ||
|
6871053ab2 | ||
|
5bb6931df7 | ||
|
e8a9960b73 | ||
|
f25c66777a | ||
|
a68505b80e | ||
|
2f9497e064 | ||
|
33964b883e | ||
|
ec7574086d | ||
|
8a42027bc9 | ||
|
71737cf696 | ||
|
659ddd9c44 | ||
|
5b6997870a | ||
|
cdf7645722 | ||
|
ca20069ca3 | ||
|
59a4a7da43 | ||
|
15af4367e5 | ||
|
ec5683e572 | ||
|
20150fdcf3 | ||
|
d11b7d04c5 | ||
|
e2d35f4696 | ||
|
c3f08b9ef2 | ||
|
24d44898f4 | ||
|
074448c996 | ||
|
0fe557bd3c | ||
|
1a87ea43a1 | ||
|
983e0895a2 | ||
|
4a2baf3f0a | ||
|
8f0caf1db2 | ||
|
c50d9e2fdc | ||
|
35548cb43e | ||
|
b711d1e11f | ||
|
917de09bb6 | ||
|
1f7f39049e | ||
|
3d001a2a1a | ||
|
de61ddab21 | ||
|
5e2d9febea | ||
|
f6af077ffe | ||
|
92088ebda8 | ||
|
c3e3837f71 | ||
|
6bd9c7835c | ||
|
6ec902c1b5 | ||
|
960dba2ae8 | ||
|
4b4bdefb6f | ||
|
dfa0a56b39 | ||
|
dd4873dfba | ||
|
91f586f7d7 | ||
|
33fb83faad | ||
|
77c04414f5 | ||
|
6969ff7ff6 | ||
|
13e08fdaa8 | ||
|
6775632f77 | ||
|
b12f1e4e61 | ||
|
7e9ba0866c | ||
|
3546f55146 | ||
|
499489f1d3 | ||
|
5759e323bd | ||
|
c762c9c549 |
70
.github/workflows/ci.yml
vendored
70
.github/workflows/ci.yml
vendored
@@ -4,51 +4,71 @@ on: [push, pull_request]
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
|
check-formatting:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
name: Consult black on python formatting
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.7
|
||||||
|
- uses: Gr1N/setup-poetry@v2
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry/virtualenvs
|
||||||
|
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-poetry-
|
||||||
|
- name: Install dependencies
|
||||||
|
run: poetry install
|
||||||
|
- name: Run black
|
||||||
|
run: make check-style
|
||||||
|
|
||||||
run-tests:
|
run-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
name: Run tests with tox
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [ '3.6', '3.7' ]
|
python-version: [ '3.6', '3.7', '3.8']
|
||||||
|
|
||||||
name: Python ${{ matrix.python-version }} test
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/setup-python@v1
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- uses: dschep/install-pipenv-action@v1
|
- uses: Gr1N/setup-poetry@v2
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry/virtualenvs
|
||||||
|
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-poetry-
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt install protobuf-compiler libprotobuf-dev
|
sudo apt install protobuf-compiler libprotobuf-dev
|
||||||
pipenv install --dev --python ${pythonLocation}/python
|
poetry install
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
cp .env.default .env
|
make generate
|
||||||
pipenv run pip install -e .
|
make test
|
||||||
pipenv run generate
|
|
||||||
pipenv run test
|
|
||||||
|
|
||||||
build-release:
|
build-release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/setup-python@v1
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: 3.7
|
python-version: 3.7
|
||||||
- uses: dschep/install-pipenv-action@v1
|
- uses: Gr1N/setup-poetry@v2
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt install protobuf-compiler libprotobuf-dev
|
|
||||||
pipenv install --dev --python ${pythonLocation}/python
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
|
run: poetry build
|
||||||
|
- name: Publish package to PyPI
|
||||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
||||||
run: pipenv run python setup.py sdist
|
run: poetry publish -n
|
||||||
- name: Publish package
|
env:
|
||||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}
|
||||||
uses: pypa/gh-action-pypi-publish@v1.0.0a0
|
|
||||||
with:
|
|
||||||
user: __token__
|
|
||||||
password: ${{ secrets.pypi }}
|
|
||||||
|
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,15 +1,16 @@
|
|||||||
|
.coverage
|
||||||
|
.DS_Store
|
||||||
.env
|
.env
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
.mypy_cache
|
.mypy_cache
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
.python-version
|
.python-version
|
||||||
build/
|
build/
|
||||||
betterproto/tests/*.bin
|
betterproto/tests/output_*
|
||||||
betterproto/tests/*_pb2.py
|
|
||||||
betterproto/tests/*.py
|
|
||||||
!betterproto/tests/generate.py
|
|
||||||
!betterproto/tests/test_*.py
|
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
dist
|
dist
|
||||||
**/*.egg-info
|
**/*.egg-info
|
||||||
output
|
output
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
.tox
|
||||||
|
17
CHANGELOG.md
17
CHANGELOG.md
@@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [1.2.5] - 2020-04-27
|
||||||
|
|
||||||
|
- Add .j2 suffix to python template names to avoid confusing certain build tools [#72](https://github.com/danielgtaylor/python-betterproto/pull/72)
|
||||||
|
|
||||||
|
## [1.2.4] - 2020-04-26
|
||||||
|
|
||||||
|
- Enforce utf-8 for reading the readme in setup.py [#67](https://github.com/danielgtaylor/python-betterproto/pull/67)
|
||||||
|
- Only import types from grpclib when type checking [#52](https://github.com/danielgtaylor/python-betterproto/pull/52)
|
||||||
|
- Improve performance of serialize/deserialize by caching type information of fields in class [#46](https://github.com/danielgtaylor/python-betterproto/pull/46)
|
||||||
|
- Support using Google's wrapper types as RPC output values [#40](https://github.com/danielgtaylor/python-betterproto/pull/40)
|
||||||
|
- Fixes issue where protoc did not recognize plugin.py as win32 application [#38](https://github.com/danielgtaylor/python-betterproto/pull/38)
|
||||||
|
- Fix services using non-pythonified field names [#34](https://github.com/danielgtaylor/python-betterproto/pull/34)
|
||||||
|
- Add ability to provide metadata, timeout & deadline args to requests [#32](https://github.com/danielgtaylor/python-betterproto/pull/32)
|
||||||
|
|
||||||
## [1.2.3] - 2020-04-15
|
## [1.2.3] - 2020-04-15
|
||||||
|
|
||||||
- Exclude empty lists from `to_dict` by default [#16](https://github.com/danielgtaylor/python-betterproto/pull/16)
|
- Exclude empty lists from `to_dict` by default [#16](https://github.com/danielgtaylor/python-betterproto/pull/16)
|
||||||
@@ -44,7 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
- Initial release
|
- Initial release
|
||||||
|
|
||||||
[unreleased]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.3...HEAD
|
[1.2.5]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.4...v1.2.5
|
||||||
|
[1.2.4]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.3...v1.2.4
|
||||||
[1.2.3]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.2...v1.2.3
|
[1.2.3]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.2...v1.2.3
|
||||||
[1.2.2]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.1...v1.2.2
|
[1.2.2]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.1...v1.2.2
|
||||||
[1.2.1]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.0...v1.2.1
|
[1.2.1]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.0...v1.2.1
|
||||||
|
42
Makefile
Normal file
42
Makefile
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
.PHONY: help setup generate test types format clean plugin full-test check-style
|
||||||
|
|
||||||
|
help: ## - Show this help.
|
||||||
|
@fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//'
|
||||||
|
|
||||||
|
# Dev workflow tasks
|
||||||
|
|
||||||
|
generate: ## - Generate test cases (do this once before running test)
|
||||||
|
poetry run ./betterproto/tests/generate.py
|
||||||
|
|
||||||
|
test: ## - Run tests
|
||||||
|
poetry run pytest --cov betterproto
|
||||||
|
|
||||||
|
types: ## - Check types with mypy
|
||||||
|
poetry run mypy betterproto --ignore-missing-imports
|
||||||
|
|
||||||
|
format: ## - Apply black formatting to source code
|
||||||
|
poetry run black . --exclude tests/output_
|
||||||
|
|
||||||
|
clean: ## - Clean out generated files from the workspace
|
||||||
|
rm -rf .coverage \
|
||||||
|
.mypy_cache \
|
||||||
|
.pytest_cache \
|
||||||
|
dist \
|
||||||
|
**/__pycache__ \
|
||||||
|
betterproto/tests/output_*
|
||||||
|
|
||||||
|
# Manual testing
|
||||||
|
|
||||||
|
# By default write plugin output to a directory called output
|
||||||
|
o=output
|
||||||
|
plugin: ## - Execute the protoc plugin, with output write to `output` or the value passed to `-o`
|
||||||
|
mkdir -p $(o)
|
||||||
|
protoc --plugin=protoc-gen-custom=betterproto/plugin.py $(i) --custom_out=$(o)
|
||||||
|
|
||||||
|
# CI tasks
|
||||||
|
|
||||||
|
full-test: generate ## - Run full testing sequence with multiple pythons
|
||||||
|
poetry run tox
|
||||||
|
|
||||||
|
check-style: ## - Check if code style is correct
|
||||||
|
poetry run black . --check --diff --exclude tests/output_
|
32
Pipfile
32
Pipfile
@@ -1,32 +0,0 @@
|
|||||||
[[source]]
|
|
||||||
name = "pypi"
|
|
||||||
url = "https://pypi.org/simple"
|
|
||||||
verify_ssl = true
|
|
||||||
|
|
||||||
[dev-packages]
|
|
||||||
flake8 = "*"
|
|
||||||
mypy = "*"
|
|
||||||
isort = "*"
|
|
||||||
pytest = "*"
|
|
||||||
rope = "*"
|
|
||||||
v = {editable = true,version = "*"}
|
|
||||||
|
|
||||||
[packages]
|
|
||||||
protobuf = "*"
|
|
||||||
jinja2 = "*"
|
|
||||||
grpclib = "*"
|
|
||||||
stringcase = "*"
|
|
||||||
black = "*"
|
|
||||||
backports-datetime-fromisoformat = "*"
|
|
||||||
dataclasses = "*"
|
|
||||||
|
|
||||||
[requires]
|
|
||||||
python_version = "3.6"
|
|
||||||
|
|
||||||
[scripts]
|
|
||||||
plugin = "protoc --plugin=protoc-gen-custom=betterproto/plugin.py --custom_out=output"
|
|
||||||
generate = "python betterproto/tests/generate.py"
|
|
||||||
test = "pytest ./betterproto/tests"
|
|
||||||
|
|
||||||
[pipenv]
|
|
||||||
allow_prereleases = true
|
|
116
README.md
116
README.md
@@ -46,10 +46,10 @@ First, install the package. Note that the `[compiler]` feature flag tells it to
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Install both the library and compiler
|
# Install both the library and compiler
|
||||||
$ pip install "betterproto[compiler]"
|
pip install "betterproto[compiler]"
|
||||||
|
|
||||||
# Install just the library (to use the generated code output)
|
# Install just the library (to use the generated code output)
|
||||||
$ pip install betterproto
|
pip install betterproto
|
||||||
```
|
```
|
||||||
|
|
||||||
Now, given you installed the compiler and have a proto file, e.g `example.proto`:
|
Now, given you installed the compiler and have a proto file, e.g `example.proto`:
|
||||||
@@ -68,14 +68,15 @@ message Greeting {
|
|||||||
You can run the following:
|
You can run the following:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
$ protoc -I . --python_betterproto_out=. example.proto
|
mkdir lib
|
||||||
|
protoc -I . --python_betterproto_out=lib example.proto
|
||||||
```
|
```
|
||||||
|
|
||||||
This will generate `hello.py` which looks like:
|
This will generate `lib/hello/__init__.py` which looks like:
|
||||||
|
|
||||||
```py
|
```python
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# sources: hello.proto
|
# sources: example.proto
|
||||||
# plugin: python-betterproto
|
# plugin: python-betterproto
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ import betterproto
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hello(betterproto.Message):
|
class Greeting(betterproto.Message):
|
||||||
"""Greeting represents a message you can tell a user."""
|
"""Greeting represents a message you can tell a user."""
|
||||||
|
|
||||||
message: str = betterproto.string_field(1)
|
message: str = betterproto.string_field(1)
|
||||||
@@ -91,23 +92,23 @@ class Hello(betterproto.Message):
|
|||||||
|
|
||||||
Now you can use it!
|
Now you can use it!
|
||||||
|
|
||||||
```py
|
```python
|
||||||
>>> from hello import Hello
|
>>> from lib.hello import Greeting
|
||||||
>>> test = Hello()
|
>>> test = Greeting()
|
||||||
>>> test
|
>>> test
|
||||||
Hello(message='')
|
Greeting(message='')
|
||||||
|
|
||||||
>>> test.message = "Hey!"
|
>>> test.message = "Hey!"
|
||||||
>>> test
|
>>> test
|
||||||
Hello(message="Hey!")
|
Greeting(message="Hey!")
|
||||||
|
|
||||||
>>> serialized = bytes(test)
|
>>> serialized = bytes(test)
|
||||||
>>> serialized
|
>>> serialized
|
||||||
b'\n\x04Hey!'
|
b'\n\x04Hey!'
|
||||||
|
|
||||||
>>> another = Hello().parse(serialized)
|
>>> another = Greeting().parse(serialized)
|
||||||
>>> another
|
>>> another
|
||||||
Hello(message="Hey!")
|
Greeting(message="Hey!")
|
||||||
|
|
||||||
>>> another.to_dict()
|
>>> another.to_dict()
|
||||||
{"message": "Hey!"}
|
{"message": "Hey!"}
|
||||||
@@ -256,6 +257,7 @@ Google provides several well-known message types like a timestamp, duration, and
|
|||||||
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
|
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
|
||||||
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
|
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
|
||||||
| `google.protobuf.*Value` | `Optional[...]` | `None` |
|
| `google.protobuf.*Value` | `Optional[...]` | `None` |
|
||||||
|
| `google.protobuf.*` | `betterproto.lib.google.protobuf.*` | `None` |
|
||||||
|
|
||||||
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
|
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
|
||||||
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
|
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
|
||||||
@@ -296,36 +298,91 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
|||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
First, make sure you have Python 3.6+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||||
|
|
||||||
|
First, make sure you have Python 3.6+ and `poetry` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Get set up with the virtual env & dependencies
|
# Get set up with the virtual env & dependencies
|
||||||
$ pipenv install --dev
|
poetry install
|
||||||
|
|
||||||
# Link the local package
|
# Activate the poetry environment
|
||||||
$ pipenv shell
|
poetry shell
|
||||||
$ pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To benefit from the collection of standard development tasks ensure you have make installed and run `make help` to see available tasks.
|
||||||
|
|
||||||
|
### Code style
|
||||||
|
|
||||||
|
This project enforces [black](https://github.com/psf/black) python code formatting.
|
||||||
|
|
||||||
|
Before committing changes run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make format
|
||||||
|
```
|
||||||
|
|
||||||
|
To avoid merge conflicts later, non-black formatted python code will fail in CI.
|
||||||
|
|
||||||
### Tests
|
### Tests
|
||||||
|
|
||||||
There are two types of tests:
|
There are two types of tests:
|
||||||
|
|
||||||
1. Manually-written tests for some behavior of the library
|
1. Standard tests
|
||||||
2. Proto files and JSON inputs for automated tests
|
2. Custom tests
|
||||||
|
|
||||||
For #2, you can add a new `*.proto` file into the `betterproto/tests` directory along with a sample `*.json` input and it will get automatically picked up.
|
#### Standard tests
|
||||||
|
|
||||||
|
Adding a standard test case is easy.
|
||||||
|
|
||||||
|
- Create a new directory `betterproto/tests/inputs/<name>`
|
||||||
|
- add `<name>.proto` with a message called `Test`
|
||||||
|
- add `<name>.json` with some test data (optional)
|
||||||
|
|
||||||
|
It will be picked up automatically when you run the tests.
|
||||||
|
|
||||||
|
- See also: [Standard Tests Development Guide](betterproto/tests/README.md)
|
||||||
|
|
||||||
|
#### Custom tests
|
||||||
|
|
||||||
|
Custom tests are found in `tests/test_*.py` and are run with pytest.
|
||||||
|
|
||||||
|
#### Running
|
||||||
|
|
||||||
Here's how to run the tests.
|
Here's how to run the tests.
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Generate assets from sample .proto files
|
# Generate assets from sample .proto files required by the tests
|
||||||
$ pipenv run generate
|
make generate
|
||||||
|
|
||||||
# Run the tests
|
# Run the tests
|
||||||
$ pipenv run test
|
make test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To run tests as they are run in CI (with tox) run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make full-test
|
||||||
|
```
|
||||||
|
|
||||||
|
### (Re)compiling Google Well-known Types
|
||||||
|
|
||||||
|
Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google).
|
||||||
|
Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
|
||||||
|
|
||||||
|
Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`.
|
||||||
|
|
||||||
|
Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
protoc \
|
||||||
|
--plugin=protoc-gen-custom=betterproto/plugin.py \
|
||||||
|
--custom_opt=INCLUDE_GOOGLE \
|
||||||
|
--custom_out=betterproto/lib \
|
||||||
|
-I /usr/local/include/ \
|
||||||
|
/usr/local/include/google/protobuf/*.proto
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
|
|
||||||
- [x] Fixed length fields
|
- [x] Fixed length fields
|
||||||
@@ -340,6 +397,9 @@ $ pipenv run test
|
|||||||
- [x] Refs to nested types
|
- [x] Refs to nested types
|
||||||
- [x] Imports in proto files
|
- [x] Imports in proto files
|
||||||
- [x] Well-known Google types
|
- [x] Well-known Google types
|
||||||
|
- [ ] Support as request input
|
||||||
|
- [ ] Support as response output
|
||||||
|
- [ ] Automatically wrap/unwrap responses
|
||||||
- [x] OneOf support
|
- [x] OneOf support
|
||||||
- [x] Basic support on the wire
|
- [x] Basic support on the wire
|
||||||
- [x] Check which was set from the group
|
- [x] Check which was set from the group
|
||||||
@@ -363,6 +423,10 @@ $ pipenv run test
|
|||||||
- [x] Automate running tests
|
- [x] Automate running tests
|
||||||
- [ ] Cleanup!
|
- [ ] Cleanup!
|
||||||
|
|
||||||
|
## Community
|
||||||
|
|
||||||
|
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Copyright © 2019 Daniel G. Taylor
|
Copyright © 2019 Daniel G. Taylor
|
||||||
|
@@ -5,34 +5,25 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from base64 import b64encode, b64decode
|
from base64 import b64decode, b64encode
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
SupportsBytes,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
TYPE_CHECKING,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import grpclib.client
|
from ._types import T
|
||||||
import grpclib.const
|
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
|
||||||
import stringcase
|
from .grpc.grpclib_client import ServiceStub
|
||||||
|
|
||||||
from .casing import safe_snake_case
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from grpclib._protocols import IProtoMessage
|
|
||||||
|
|
||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||||
# Apply backport of datetime.fromisoformat from 3.7
|
# Apply backport of datetime.fromisoformat from 3.7
|
||||||
@@ -118,14 +109,18 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
|||||||
|
|
||||||
|
|
||||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
def datetime_default_gen():
|
||||||
|
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
DATETIME_ZERO = datetime_default_gen()
|
||||||
|
|
||||||
|
|
||||||
class Casing(enum.Enum):
|
class Casing(enum.Enum):
|
||||||
"""Casing constants for serialization."""
|
"""Casing constants for serialization."""
|
||||||
|
|
||||||
CAMEL = stringcase.camelcase
|
CAMEL = camel_case
|
||||||
SNAKE = stringcase.snakecase
|
SNAKE = snake_case
|
||||||
|
|
||||||
|
|
||||||
class _PLACEHOLDER:
|
class _PLACEHOLDER:
|
||||||
@@ -422,8 +417,82 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
class ProtoClassMetadata:
|
||||||
T = TypeVar("T", bound="Message")
|
oneof_group_by_field: Dict[str, str]
|
||||||
|
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||||
|
default_gen: Dict[str, Callable]
|
||||||
|
cls_by_field: Dict[str, Type]
|
||||||
|
field_name_by_number: Dict[int, str]
|
||||||
|
meta_by_field_name: Dict[str, FieldMetadata]
|
||||||
|
__slots__ = (
|
||||||
|
"oneof_group_by_field",
|
||||||
|
"oneof_field_by_group",
|
||||||
|
"default_gen",
|
||||||
|
"cls_by_field",
|
||||||
|
"field_name_by_number",
|
||||||
|
"meta_by_field_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, cls: Type["Message"]):
|
||||||
|
by_field = {}
|
||||||
|
by_group: Dict[str, Set] = {}
|
||||||
|
by_field_name = {}
|
||||||
|
by_field_number = {}
|
||||||
|
|
||||||
|
fields = dataclasses.fields(cls)
|
||||||
|
for field in fields:
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
|
if meta.group:
|
||||||
|
# This is part of a one-of group.
|
||||||
|
by_field[field.name] = meta.group
|
||||||
|
|
||||||
|
by_group.setdefault(meta.group, set()).add(field)
|
||||||
|
|
||||||
|
by_field_name[field.name] = meta
|
||||||
|
by_field_number[meta.number] = field.name
|
||||||
|
|
||||||
|
self.oneof_group_by_field = by_field
|
||||||
|
self.oneof_field_by_group = by_group
|
||||||
|
self.field_name_by_number = by_field_number
|
||||||
|
self.meta_by_field_name = by_field_name
|
||||||
|
|
||||||
|
self.default_gen = self._get_default_gen(cls, fields)
|
||||||
|
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_default_gen(cls, fields):
|
||||||
|
default_gen = {}
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
default_gen[field.name] = cls._get_field_default_gen(field)
|
||||||
|
|
||||||
|
return default_gen
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_cls_by_field(cls, fields):
|
||||||
|
field_cls = {}
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
if meta.proto_type == TYPE_MAP:
|
||||||
|
assert meta.map_types
|
||||||
|
kt = cls._cls_for(field, index=0)
|
||||||
|
vt = cls._cls_for(field, index=1)
|
||||||
|
Entry = dataclasses.make_dataclass(
|
||||||
|
"Entry",
|
||||||
|
[
|
||||||
|
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||||
|
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||||
|
],
|
||||||
|
bases=(Message,),
|
||||||
|
)
|
||||||
|
field_cls[field.name] = Entry
|
||||||
|
field_cls[field.name + ".value"] = vt
|
||||||
|
else:
|
||||||
|
field_cls[field.name] = cls._cls_for(field)
|
||||||
|
|
||||||
|
return field_cls
|
||||||
|
|
||||||
|
|
||||||
class Message(ABC):
|
class Message(ABC):
|
||||||
@@ -435,69 +504,74 @@ class Message(ABC):
|
|||||||
|
|
||||||
_serialized_on_wire: bool
|
_serialized_on_wire: bool
|
||||||
_unknown_fields: bytes
|
_unknown_fields: bytes
|
||||||
_group_map: Dict[str, dict]
|
_group_current: Dict[str, str]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Keep track of whether every field was default
|
# Keep track of whether every field was default
|
||||||
all_sentinel = True
|
all_sentinel = True
|
||||||
|
|
||||||
# Set a default value for each field in the class after `__init__` has
|
# Set current field of each group after `__init__` has already been run.
|
||||||
# already been run.
|
group_current: Dict[str, str] = {}
|
||||||
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
for field in dataclasses.fields(self):
|
|
||||||
meta = FieldMetadata.get(field)
|
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
# This is part of a one-of group.
|
group_current.setdefault(meta.group)
|
||||||
group_map["fields"][field.name] = meta.group
|
|
||||||
|
|
||||||
if meta.group not in group_map["groups"]:
|
if getattr(self, field_name) != PLACEHOLDER:
|
||||||
group_map["groups"][meta.group] = {"current": None, "fields": set()}
|
|
||||||
group_map["groups"][meta.group]["fields"].add(field)
|
|
||||||
|
|
||||||
if getattr(self, field.name) != PLACEHOLDER:
|
|
||||||
# Skip anything not set to the sentinel value
|
# Skip anything not set to the sentinel value
|
||||||
all_sentinel = False
|
all_sentinel = False
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
# This was set, so make it the selected value of the one-of.
|
# This was set, so make it the selected value of the one-of.
|
||||||
group_map["groups"][meta.group]["current"] = field
|
group_current[meta.group] = field_name
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
setattr(self, field.name, self._get_field_default(field, meta))
|
setattr(self, field_name, self._get_field_default(field_name))
|
||||||
|
|
||||||
# Now that all the defaults are set, reset it!
|
# Now that all the defaults are set, reset it!
|
||||||
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
||||||
self.__dict__["_unknown_fields"] = b""
|
self.__dict__["_unknown_fields"] = b""
|
||||||
self.__dict__["_group_map"] = group_map
|
self.__dict__["_group_current"] = group_current
|
||||||
|
|
||||||
def __setattr__(self, attr: str, value: Any) -> None:
|
def __setattr__(self, attr: str, value: Any) -> None:
|
||||||
if attr != "_serialized_on_wire":
|
if attr != "_serialized_on_wire":
|
||||||
# Track when a field has been set.
|
# Track when a field has been set.
|
||||||
self.__dict__["_serialized_on_wire"] = True
|
self.__dict__["_serialized_on_wire"] = True
|
||||||
|
|
||||||
if attr in getattr(self, "_group_map", {}).get("fields", {}):
|
if hasattr(self, "_group_current"): # __post_init__ had already run
|
||||||
group = self._group_map["fields"][attr]
|
if attr in self._betterproto.oneof_group_by_field:
|
||||||
for field in self._group_map["groups"][group]["fields"]:
|
group = self._betterproto.oneof_group_by_field[attr]
|
||||||
|
for field in self._betterproto.oneof_field_by_group[group]:
|
||||||
if field.name == attr:
|
if field.name == attr:
|
||||||
self._group_map["groups"][group]["current"] = field
|
self._group_current[group] = field.name
|
||||||
else:
|
else:
|
||||||
super().__setattr__(
|
super().__setattr__(
|
||||||
field.name,
|
field.name, self._get_field_default(field.name),
|
||||||
self._get_field_default(field, FieldMetadata.get(field)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _betterproto(self):
|
||||||
|
"""
|
||||||
|
Lazy initialize metadata for each protobuf class.
|
||||||
|
It may be initialized multiple times in a multi-threaded environment,
|
||||||
|
but that won't affect the correctness.
|
||||||
|
"""
|
||||||
|
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||||
|
if not meta:
|
||||||
|
meta = ProtoClassMetadata(self.__class__)
|
||||||
|
self.__class__._betterproto_meta = meta
|
||||||
|
return meta
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
Get the binary encoded Protobuf representation of this instance.
|
Get the binary encoded Protobuf representation of this instance.
|
||||||
"""
|
"""
|
||||||
output = b""
|
output = b""
|
||||||
for field in dataclasses.fields(self):
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
meta = FieldMetadata.get(field)
|
value = getattr(self, field_name)
|
||||||
value = getattr(self, field.name)
|
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
# Optional items should be skipped. This is used for the Google
|
# Optional items should be skipped. This is used for the Google
|
||||||
@@ -508,16 +582,16 @@ class Message(ABC):
|
|||||||
# currently set in a `oneof` group, so it must be serialized even
|
# currently set in a `oneof` group, so it must be serialized even
|
||||||
# if the value is the default zero value.
|
# if the value is the default zero value.
|
||||||
selected_in_group = False
|
selected_in_group = False
|
||||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
if meta.group and self._group_current[meta.group] == field_name:
|
||||||
selected_in_group = True
|
selected_in_group = True
|
||||||
|
|
||||||
serialize_empty = False
|
serialize_empty = False
|
||||||
if isinstance(value, Message) and value._serialized_on_wire:
|
if isinstance(value, Message) and value._serialized_on_wire:
|
||||||
# Empty messages can still be sent on the wire if they were
|
# Empty messages can still be sent on the wire if they were
|
||||||
# set (or received empty).
|
# set (or recieved empty).
|
||||||
serialize_empty = True
|
serialize_empty = True
|
||||||
|
|
||||||
if value == self._get_field_default(field, meta) and not (
|
if value == self._get_field_default(field_name) and not (
|
||||||
selected_in_group or serialize_empty
|
selected_in_group or serialize_empty
|
||||||
):
|
):
|
||||||
# Default (zero) values are not serialized. Two exceptions are
|
# Default (zero) values are not serialized. Two exceptions are
|
||||||
@@ -560,50 +634,53 @@ class Message(ABC):
|
|||||||
# For compatibility with other libraries
|
# For compatibility with other libraries
|
||||||
SerializeToString = __bytes__
|
SerializeToString = __bytes__
|
||||||
|
|
||||||
def _type_hint(self, field_name: str) -> Type:
|
@classmethod
|
||||||
module = inspect.getmodule(self.__class__)
|
def _type_hint(cls, field_name: str) -> Type:
|
||||||
type_hints = get_type_hints(self.__class__, vars(module))
|
module = inspect.getmodule(cls)
|
||||||
|
type_hints = get_type_hints(cls, vars(module))
|
||||||
return type_hints[field_name]
|
return type_hints[field_name]
|
||||||
|
|
||||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
@classmethod
|
||||||
|
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
"""Get the message class for a field from the type hints."""
|
"""Get the message class for a field from the type hints."""
|
||||||
cls = self._type_hint(field.name)
|
field_cls = cls._type_hint(field.name)
|
||||||
if hasattr(cls, "__args__") and index >= 0:
|
if hasattr(field_cls, "__args__") and index >= 0:
|
||||||
cls = cls.__args__[index]
|
field_cls = field_cls.__args__[index]
|
||||||
return cls
|
return field_cls
|
||||||
|
|
||||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
def _get_field_default(self, field_name):
|
||||||
t = self._type_hint(field.name)
|
return self._betterproto.default_gen[field_name]()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
||||||
|
t = cls._type_hint(field.name)
|
||||||
|
|
||||||
value: Any = 0
|
|
||||||
if hasattr(t, "__origin__"):
|
if hasattr(t, "__origin__"):
|
||||||
if t.__origin__ in (dict, Dict):
|
if t.__origin__ in (dict, Dict):
|
||||||
# This is some kind of map (dict in Python).
|
# This is some kind of map (dict in Python).
|
||||||
value = {}
|
return dict
|
||||||
elif t.__origin__ in (list, List):
|
elif t.__origin__ in (list, List):
|
||||||
# This is some kind of list (repeated) field.
|
# This is some kind of list (repeated) field.
|
||||||
value = []
|
return list
|
||||||
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||||
# This is an optional (wrapped) field. For setting the default we
|
# This is an optional (wrapped) field. For setting the default we
|
||||||
# really don't care what kind of field it is.
|
# really don't care what kind of field it is.
|
||||||
value = None
|
return type(None)
|
||||||
else:
|
else:
|
||||||
value = t()
|
return t
|
||||||
elif issubclass(t, Enum):
|
elif issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
value = 0
|
return int
|
||||||
elif t == datetime:
|
elif t == datetime:
|
||||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
value = DATETIME_ZERO
|
return datetime_default_gen
|
||||||
else:
|
else:
|
||||||
# This is either a primitive scalar or another message type. Calling
|
# This is either a primitive scalar or another message type. Calling
|
||||||
# it should result in its zero value.
|
# it should result in its zero value.
|
||||||
value = t()
|
return t
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _postprocess_single(
|
def _postprocess_single(
|
||||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Adjusts values after parsing."""
|
"""Adjusts values after parsing."""
|
||||||
if wire_type == WIRE_VARINT:
|
if wire_type == WIRE_VARINT:
|
||||||
@@ -625,7 +702,7 @@ class Message(ABC):
|
|||||||
if meta.proto_type == TYPE_STRING:
|
if meta.proto_type == TYPE_STRING:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field_name]
|
||||||
|
|
||||||
if cls == datetime:
|
if cls == datetime:
|
||||||
value = _Timestamp().parse(value).to_datetime()
|
value = _Timestamp().parse(value).to_datetime()
|
||||||
@@ -639,20 +716,7 @@ class Message(ABC):
|
|||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
value._serialized_on_wire = True
|
value._serialized_on_wire = True
|
||||||
elif meta.proto_type == TYPE_MAP:
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
value = self._betterproto.cls_by_field[field_name]().parse(value)
|
||||||
# key/value pair will recreate the class.
|
|
||||||
assert meta.map_types
|
|
||||||
kt = self._cls_for(field, index=0)
|
|
||||||
vt = self._cls_for(field, index=1)
|
|
||||||
Entry = dataclasses.make_dataclass(
|
|
||||||
"Entry",
|
|
||||||
[
|
|
||||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
|
||||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
|
||||||
],
|
|
||||||
bases=(Message,),
|
|
||||||
)
|
|
||||||
value = Entry().parse(value)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -661,17 +725,16 @@ class Message(ABC):
|
|||||||
Parse the binary encoded Protobuf into this message instance. This
|
Parse the binary encoded Protobuf into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
|
||||||
for parsed in parse_fields(data):
|
for parsed in parse_fields(data):
|
||||||
if parsed.number in fields:
|
field_name = self._betterproto.field_name_by_number.get(parsed.number)
|
||||||
field = fields[parsed.number]
|
if not field_name:
|
||||||
meta = FieldMetadata.get(field)
|
self._unknown_fields += parsed.raw
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta = self._betterproto.meta_by_field_name[field_name]
|
||||||
|
|
||||||
value: Any
|
value: Any
|
||||||
if (
|
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
|
||||||
parsed.wire_type == WIRE_LEN_DELIM
|
|
||||||
and meta.proto_type in PACKED_TYPES
|
|
||||||
):
|
|
||||||
# This is a packed repeated field.
|
# This is a packed repeated field.
|
||||||
pos = 0
|
pos = 0
|
||||||
value = []
|
value = []
|
||||||
@@ -686,24 +749,22 @@ class Message(ABC):
|
|||||||
decoded, pos = decode_varint(parsed.value, pos)
|
decoded, pos = decode_varint(parsed.value, pos)
|
||||||
wire_type = WIRE_VARINT
|
wire_type = WIRE_VARINT
|
||||||
decoded = self._postprocess_single(
|
decoded = self._postprocess_single(
|
||||||
wire_type, meta, field, decoded
|
wire_type, meta, field_name, decoded
|
||||||
)
|
)
|
||||||
value.append(decoded)
|
value.append(decoded)
|
||||||
else:
|
else:
|
||||||
value = self._postprocess_single(
|
value = self._postprocess_single(
|
||||||
parsed.wire_type, meta, field, parsed.value
|
parsed.wire_type, meta, field_name, parsed.value
|
||||||
)
|
)
|
||||||
|
|
||||||
current = getattr(self, field.name)
|
current = getattr(self, field_name)
|
||||||
if meta.proto_type == TYPE_MAP:
|
if meta.proto_type == TYPE_MAP:
|
||||||
# Value represents a single key/value pair entry in the map.
|
# Value represents a single key/value pair entry in the map.
|
||||||
current[value.key] = value.value
|
current[value.key] = value.value
|
||||||
elif isinstance(current, list) and not isinstance(value, list):
|
elif isinstance(current, list) and not isinstance(value, list):
|
||||||
current.append(value)
|
current.append(value)
|
||||||
else:
|
else:
|
||||||
setattr(self, field.name, value)
|
setattr(self, field_name, value)
|
||||||
else:
|
|
||||||
self._unknown_fields += parsed.raw
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -714,7 +775,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
def to_dict(
|
def to_dict(
|
||||||
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
|
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
|
||||||
) -> dict:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Returns a dict representation of this message instance which can be
|
Returns a dict representation of this message instance which can be
|
||||||
used to serialize to e.g. JSON. Defaults to camel casing for
|
used to serialize to e.g. JSON. Defaults to camel casing for
|
||||||
@@ -726,10 +787,9 @@ class Message(ABC):
|
|||||||
`False`.
|
`False`.
|
||||||
"""
|
"""
|
||||||
output: Dict[str, Any] = {}
|
output: Dict[str, Any] = {}
|
||||||
for field in dataclasses.fields(self):
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
meta = FieldMetadata.get(field)
|
v = getattr(self, field_name)
|
||||||
v = getattr(self, field.name)
|
cased_name = casing(field_name).rstrip("_") # type: ignore
|
||||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
if v != DATETIME_ZERO or include_default_values:
|
if v != DATETIME_ZERO or include_default_values:
|
||||||
@@ -755,7 +815,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
if v or include_default_values:
|
if v or include_default_values:
|
||||||
output[cased_name] = v
|
output[cased_name] = v
|
||||||
elif v != self._get_field_default(field, meta) or include_default_values:
|
elif v != self._get_field_default(field_name) or include_default_values:
|
||||||
if meta.proto_type in INT_64_TYPES:
|
if meta.proto_type in INT_64_TYPES:
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [str(n) for n in v]
|
output[cased_name] = [str(n) for n in v]
|
||||||
@@ -767,7 +827,9 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
output[cased_name] = b64encode(v).decode("utf8")
|
output[cased_name] = b64encode(v).decode("utf8")
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_values = list(self._cls_for(field)) # type: ignore
|
enum_values = list(
|
||||||
|
self._betterproto.cls_by_field[field_name]
|
||||||
|
) # type: ignore
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [enum_values[e].name for e in v]
|
output[cased_name] = [enum_values[e].name for e in v]
|
||||||
else:
|
else:
|
||||||
@@ -784,33 +846,31 @@ class Message(ABC):
|
|||||||
self._serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
||||||
for key in value:
|
for key in value:
|
||||||
snake_cased = safe_snake_case(key)
|
field_name = safe_snake_case(key)
|
||||||
if snake_cased in fields_by_name:
|
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||||
field = fields_by_name[snake_cased]
|
if not meta:
|
||||||
meta = FieldMetadata.get(field)
|
continue
|
||||||
|
|
||||||
if value[key] is not None:
|
if value[key] is not None:
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field_name)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field_name]
|
||||||
for i in range(len(value[key])):
|
for i in range(len(value[key])):
|
||||||
v.append(cls().from_dict(value[key][i]))
|
v.append(cls().from_dict(value[key][i]))
|
||||||
elif isinstance(v, datetime):
|
elif isinstance(v, datetime):
|
||||||
v = datetime.fromisoformat(
|
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||||
value[key].replace("Z", "+00:00")
|
setattr(self, field_name, v)
|
||||||
)
|
|
||||||
setattr(self, field.name, v)
|
|
||||||
elif isinstance(v, timedelta):
|
elif isinstance(v, timedelta):
|
||||||
v = timedelta(seconds=float(value[key][:-1]))
|
v = timedelta(seconds=float(value[key][:-1]))
|
||||||
setattr(self, field.name, v)
|
setattr(self, field_name, v)
|
||||||
elif meta.wraps:
|
elif meta.wraps:
|
||||||
setattr(self, field.name, value[key])
|
setattr(self, field_name, value[key])
|
||||||
else:
|
else:
|
||||||
v.from_dict(value[key])
|
v.from_dict(value[key])
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field_name)
|
||||||
cls = self._cls_for(field, index=1)
|
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
||||||
for k in value[key]:
|
for k in value[key]:
|
||||||
v[k] = cls().from_dict(value[key][k])
|
v[k] = cls().from_dict(value[key][k])
|
||||||
else:
|
else:
|
||||||
@@ -826,14 +886,14 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
v = b64decode(value[key])
|
v = b64decode(value[key])
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_cls = self._cls_for(field)
|
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = [enum_cls.from_string(e) for e in v]
|
v = [enum_cls.from_string(e) for e in v]
|
||||||
elif isinstance(v, str):
|
elif isinstance(v, str):
|
||||||
v = enum_cls.from_string(v)
|
v = enum_cls.from_string(v)
|
||||||
|
|
||||||
if v is not None:
|
if v is not None:
|
||||||
setattr(self, field.name, v)
|
setattr(self, field_name, v)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
||||||
@@ -859,25 +919,29 @@ def serialized_on_wire(message: Message) -> bool:
|
|||||||
|
|
||||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||||
"""Return the name and value of a message's one-of field group."""
|
"""Return the name and value of a message's one-of field group."""
|
||||||
field = message._group_map["groups"].get(group_name, {}).get("current")
|
field_name = message._group_current.get(group_name)
|
||||||
if not field:
|
if not field_name:
|
||||||
return ("", None)
|
return ("", None)
|
||||||
return (field.name, getattr(message, field.name))
|
return (field_name, getattr(message, field_name))
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||||
class _Duration(Message):
|
from .lib.google.protobuf import (
|
||||||
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
Duration,
|
||||||
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
Timestamp,
|
||||||
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
BoolValue,
|
||||||
seconds: int = int64_field(1)
|
BytesValue,
|
||||||
# Signed fractions of a second at nanosecond resolution of the span of time.
|
DoubleValue,
|
||||||
# Durations less than one second are represented with a 0 `seconds` field and
|
FloatValue,
|
||||||
# a positive or negative `nanos` field. For durations of one second or more,
|
Int32Value,
|
||||||
# a non-zero value for the `nanos` field must be of the same sign as the
|
Int64Value,
|
||||||
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
StringValue,
|
||||||
nanos: int = int32_field(2)
|
UInt32Value,
|
||||||
|
UInt64Value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _Duration(Duration):
|
||||||
def to_timedelta(self) -> timedelta:
|
def to_timedelta(self) -> timedelta:
|
||||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||||
|
|
||||||
@@ -890,16 +954,7 @@ class _Duration(Message):
|
|||||||
return ".".join(parts) + "s"
|
return ".".join(parts) + "s"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
class _Timestamp(Timestamp):
|
||||||
class _Timestamp(Message):
|
|
||||||
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
|
||||||
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
|
||||||
seconds: int = int64_field(1)
|
|
||||||
# Non-negative fractions of a second at nanosecond resolution. Negative
|
|
||||||
# second values with fractions must still have non-negative nanos values that
|
|
||||||
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
|
||||||
nanos: int = int32_field(2)
|
|
||||||
|
|
||||||
def to_datetime(self) -> datetime:
|
def to_datetime(self) -> datetime:
|
||||||
ts = self.seconds + (self.nanos / 1e9)
|
ts = self.seconds + (self.nanos / 1e9)
|
||||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||||
@@ -940,93 +995,16 @@ class _WrappedMessage(Message):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _BoolValue(_WrappedMessage):
|
|
||||||
value: bool = bool_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _Int32Value(_WrappedMessage):
|
|
||||||
value: int = int32_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _UInt32Value(_WrappedMessage):
|
|
||||||
value: int = uint32_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _Int64Value(_WrappedMessage):
|
|
||||||
value: int = int64_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _UInt64Value(_WrappedMessage):
|
|
||||||
value: int = uint64_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _FloatValue(_WrappedMessage):
|
|
||||||
value: float = float_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _DoubleValue(_WrappedMessage):
|
|
||||||
value: float = double_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _StringValue(_WrappedMessage):
|
|
||||||
value: str = string_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _BytesValue(_WrappedMessage):
|
|
||||||
value: bytes = bytes_field(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_wrapper(proto_type: str) -> Type:
|
def _get_wrapper(proto_type: str) -> Type:
|
||||||
"""Get the wrapper message class for a wrapped type."""
|
"""Get the wrapper message class for a wrapped type."""
|
||||||
return {
|
return {
|
||||||
TYPE_BOOL: _BoolValue,
|
TYPE_BOOL: BoolValue,
|
||||||
TYPE_INT32: _Int32Value,
|
TYPE_INT32: Int32Value,
|
||||||
TYPE_UINT32: _UInt32Value,
|
TYPE_UINT32: UInt32Value,
|
||||||
TYPE_INT64: _Int64Value,
|
TYPE_INT64: Int64Value,
|
||||||
TYPE_UINT64: _UInt64Value,
|
TYPE_UINT64: UInt64Value,
|
||||||
TYPE_FLOAT: _FloatValue,
|
TYPE_FLOAT: FloatValue,
|
||||||
TYPE_DOUBLE: _DoubleValue,
|
TYPE_DOUBLE: DoubleValue,
|
||||||
TYPE_STRING: _StringValue,
|
TYPE_STRING: StringValue,
|
||||||
TYPE_BYTES: _BytesValue,
|
TYPE_BYTES: BytesValue,
|
||||||
}[proto_type]
|
}[proto_type]
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
|
||||||
"""
|
|
||||||
Base class for async gRPC service stubs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel: grpclib.client.Channel) -> None:
|
|
||||||
self.channel = channel
|
|
||||||
|
|
||||||
async def _unary_unary(
|
|
||||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
|
||||||
) -> T:
|
|
||||||
"""Make a unary request and return the response."""
|
|
||||||
async with self.channel.request(
|
|
||||||
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
|
|
||||||
) as stream:
|
|
||||||
await stream.send_message(request, end=True)
|
|
||||||
response = await stream.recv_message()
|
|
||||||
assert response is not None
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _unary_stream(
|
|
||||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
|
||||||
) -> AsyncGenerator[T, None]:
|
|
||||||
"""Make a unary request and return the stream response iterator."""
|
|
||||||
async with self.channel.request(
|
|
||||||
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
|
|
||||||
) as stream:
|
|
||||||
await stream.send_message(request, end=True)
|
|
||||||
async for message in stream:
|
|
||||||
yield message
|
|
||||||
|
9
betterproto/_types.py
Normal file
9
betterproto/_types.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import Message
|
||||||
|
from grpclib._protocols import IProtoMessage
|
||||||
|
|
||||||
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
|
T = TypeVar("T", bound="Message")
|
||||||
|
ST = TypeVar("ST", bound="IProtoMessage")
|
@@ -1,9 +1,21 @@
|
|||||||
import stringcase
|
import re
|
||||||
|
|
||||||
|
# Word delimiters and symbols that will not be preserved when re-casing.
|
||||||
|
# language=PythonRegExp
|
||||||
|
SYMBOLS = "[^a-zA-Z0-9]*"
|
||||||
|
|
||||||
|
# Optionally capitalized word.
|
||||||
|
# language=PythonRegExp
|
||||||
|
WORD = "[A-Z]*[a-z]*[0-9]*"
|
||||||
|
|
||||||
|
# Uppercase word, not followed by lowercase letters.
|
||||||
|
# language=PythonRegExp
|
||||||
|
WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
|
||||||
|
|
||||||
|
|
||||||
def safe_snake_case(value: str) -> str:
|
def safe_snake_case(value: str) -> str:
|
||||||
"""Snake case a value taking into account Python keywords."""
|
"""Snake case a value taking into account Python keywords."""
|
||||||
value = stringcase.snakecase(value)
|
value = snake_case(value)
|
||||||
if value in [
|
if value in [
|
||||||
"and",
|
"and",
|
||||||
"as",
|
"as",
|
||||||
@@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str:
|
|||||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||||
value += "_"
|
value += "_"
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def snake_case(value: str, strict: bool = True):
|
||||||
|
"""
|
||||||
|
Join words with an underscore into lowercase and remove symbols.
|
||||||
|
@param value: value to convert
|
||||||
|
@param strict: force single underscores
|
||||||
|
"""
|
||||||
|
|
||||||
|
def substitute_word(symbols, word, is_start):
|
||||||
|
if not word:
|
||||||
|
return ""
|
||||||
|
if strict:
|
||||||
|
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
|
||||||
|
elif is_start:
|
||||||
|
delimiter_count = len(symbols)
|
||||||
|
elif word.isupper() or word.islower():
|
||||||
|
delimiter_count = max(
|
||||||
|
1, len(symbols)
|
||||||
|
) # Preserve all delimiters if not strict.
|
||||||
|
else:
|
||||||
|
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
|
||||||
|
|
||||||
|
return ("_" * delimiter_count) + word.lower()
|
||||||
|
|
||||||
|
snake = re.sub(
|
||||||
|
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||||
|
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
return snake
|
||||||
|
|
||||||
|
|
||||||
|
def pascal_case(value: str, strict: bool = True):
|
||||||
|
"""
|
||||||
|
Capitalize each word and remove symbols.
|
||||||
|
@param value: value to convert
|
||||||
|
@param strict: output only alphanumeric characters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def substitute_word(symbols, word):
|
||||||
|
if strict:
|
||||||
|
return word.capitalize() # Remove all delimiters
|
||||||
|
|
||||||
|
if word.islower():
|
||||||
|
delimiter_length = len(symbols[:-1]) # Lose one delimiter
|
||||||
|
else:
|
||||||
|
delimiter_length = len(symbols) # Preserve all delimiters
|
||||||
|
|
||||||
|
return ("_" * delimiter_length) + word.capitalize()
|
||||||
|
|
||||||
|
return re.sub(
|
||||||
|
f"({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||||
|
lambda groups: substitute_word(groups[1], groups[2]),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def camel_case(value: str, strict: bool = True):
|
||||||
|
"""
|
||||||
|
Capitalize all words except first and remove symbols.
|
||||||
|
"""
|
||||||
|
return lowercase_first(pascal_case(value, strict=strict))
|
||||||
|
|
||||||
|
|
||||||
|
def lowercase_first(value: str):
|
||||||
|
return value[0:1].lower() + value[1:]
|
||||||
|
0
betterproto/compile/__init__.py
Normal file
0
betterproto/compile/__init__.py
Normal file
160
betterproto/compile/importing.py
Normal file
160
betterproto/compile/importing.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Set, Type
|
||||||
|
|
||||||
|
from betterproto import safe_snake_case
|
||||||
|
from betterproto.compile.naming import pythonize_class_name
|
||||||
|
from betterproto.lib.google import protobuf as google_protobuf
|
||||||
|
|
||||||
|
WRAPPER_TYPES: Dict[str, Type] = {
|
||||||
|
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||||
|
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||||
|
".google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||||
|
".google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||||
|
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||||
|
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||||
|
".google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||||
|
".google.protobuf.StringValue": google_protobuf.StringValue,
|
||||||
|
".google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_source_type_name(field_type_name):
|
||||||
|
"""
|
||||||
|
Split full source type name into package and type name.
|
||||||
|
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
||||||
|
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
|
||||||
|
"""
|
||||||
|
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
|
||||||
|
if package_match:
|
||||||
|
package = package_match.group(1)
|
||||||
|
name = package_match.group(2)
|
||||||
|
else:
|
||||||
|
package = ""
|
||||||
|
name = field_type_name.lstrip(".")
|
||||||
|
return package, name
|
||||||
|
|
||||||
|
|
||||||
|
def get_type_reference(
|
||||||
|
package: str, imports: set, source_type: str, unwrap: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
|
necessary. Unwraps well known type if required.
|
||||||
|
"""
|
||||||
|
if unwrap:
|
||||||
|
if source_type in WRAPPER_TYPES:
|
||||||
|
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||||
|
return f"Optional[{wrapped_type.__name__}]"
|
||||||
|
|
||||||
|
if source_type == ".google.protobuf.Duration":
|
||||||
|
return "timedelta"
|
||||||
|
|
||||||
|
if source_type == ".google.protobuf.Timestamp":
|
||||||
|
return "datetime"
|
||||||
|
|
||||||
|
source_package, source_type = parse_source_type_name(source_type)
|
||||||
|
|
||||||
|
current_package: List[str] = package.split(".") if package else []
|
||||||
|
py_package: List[str] = source_package.split(".") if source_package else []
|
||||||
|
py_type: str = pythonize_class_name(source_type)
|
||||||
|
|
||||||
|
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||||
|
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||||
|
if importing_google_protobuf and not compiling_google_protobuf:
|
||||||
|
py_package = ["betterproto", "lib"] + py_package
|
||||||
|
|
||||||
|
if py_package[:1] == ["betterproto"]:
|
||||||
|
return reference_absolute(imports, py_package, py_type)
|
||||||
|
|
||||||
|
if py_package == current_package:
|
||||||
|
return reference_sibling(py_type)
|
||||||
|
|
||||||
|
if py_package[: len(current_package)] == current_package:
|
||||||
|
return reference_descendent(current_package, imports, py_package, py_type)
|
||||||
|
|
||||||
|
if current_package[: len(py_package)] == py_package:
|
||||||
|
return reference_ancestor(current_package, imports, py_package, py_type)
|
||||||
|
|
||||||
|
return reference_cousin(current_package, imports, py_package, py_type)
|
||||||
|
|
||||||
|
|
||||||
|
def reference_absolute(imports, py_package, py_type):
|
||||||
|
"""
|
||||||
|
Returns a reference to a python type located in the root, i.e. sys.path.
|
||||||
|
"""
|
||||||
|
string_import = ".".join(py_package)
|
||||||
|
string_alias = safe_snake_case(string_import)
|
||||||
|
imports.add(f"import {string_import} as {string_alias}")
|
||||||
|
return f"{string_alias}.{py_type}"
|
||||||
|
|
||||||
|
|
||||||
|
def reference_sibling(py_type: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a reference to a python type within the same package as the current package.
|
||||||
|
"""
|
||||||
|
return f'"{py_type}"'
|
||||||
|
|
||||||
|
|
||||||
|
def reference_descendent(
|
||||||
|
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns a reference to a python type in a package that is a descendent of the current package,
|
||||||
|
and adds the required import that is aliased to avoid name conflicts.
|
||||||
|
"""
|
||||||
|
importing_descendent = py_package[len(current_package) :]
|
||||||
|
string_from = ".".join(importing_descendent[:-1])
|
||||||
|
string_import = importing_descendent[-1]
|
||||||
|
if string_from:
|
||||||
|
string_alias = "_".join(importing_descendent)
|
||||||
|
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
|
||||||
|
return f"{string_alias}.{py_type}"
|
||||||
|
else:
|
||||||
|
imports.add(f"from . import {string_import}")
|
||||||
|
return f"{string_import}.{py_type}"
|
||||||
|
|
||||||
|
|
||||||
|
def reference_ancestor(
|
||||||
|
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns a reference to a python type in a package which is an ancestor to the current package,
|
||||||
|
and adds the required import that is aliased (if possible) to avoid name conflicts.
|
||||||
|
|
||||||
|
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
|
||||||
|
"""
|
||||||
|
distance_up = len(current_package) - len(py_package)
|
||||||
|
if py_package:
|
||||||
|
string_import = py_package[-1]
|
||||||
|
string_alias = f"_{'_' * distance_up}{string_import}__"
|
||||||
|
string_from = f"..{'.' * distance_up}"
|
||||||
|
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||||
|
return f"{string_alias}.{py_type}"
|
||||||
|
else:
|
||||||
|
string_alias = f"{'_' * distance_up}{py_type}__"
|
||||||
|
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
|
||||||
|
return string_alias
|
||||||
|
|
||||||
|
|
||||||
|
def reference_cousin(
|
||||||
|
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
|
||||||
|
and adds the required import that is aliased to avoid name conflicts.
|
||||||
|
"""
|
||||||
|
shared_ancestry = os.path.commonprefix([current_package, py_package])
|
||||||
|
distance_up = len(current_package) - len(shared_ancestry)
|
||||||
|
string_from = f".{'.' * distance_up}" + ".".join(
|
||||||
|
py_package[len(shared_ancestry) : -1]
|
||||||
|
)
|
||||||
|
string_import = py_package[-1]
|
||||||
|
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
|
||||||
|
string_alias = (
|
||||||
|
f"{'_' * distance_up}"
|
||||||
|
+ safe_snake_case(".".join(py_package[len(shared_ancestry) :]))
|
||||||
|
+ "__"
|
||||||
|
)
|
||||||
|
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||||
|
return f"{string_alias}.{py_type}"
|
13
betterproto/compile/naming.py
Normal file
13
betterproto/compile/naming.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from betterproto import casing
|
||||||
|
|
||||||
|
|
||||||
|
def pythonize_class_name(name):
|
||||||
|
return casing.pascal_case(name)
|
||||||
|
|
||||||
|
|
||||||
|
def pythonize_field_name(name: str):
|
||||||
|
return casing.safe_snake_case(name)
|
||||||
|
|
||||||
|
|
||||||
|
def pythonize_method_name(name: str):
|
||||||
|
return casing.safe_snake_case(name)
|
0
betterproto/grpc/__init__.py
Normal file
0
betterproto/grpc/__init__.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
from abc import ABC
|
||||||
|
import asyncio
|
||||||
|
import grpclib.const
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Collection,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from .._types import ST, T
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from grpclib._protocols import IProtoMessage
|
||||||
|
from grpclib.client import Channel, Stream
|
||||||
|
from grpclib.metadata import Deadline
|
||||||
|
|
||||||
|
|
||||||
|
_Value = Union[str, bytes]
|
||||||
|
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
||||||
|
_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceStub(ABC):
|
||||||
|
"""
|
||||||
|
Base class for async gRPC clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel: "Channel",
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel = channel
|
||||||
|
self.timeout = timeout
|
||||||
|
self.deadline = deadline
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
def __resolve_request_kwargs(
|
||||||
|
self,
|
||||||
|
timeout: Optional[float],
|
||||||
|
deadline: Optional["Deadline"],
|
||||||
|
metadata: Optional[_MetadataLike],
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"timeout": self.timeout if timeout is None else timeout,
|
||||||
|
"deadline": self.deadline if deadline is None else deadline,
|
||||||
|
"metadata": self.metadata if metadata is None else metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _unary_unary(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request: "IProtoMessage",
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> T:
|
||||||
|
"""Make a unary request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
|
type(request),
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _unary_stream(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request: "IProtoMessage",
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
"""Make a unary request and return the stream response iterator."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.UNARY_STREAM,
|
||||||
|
type(request),
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
async for message in stream:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
async def _stream_unary(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: _MessageSource,
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> T:
|
||||||
|
"""Make a stream request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.STREAM_UNARY,
|
||||||
|
request_type,
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await self._send_messages(stream, request_iterator)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _stream_stream(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: _MessageSource,
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
"""
|
||||||
|
Make a stream request and return an AsyncIterator to iterate over response
|
||||||
|
messages.
|
||||||
|
"""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.STREAM_STREAM,
|
||||||
|
request_type,
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_request()
|
||||||
|
sending_task = asyncio.ensure_future(
|
||||||
|
self._send_messages(stream, request_iterator)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async for response in stream:
|
||||||
|
yield response
|
||||||
|
except:
|
||||||
|
sending_task.cancel()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_messages(stream, messages: _MessageSource):
|
||||||
|
if isinstance(messages, AsyncIterable):
|
||||||
|
async for message in messages:
|
||||||
|
await stream.send_message(message)
|
||||||
|
else:
|
||||||
|
for message in messages:
|
||||||
|
await stream.send_message(message)
|
||||||
|
await stream.end()
|
0
betterproto/grpc/util/__init__.py
Normal file
0
betterproto/grpc/util/__init__.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import (
|
||||||
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelClosed(Exception):
|
||||||
|
"""
|
||||||
|
An exception raised on an attempt to send through a closed channel
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDone(Exception):
|
||||||
|
"""
|
||||||
|
An exception raised on an attempt to send recieve from a channel that is both closed
|
||||||
|
and empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncChannel(AsyncIterable[T]):
|
||||||
|
"""
|
||||||
|
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||||
|
|
||||||
|
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
client = GeneratedStub(grpclib_chan)
|
||||||
|
request_chan = await AsyncChannel()
|
||||||
|
# We can start be sending all the requests we already have
|
||||||
|
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
||||||
|
async for response in client.rpc_call(request_chan):
|
||||||
|
# The response iterator will remain active until the connection is closed
|
||||||
|
...
|
||||||
|
# More items can be sent at any time
|
||||||
|
await request_chan.send(ReqestObject(...))
|
||||||
|
...
|
||||||
|
# The channel must be closed to complete the gRPC connection
|
||||||
|
request_chan.close()
|
||||||
|
|
||||||
|
Items can be sent through the channel by either:
|
||||||
|
- providing an iterable to the send_from method
|
||||||
|
- passing them to the send method one at a time
|
||||||
|
|
||||||
|
Items can be recieved from the channel by either:
|
||||||
|
- iterating over the channel with a for loop to get all items
|
||||||
|
- calling the recieve method to get one item at a time
|
||||||
|
|
||||||
|
If the channel is empty then recievers will wait until either an item appears or the
|
||||||
|
channel is closed.
|
||||||
|
|
||||||
|
Once the channel is closed then subsequent attempt to send through the channel will
|
||||||
|
fail with a ChannelClosed exception.
|
||||||
|
|
||||||
|
When th channel is closed and empty then it is done, and further attempts to recieve
|
||||||
|
from it will fail with a ChannelDone exception
|
||||||
|
|
||||||
|
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
||||||
|
recieved by only one of the recievers.
|
||||||
|
|
||||||
|
:param source:
|
||||||
|
An optional iterable will items that should be sent through the channel
|
||||||
|
immediately.
|
||||||
|
:param buffer_limit:
|
||||||
|
Limit the number of items that can be buffered in the channel, A value less than
|
||||||
|
1 implies no limit. If the channel is full then attempts to send more items will
|
||||||
|
result in the sender waiting until an item is recieved from the channel.
|
||||||
|
:param close:
|
||||||
|
If set to True then the channel will automatically close after exhausting source
|
||||||
|
or immediately if no source is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, buffer_limit: int = 0, close: bool = False,
|
||||||
|
):
|
||||||
|
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||||
|
self._closed = False
|
||||||
|
self._waiting_recievers: int = 0
|
||||||
|
# Track whether flush has been invoked so it can only happen once
|
||||||
|
self._flushed = False
|
||||||
|
|
||||||
|
def __aiter__(self) -> AsyncIterator[T]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> T:
|
||||||
|
if self.done():
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self._waiting_recievers += 1
|
||||||
|
try:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is self.__flush:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self._waiting_recievers -= 1
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def closed(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if this channel is closed and no-longer accepting new items
|
||||||
|
"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this channel is done.
|
||||||
|
|
||||||
|
:return: True if this channel is closed and and has been drained of items in
|
||||||
|
which case any further attempts to recieve an item from this channel will raise
|
||||||
|
a ChannelDone exception.
|
||||||
|
"""
|
||||||
|
# After close the channel is not yet done until there is at least one waiting
|
||||||
|
# reciever per enqueued item.
|
||||||
|
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
||||||
|
|
||||||
|
async def send_from(
|
||||||
|
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||||
|
) -> "AsyncChannel[T]":
|
||||||
|
"""
|
||||||
|
Iterates the given [Async]Iterable and sends all the resulting items.
|
||||||
|
If close is set to True then subsequent send calls will be rejected with a
|
||||||
|
ChannelClosed exception.
|
||||||
|
:param source: an iterable of items to send
|
||||||
|
:param close:
|
||||||
|
if True then the channel will be closed after the source has been exhausted
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise ChannelClosed("Cannot send through a closed channel")
|
||||||
|
if isinstance(source, AsyncIterable):
|
||||||
|
async for item in source:
|
||||||
|
await self._queue.put(item)
|
||||||
|
else:
|
||||||
|
for item in source:
|
||||||
|
await self._queue.put(item)
|
||||||
|
if close:
|
||||||
|
# Complete the closing process
|
||||||
|
self.close()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def send(self, item: T) -> "AsyncChannel[T]":
|
||||||
|
"""
|
||||||
|
Send a single item over this channel.
|
||||||
|
:param item: The item to send
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise ChannelClosed("Cannot send through a closed channel")
|
||||||
|
await self._queue.put(item)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def recieve(self) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Returns the next item from this channel when it becomes available,
|
||||||
|
or None if the channel is closed before another item is sent.
|
||||||
|
:return: An item from the channel
|
||||||
|
"""
|
||||||
|
if self.done():
|
||||||
|
raise ChannelDone("Cannot recieve from a closed channel")
|
||||||
|
self._waiting_recievers += 1
|
||||||
|
try:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is self.__flush:
|
||||||
|
return None
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self._waiting_recievers -= 1
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close this channel to new items
|
||||||
|
"""
|
||||||
|
self._closed = True
|
||||||
|
asyncio.ensure_future(self._flush_queue())
|
||||||
|
|
||||||
|
async def _flush_queue(self):
|
||||||
|
"""
|
||||||
|
To be called after the channel is closed. Pushes a number of self.__flush
|
||||||
|
objects to the queue to ensure no waiting consumers get deadlocked.
|
||||||
|
"""
|
||||||
|
if not self._flushed:
|
||||||
|
self._flushed = True
|
||||||
|
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
||||||
|
for _ in range(deadlocked_recievers):
|
||||||
|
await self._queue.put(self.__flush)
|
||||||
|
|
||||||
|
# A special signal object for flushing the queue when the channel is closed
|
||||||
|
__flush = object()
|
0
betterproto/lib/__init__.py
Normal file
0
betterproto/lib/__init__.py
Normal file
0
betterproto/lib/google/__init__.py
Normal file
0
betterproto/lib/google/__init__.py
Normal file
1312
betterproto/lib/google/protobuf/__init__.py
Normal file
1312
betterproto/lib/google/protobuf/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
2
betterproto/plugin.bat
Normal file
2
betterproto/plugin.bat
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
@SET plugin_dir=%~dp0
|
||||||
|
@python %plugin_dir%/plugin.py %*
|
@@ -1,112 +1,65 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
|
||||||
import os.path
|
import os.path
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, List, Tuple
|
from typing import List, Union
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
from betterproto.compile.importing import get_type_reference
|
||||||
|
from betterproto.compile.naming import (
|
||||||
|
pythonize_class_name,
|
||||||
|
pythonize_field_name,
|
||||||
|
pythonize_method_name,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# betterproto[compiler] specific dependencies
|
||||||
import black
|
import black
|
||||||
except ImportError:
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
print(
|
from google.protobuf.descriptor_pb2 import (
|
||||||
"Unable to import `black` formatter. Did you install the compiler feature with `pip install betterproto[compiler]`?"
|
|
||||||
)
|
|
||||||
raise SystemExit(1)
|
|
||||||
|
|
||||||
import jinja2
|
|
||||||
import stringcase
|
|
||||||
|
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
|
||||||
from google.protobuf.descriptor_pb2 import (
|
|
||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FieldDescriptorProto,
|
FieldDescriptorProto,
|
||||||
FileDescriptorProto,
|
)
|
||||||
ServiceDescriptorProto,
|
import google.protobuf.wrappers_pb2 as google_wrappers
|
||||||
)
|
import jinja2
|
||||||
|
except ImportError as err:
|
||||||
from betterproto.casing import safe_snake_case
|
missing_import = err.args[0][17:-1]
|
||||||
|
print(
|
||||||
|
"\033[31m"
|
||||||
|
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||||
|
"Please ensure that you've installed betterproto as "
|
||||||
|
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||||
|
"are included."
|
||||||
|
"\033[0m"
|
||||||
|
)
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
|
||||||
WRAPPER_TYPES = {
|
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
|
||||||
"google.protobuf.DoubleValue": "float",
|
if field.type in [1, 2]:
|
||||||
"google.protobuf.FloatValue": "float",
|
|
||||||
"google.protobuf.Int64Value": "int",
|
|
||||||
"google.protobuf.UInt64Value": "int",
|
|
||||||
"google.protobuf.Int32Value": "int",
|
|
||||||
"google.protobuf.UInt32Value": "int",
|
|
||||||
"google.protobuf.BoolValue": "bool",
|
|
||||||
"google.protobuf.StringValue": "str",
|
|
||||||
"google.protobuf.BytesValue": "bytes",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Return a Python type name for a proto type reference. Adds the import if
|
|
||||||
necessary.
|
|
||||||
"""
|
|
||||||
# If the package name is a blank string, then this should still work
|
|
||||||
# because by convention packages are lowercase and message/enum types are
|
|
||||||
# pascal-cased. May require refactoring in the future.
|
|
||||||
type_name = type_name.lstrip(".")
|
|
||||||
|
|
||||||
if type_name in WRAPPER_TYPES:
|
|
||||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
|
||||||
|
|
||||||
if type_name == "google.protobuf.Duration":
|
|
||||||
return "timedelta"
|
|
||||||
|
|
||||||
if type_name == "google.protobuf.Timestamp":
|
|
||||||
return "datetime"
|
|
||||||
|
|
||||||
if type_name.startswith(package):
|
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
|
||||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
|
||||||
# This is the current package, which has nested types flattened.
|
|
||||||
# foo.bar_thing => FooBarThing
|
|
||||||
cased = [stringcase.pascalcase(part) for part in parts]
|
|
||||||
type_name = f'"{"".join(cased)}"'
|
|
||||||
|
|
||||||
if "." in type_name:
|
|
||||||
# This is imported from another package. No need
|
|
||||||
# to use a forward ref and we need to add the import.
|
|
||||||
parts = type_name.split(".")
|
|
||||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
|
||||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
|
||||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
|
||||||
|
|
||||||
return type_name
|
|
||||||
|
|
||||||
|
|
||||||
def py_type(
|
|
||||||
package: str,
|
|
||||||
imports: set,
|
|
||||||
message: DescriptorProto,
|
|
||||||
descriptor: FieldDescriptorProto,
|
|
||||||
) -> str:
|
|
||||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
|
||||||
return "float"
|
return "float"
|
||||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
|
||||||
return "int"
|
return "int"
|
||||||
elif descriptor.type == 8:
|
elif field.type == 8:
|
||||||
return "bool"
|
return "bool"
|
||||||
elif descriptor.type == 9:
|
elif field.type == 9:
|
||||||
return "str"
|
return "str"
|
||||||
elif descriptor.type in [11, 14]:
|
elif field.type in [11, 14]:
|
||||||
# Type referencing another defined Message or a named enum
|
# Type referencing another defined Message or a named enum
|
||||||
return get_ref_type(package, imports, descriptor.type_name)
|
return get_type_reference(package, imports, field.type_name)
|
||||||
elif descriptor.type == 12:
|
elif field.type == 12:
|
||||||
return "bytes"
|
return "bytes"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
raise NotImplementedError(f"Unknown type {field.type}")
|
||||||
|
|
||||||
|
|
||||||
def get_py_zero(type_num: int) -> str:
|
def get_py_zero(type_num: int) -> Union[str, float]:
|
||||||
zero = 0
|
zero: Union[str, float] = 0
|
||||||
if type_num in []:
|
if type_num in []:
|
||||||
zero = 0.0
|
zero = 0.0
|
||||||
elif type_num == 8:
|
elif type_num == 8:
|
||||||
@@ -122,7 +75,7 @@ def get_py_zero(type_num: int) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def traverse(proto_file):
|
def traverse(proto_file):
|
||||||
def _traverse(path, items, prefix = ''):
|
def _traverse(path, items, prefix=""):
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
# Adjust the name since we flatten the heirarchy.
|
# Adjust the name since we flatten the heirarchy.
|
||||||
item.name = next_prefix = prefix + item.name
|
item.name = next_prefix = prefix + item.name
|
||||||
@@ -167,25 +120,28 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def generate_code(request, response):
|
def generate_code(request, response):
|
||||||
|
plugin_options = request.parameter.split(",") if request.parameter else []
|
||||||
|
|
||||||
env = jinja2.Environment(
|
env = jinja2.Environment(
|
||||||
trim_blocks=True,
|
trim_blocks=True,
|
||||||
lstrip_blocks=True,
|
lstrip_blocks=True,
|
||||||
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
|
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
|
||||||
)
|
)
|
||||||
template = env.get_template("template.py")
|
template = env.get_template("template.py.j2")
|
||||||
|
|
||||||
output_map = {}
|
output_map = {}
|
||||||
for proto_file in request.proto_file:
|
for proto_file in request.proto_file:
|
||||||
out = proto_file.package
|
if (
|
||||||
if out == "google.protobuf":
|
proto_file.package == "google.protobuf"
|
||||||
|
and "INCLUDE_GOOGLE" not in plugin_options
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not out:
|
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
|
||||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
|
||||||
|
|
||||||
if out not in output_map:
|
if output_file not in output_map:
|
||||||
output_map[out] = {"package": proto_file.package, "files": []}
|
output_map[output_file] = {"package": proto_file.package, "files": []}
|
||||||
output_map[out]["files"].append(proto_file)
|
output_map[output_file]["files"].append(proto_file)
|
||||||
|
|
||||||
# TODO: Figure out how to handle gRPC request/response messages and add
|
# TODO: Figure out how to handle gRPC request/response messages and add
|
||||||
# processing below for Service.
|
# processing below for Service.
|
||||||
@@ -204,17 +160,10 @@ def generate_code(request, response):
|
|||||||
"services": [],
|
"services": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
type_mapping = {}
|
|
||||||
|
|
||||||
for proto_file in options["files"]:
|
for proto_file in options["files"]:
|
||||||
# print(proto_file.message_type, file=sys.stderr)
|
item: DescriptorProto
|
||||||
# print(proto_file.service, file=sys.stderr)
|
|
||||||
# print(proto_file.source_code_info, file=sys.stderr)
|
|
||||||
|
|
||||||
for item, path in traverse(proto_file):
|
for item, path in traverse(proto_file):
|
||||||
# print(item, file=sys.stderr)
|
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
||||||
# print(path, file=sys.stderr)
|
|
||||||
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
|
|
||||||
|
|
||||||
if isinstance(item, DescriptorProto):
|
if isinstance(item, DescriptorProto):
|
||||||
# print(item, file=sys.stderr)
|
# print(item, file=sys.stderr)
|
||||||
@@ -231,7 +180,7 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i, f in enumerate(item.field):
|
for i, f in enumerate(item.field):
|
||||||
t = py_type(package, output["imports"], item, f)
|
t = py_type(package, output["imports"], f)
|
||||||
zero = get_py_zero(f.type)
|
zero = get_py_zero(f.type)
|
||||||
|
|
||||||
repeated = False
|
repeated = False
|
||||||
@@ -240,11 +189,13 @@ def generate_code(request, response):
|
|||||||
field_type = f.Type.Name(f.type).lower()[5:]
|
field_type = f.Type.Name(f.type).lower()[5:]
|
||||||
|
|
||||||
field_wraps = ""
|
field_wraps = ""
|
||||||
if f.type_name.startswith(
|
match_wrapper = re.match(
|
||||||
".google.protobuf"
|
r"\.google\.protobuf\.(.+)Value", f.type_name
|
||||||
) and f.type_name.endswith("Value"):
|
)
|
||||||
w = f.type_name.split(".").pop()[:-5].upper()
|
if match_wrapper:
|
||||||
field_wraps = f"betterproto.TYPE_{w}"
|
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||||
|
if hasattr(betterproto, wrapped_type):
|
||||||
|
field_wraps = f"betterproto.{wrapped_type}"
|
||||||
|
|
||||||
map_types = None
|
map_types = None
|
||||||
if f.type == 11:
|
if f.type == 11:
|
||||||
@@ -264,13 +215,11 @@ def generate_code(request, response):
|
|||||||
k = py_type(
|
k = py_type(
|
||||||
package,
|
package,
|
||||||
output["imports"],
|
output["imports"],
|
||||||
item,
|
|
||||||
nested.field[0],
|
nested.field[0],
|
||||||
)
|
)
|
||||||
v = py_type(
|
v = py_type(
|
||||||
package,
|
package,
|
||||||
output["imports"],
|
output["imports"],
|
||||||
item,
|
|
||||||
nested.field[1],
|
nested.field[1],
|
||||||
)
|
)
|
||||||
t = f"Dict[{k}, {v}]"
|
t = f"Dict[{k}, {v}]"
|
||||||
@@ -306,7 +255,7 @@ def generate_code(request, response):
|
|||||||
data["properties"].append(
|
data["properties"].append(
|
||||||
{
|
{
|
||||||
"name": f.name,
|
"name": f.name,
|
||||||
"py_name": safe_snake_case(f.name),
|
"py_name": pythonize_field_name(f.name),
|
||||||
"number": f.number,
|
"number": f.number,
|
||||||
"comment": get_comment(proto_file, path + [2, i]),
|
"comment": get_comment(proto_file, path + [2, i]),
|
||||||
"proto_type": int(f.type),
|
"proto_type": int(f.type),
|
||||||
@@ -347,17 +296,14 @@ def generate_code(request, response):
|
|||||||
|
|
||||||
data = {
|
data = {
|
||||||
"name": service.name,
|
"name": service.name,
|
||||||
"py_name": stringcase.pascalcase(service.name),
|
"py_name": pythonize_class_name(service.name),
|
||||||
"comment": get_comment(proto_file, [6, i]),
|
"comment": get_comment(proto_file, [6, i]),
|
||||||
"methods": [],
|
"methods": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for j, method in enumerate(service.method):
|
for j, method in enumerate(service.method):
|
||||||
if method.client_streaming:
|
|
||||||
raise NotImplementedError("Client streaming not yet supported")
|
|
||||||
|
|
||||||
input_message = None
|
input_message = None
|
||||||
input_type = get_ref_type(
|
input_type = get_type_reference(
|
||||||
package, output["imports"], method.input_type
|
package, output["imports"], method.input_type
|
||||||
).strip('"')
|
).strip('"')
|
||||||
for msg in output["messages"]:
|
for msg in output["messages"]:
|
||||||
@@ -371,23 +317,30 @@ def generate_code(request, response):
|
|||||||
data["methods"].append(
|
data["methods"].append(
|
||||||
{
|
{
|
||||||
"name": method.name,
|
"name": method.name,
|
||||||
"py_name": stringcase.snakecase(method.name),
|
"py_name": pythonize_method_name(method.name),
|
||||||
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
||||||
"route": f"/{package}.{service.name}/{method.name}",
|
"route": f"/{package}.{service.name}/{method.name}",
|
||||||
"input": get_ref_type(
|
"input": get_type_reference(
|
||||||
package, output["imports"], method.input_type
|
package, output["imports"], method.input_type
|
||||||
).strip('"'),
|
).strip('"'),
|
||||||
"input_message": input_message,
|
"input_message": input_message,
|
||||||
"output": get_ref_type(
|
"output": get_type_reference(
|
||||||
package, output["imports"], method.output_type
|
package,
|
||||||
|
output["imports"],
|
||||||
|
method.output_type,
|
||||||
|
unwrap=False,
|
||||||
).strip('"'),
|
).strip('"'),
|
||||||
"client_streaming": method.client_streaming,
|
"client_streaming": method.client_streaming,
|
||||||
"server_streaming": method.server_streaming,
|
"server_streaming": method.server_streaming,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if method.client_streaming:
|
||||||
|
output["typing_imports"].add("AsyncIterable")
|
||||||
|
output["typing_imports"].add("Iterable")
|
||||||
|
output["typing_imports"].add("Union")
|
||||||
if method.server_streaming:
|
if method.server_streaming:
|
||||||
output["typing_imports"].add("AsyncGenerator")
|
output["typing_imports"].add("AsyncIterator")
|
||||||
|
|
||||||
output["services"].append(data)
|
output["services"].append(data)
|
||||||
|
|
||||||
@@ -397,8 +350,7 @@ def generate_code(request, response):
|
|||||||
|
|
||||||
# Fill response
|
# Fill response
|
||||||
f = response.file.add()
|
f = response.file.add()
|
||||||
# print(filename, file=sys.stderr)
|
f.name = filename
|
||||||
f.name = filename.replace(".", os.path.sep) + ".py"
|
|
||||||
|
|
||||||
# Render and then format the output file.
|
# Render and then format the output file.
|
||||||
f.content = black.format_str(
|
f.content = black.format_str(
|
||||||
@@ -406,32 +358,23 @@ def generate_code(request, response):
|
|||||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
||||||
)
|
)
|
||||||
|
|
||||||
inits = set([""])
|
# Make each output directory a package with __init__ file
|
||||||
for f in response.file:
|
output_paths = set(pathlib.Path(path) for path in output_map.keys())
|
||||||
# Ensure output paths exist
|
init_files = (
|
||||||
# print(f.name, file=sys.stderr)
|
set(
|
||||||
dirnames = os.path.dirname(f.name)
|
directory.joinpath("__init__.py")
|
||||||
if dirnames:
|
for path in output_paths
|
||||||
os.makedirs(dirnames, exist_ok=True)
|
for directory in path.parents
|
||||||
base = ""
|
)
|
||||||
for part in dirnames.split(os.path.sep):
|
- output_paths
|
||||||
base = os.path.join(base, part)
|
)
|
||||||
inits.add(base)
|
|
||||||
|
|
||||||
for base in inits:
|
|
||||||
name = os.path.join(base, "__init__.py")
|
|
||||||
|
|
||||||
if os.path.exists(name):
|
|
||||||
# Never overwrite inits as they may have custom stuff in them.
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for init_file in init_files:
|
||||||
init = response.file.add()
|
init = response.file.add()
|
||||||
init.name = name
|
init.name = str(init_file)
|
||||||
init.content = b""
|
|
||||||
|
|
||||||
filenames = sorted([f.name for f in response.file])
|
for filename in sorted(output_paths.union(init_files)):
|
||||||
for fname in filenames:
|
print(f"Writing {filename}", file=sys.stderr)
|
||||||
print(f"Writing {fname}", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@@ -63,34 +63,72 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
async def {{ method.py_name }}(self
|
||||||
|
{%- if not method.client_streaming -%}
|
||||||
|
{%- if method.input_message and method.input_message.properties -%}, *,
|
||||||
|
{%- for field in method.input_message.properties -%}
|
||||||
|
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
|
||||||
|
Optional[{{ field.type }}]
|
||||||
|
{%- else -%}
|
||||||
|
{{ field.type }}
|
||||||
|
{%- endif -%} = {{ field.zero }}
|
||||||
|
{%- if not loop.last %}, {% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{# Client streaming: need a request iterator instead #}
|
||||||
|
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
|
||||||
|
{%- endif -%}
|
||||||
|
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if not method.client_streaming %}
|
||||||
request = {{ method.input }}()
|
request = {{ method.input }}()
|
||||||
{% for field in method.input_message.properties %}
|
{% for field in method.input_message.properties %}
|
||||||
{% if field.field_type == 'message' %}
|
{% if field.field_type == 'message' %}
|
||||||
if {{ field.name }} is not None:
|
if {{ field.py_name }} is not None:
|
||||||
request.{{ field.name }} = {{ field.name }}
|
request.{{ field.py_name }} = {{ field.py_name }}
|
||||||
{% else %}
|
{% else %}
|
||||||
request.{{ field.name }} = {{ field.name }}
|
request.{{ field.py_name }} = {{ field.py_name }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if method.server_streaming %}
|
{% if method.server_streaming %}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
async for response in self._stream_stream(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }},
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
async for response in self._unary_stream(
|
async for response in self._unary_stream(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }},
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
{% else %}
|
|
||||||
|
{% endif %}{# if client streaming #}
|
||||||
|
{% else %}{# i.e. not server streaming #}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
return await self._stream_unary(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }}
|
||||||
|
)
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
return await self._unary_unary(
|
return await self._unary_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }}
|
||||||
)
|
)
|
||||||
|
{% endif %}{# client streaming #}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
91
betterproto/tests/README.md
Normal file
91
betterproto/tests/README.md
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Standard Tests Development Guide
|
||||||
|
|
||||||
|
Standard test cases are found in [betterproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation.
|
||||||
|
|
||||||
|
```
|
||||||
|
inputs/
|
||||||
|
bool/
|
||||||
|
double/
|
||||||
|
int32/
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test case directory structure
|
||||||
|
|
||||||
|
Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bool/
|
||||||
|
bool.proto
|
||||||
|
bool.json # optional
|
||||||
|
test_bool.py # optional
|
||||||
|
```
|
||||||
|
|
||||||
|
### proto
|
||||||
|
|
||||||
|
`<name>.proto` — *The protobuf message to test*
|
||||||
|
|
||||||
|
```protobuf
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
bool value = 1;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can add multiple `.proto` files to the test case, as long as one file matches the directory name.
|
||||||
|
|
||||||
|
### json
|
||||||
|
|
||||||
|
`<name>.json` — *Test-data to validate the message with*
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"value": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### pytest
|
||||||
|
|
||||||
|
`test_<name>.py` — *Custom test to validate specific aspects of the generated class*
|
||||||
|
|
||||||
|
```python
|
||||||
|
from betterproto.tests.output_betterproto.bool.bool import Test
|
||||||
|
|
||||||
|
def test_value():
|
||||||
|
message = Test()
|
||||||
|
assert not message.value, "Boolean is False by default"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Standard tests
|
||||||
|
|
||||||
|
The following tests are automatically executed for all cases:
|
||||||
|
|
||||||
|
- [x] Can the generated python code be imported?
|
||||||
|
- [x] Can the generated message class be instantiated?
|
||||||
|
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
|
||||||
|
- _when `.json` is present_
|
||||||
|
|
||||||
|
## Running the tests
|
||||||
|
|
||||||
|
- `pipenv run generate`
|
||||||
|
This generates:
|
||||||
|
- `betterproto/tests/output_betterproto` — *the plugin generated python classes*
|
||||||
|
- `betterproto/tests/output_reference` — *reference implementation classes*
|
||||||
|
- `pipenv run test`
|
||||||
|
|
||||||
|
## Intentionally Failing tests
|
||||||
|
|
||||||
|
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
|
||||||
|
|
||||||
|
When running `pytest`, they show up as `x` or `X` in the test results.
|
||||||
|
|
||||||
|
```
|
||||||
|
betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `.` — PASSED
|
||||||
|
- `x` — XFAIL: expected failure
|
||||||
|
- `X` — XPASS: expected failure, but still passed
|
||||||
|
|
||||||
|
Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py)
|
0
betterproto/tests/__init__.py
Normal file
0
betterproto/tests/__init__.py
Normal file
189
betterproto/tests/generate.py
Normal file → Executable file
189
betterproto/tests/generate.py
Normal file → Executable file
@@ -1,84 +1,143 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
from betterproto.tests.util import (
|
||||||
|
get_directories,
|
||||||
|
inputs_path,
|
||||||
|
output_path_betterproto,
|
||||||
|
output_path_reference,
|
||||||
|
protoc_plugin,
|
||||||
|
protoc_reference,
|
||||||
|
)
|
||||||
|
|
||||||
# Force pure-python implementation instead of C++, otherwise imports
|
# Force pure-python implementation instead of C++, otherwise imports
|
||||||
# break things because we can't properly reset the symbol database.
|
# break things because we can't properly reset the symbol database.
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
import importlib
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from typing import Generator, Tuple
|
|
||||||
|
|
||||||
from google.protobuf import symbol_database
|
def clear_directory(dir_path: Path):
|
||||||
from google.protobuf.descriptor_pool import DescriptorPool
|
for file_or_directory in dir_path.glob("*"):
|
||||||
from google.protobuf.json_format import MessageToJson, Parse
|
if file_or_directory.is_dir():
|
||||||
|
shutil.rmtree(file_or_directory)
|
||||||
|
else:
|
||||||
|
file_or_directory.unlink()
|
||||||
|
|
||||||
|
|
||||||
root = os.path.dirname(os.path.realpath(__file__))
|
async def generate(whitelist: Set[str], verbose: bool):
|
||||||
|
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
|
||||||
|
|
||||||
|
path_whitelist = set()
|
||||||
|
name_whitelist = set()
|
||||||
|
for item in whitelist:
|
||||||
|
if item in test_case_names:
|
||||||
|
name_whitelist.add(item)
|
||||||
|
continue
|
||||||
|
path_whitelist.add(item)
|
||||||
|
|
||||||
|
generation_tasks = []
|
||||||
|
for test_case_name in sorted(test_case_names):
|
||||||
|
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
|
||||||
|
if (
|
||||||
|
whitelist
|
||||||
|
and str(test_case_input_path) not in path_whitelist
|
||||||
|
and test_case_name not in name_whitelist
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
generation_tasks.append(
|
||||||
|
generate_test_case_output(test_case_input_path, test_case_name, verbose)
|
||||||
|
)
|
||||||
|
|
||||||
|
failed_test_cases = []
|
||||||
|
# Wait for all subprocs and match any failures to names to report
|
||||||
|
for test_case_name, result in zip(
|
||||||
|
sorted(test_case_names), await asyncio.gather(*generation_tasks)
|
||||||
|
):
|
||||||
|
if result != 0:
|
||||||
|
failed_test_cases.append(test_case_name)
|
||||||
|
|
||||||
|
if failed_test_cases:
|
||||||
|
sys.stderr.write(
|
||||||
|
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
|
||||||
|
)
|
||||||
|
for failed_test_case in failed_test_cases:
|
||||||
|
sys.stderr.write(f"- {failed_test_case}\n")
|
||||||
|
|
||||||
|
|
||||||
def get_files(end: str) -> Generator[str, None, None]:
|
async def generate_test_case_output(
|
||||||
for r, dirs, files in os.walk(root):
|
test_case_input_path: Path, test_case_name: str, verbose: bool
|
||||||
for filename in [f for f in files if f.endswith(end)]:
|
) -> int:
|
||||||
yield os.path.join(r, filename)
|
"""
|
||||||
|
Returns the max of the subprocess return values
|
||||||
|
"""
|
||||||
|
|
||||||
|
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||||
|
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
|
||||||
|
|
||||||
|
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||||
|
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||||
|
|
||||||
|
clear_directory(test_case_output_path_reference)
|
||||||
|
clear_directory(test_case_output_path_betterproto)
|
||||||
|
|
||||||
|
(
|
||||||
|
(ref_out, ref_err, ref_code),
|
||||||
|
(plg_out, plg_err, plg_code),
|
||||||
|
) = await asyncio.gather(
|
||||||
|
protoc_reference(test_case_input_path, test_case_output_path_reference),
|
||||||
|
protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
|
||||||
|
)
|
||||||
|
|
||||||
|
message = f"Generated output for {test_case_name!r}"
|
||||||
|
if verbose:
|
||||||
|
print(f"\033[31;1;4m{message}\033[0m")
|
||||||
|
if ref_out:
|
||||||
|
sys.stdout.buffer.write(ref_out)
|
||||||
|
if ref_err:
|
||||||
|
sys.stderr.buffer.write(ref_err)
|
||||||
|
if plg_out:
|
||||||
|
sys.stdout.buffer.write(plg_out)
|
||||||
|
if plg_err:
|
||||||
|
sys.stderr.buffer.write(plg_err)
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
else:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
return max(ref_code, plg_code)
|
||||||
|
|
||||||
|
|
||||||
def get_base(filename: str) -> str:
|
HELP = "\n".join(
|
||||||
return os.path.splitext(os.path.basename(filename))[0]
|
(
|
||||||
|
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
|
||||||
|
"Generate python classes for standard tests.",
|
||||||
|
"",
|
||||||
|
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||||||
|
" python generate.py inputs/bool inputs/double inputs/enum",
|
||||||
|
"",
|
||||||
|
"NAMES One or more test-case names to generate classes for.",
|
||||||
|
" python generate.py bool double enums",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_ext(filename: str, ext: str) -> str:
|
def main():
|
||||||
if not filename.endswith(ext):
|
if set(sys.argv).intersection({"-h", "--help"}):
|
||||||
return filename + ext
|
print(HELP)
|
||||||
return filename
|
return
|
||||||
|
if sys.argv[1:2] == ["-v"]:
|
||||||
|
verbose = True
|
||||||
|
whitelist = set(sys.argv[2:])
|
||||||
|
else:
|
||||||
|
verbose = False
|
||||||
|
whitelist = set(sys.argv[1:])
|
||||||
|
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.chdir(root)
|
main()
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
|
||||||
proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]]
|
|
||||||
bases = {get_base(f) for f in proto_files}
|
|
||||||
json_files = [
|
|
||||||
f for f in get_files(".json") if get_base(f).split("-")[0] in bases
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
proto_files = get_files(".proto")
|
|
||||||
json_files = get_files(".json")
|
|
||||||
|
|
||||||
for filename in proto_files:
|
|
||||||
print(f"Generating code for {os.path.basename(filename)}")
|
|
||||||
subprocess.run(
|
|
||||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
|
||||||
)
|
|
||||||
subprocess.run(
|
|
||||||
f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}",
|
|
||||||
shell=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for filename in json_files:
|
|
||||||
# Reset the internal symbol database so we can import the `Test` message
|
|
||||||
# multiple times. Ugh.
|
|
||||||
sym = symbol_database.Default()
|
|
||||||
sym.pool = DescriptorPool()
|
|
||||||
|
|
||||||
parts = get_base(filename).split("-")
|
|
||||||
out = filename.replace(".json", ".bin")
|
|
||||||
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
|
||||||
|
|
||||||
imported = importlib.import_module(f"{parts[0]}_pb2")
|
|
||||||
input_json = open(filename).read()
|
|
||||||
parsed = Parse(input_json, imported.Test())
|
|
||||||
serialized = parsed.SerializeToString()
|
|
||||||
preserve = "casing" not in filename
|
|
||||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
|
|
||||||
|
|
||||||
s_loaded = json.loads(serialized_json)
|
|
||||||
in_loaded = json.loads(input_json)
|
|
||||||
|
|
||||||
if s_loaded != in_loaded:
|
|
||||||
raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded)
|
|
||||||
|
|
||||||
open(out, "wb").write(serialized)
|
|
||||||
|
0
betterproto/tests/grpc/__init__.py
Normal file
0
betterproto/tests/grpc/__init__.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import asyncio
|
||||||
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
|
DoThingResponse,
|
||||||
|
DoThingRequest,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
TestStub as ThingServiceClient,
|
||||||
|
)
|
||||||
|
import grpclib
|
||||||
|
from grpclib.testing import ChannelFor
|
||||||
|
import pytest
|
||||||
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
|
from .thing_service import ThingService
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_client(client, name="clean room", **kwargs):
|
||||||
|
response = await client.do_thing(name=name)
|
||||||
|
assert response.names == [name]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_request_meta_recieved(deadline, metadata):
|
||||||
|
def server_side_test(stream):
|
||||||
|
assert stream.deadline._timestamp == pytest.approx(
|
||||||
|
deadline._timestamp, 1
|
||||||
|
), "The provided deadline should be recieved serverside"
|
||||||
|
assert (
|
||||||
|
stream.metadata["authorization"] == metadata["authorization"]
|
||||||
|
), "The provided authorization metadata should be recieved serverside"
|
||||||
|
|
||||||
|
return server_side_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_service_call():
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
await _test_client(ThingServiceClient(channel))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_call_with_upfront_request_params():
|
||||||
|
# Setting deadline
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
await _test_client(
|
||||||
|
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setting timeout
|
||||||
|
timeout = 99
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
await _test_client(
|
||||||
|
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_call_lower_level_with_overrides():
|
||||||
|
THING_TO_DO = "get milk"
|
||||||
|
|
||||||
|
# Setting deadline
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||||
|
kwarg_metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
response = await client._unary_unary(
|
||||||
|
"/service.Test/DoThing",
|
||||||
|
DoThingRequest(THING_TO_DO),
|
||||||
|
DoThingResponse,
|
||||||
|
deadline=kwarg_deadline,
|
||||||
|
metadata=kwarg_metadata,
|
||||||
|
)
|
||||||
|
assert response.names == [THING_TO_DO]
|
||||||
|
|
||||||
|
# Setting timeout
|
||||||
|
timeout = 99
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
kwarg_timeout = 9000
|
||||||
|
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||||
|
kwarg_metadata = {"authorization": "09876"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[
|
||||||
|
ThingService(
|
||||||
|
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
) as channel:
|
||||||
|
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
response = await client._unary_unary(
|
||||||
|
"/service.Test/DoThing",
|
||||||
|
DoThingRequest(THING_TO_DO),
|
||||||
|
DoThingResponse,
|
||||||
|
timeout=kwarg_timeout,
|
||||||
|
metadata=kwarg_metadata,
|
||||||
|
)
|
||||||
|
assert response.names == [THING_TO_DO]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_gen_for_unary_stream_request():
|
||||||
|
thing_name = "my milkshakes"
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
expected_versions = [5, 4, 3, 2, 1]
|
||||||
|
async for response in client.get_thing_versions(name=thing_name):
|
||||||
|
assert response.name == thing_name
|
||||||
|
assert response.version == expected_versions.pop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_gen_for_stream_stream_request():
|
||||||
|
some_things = ["cake", "cricket", "coral reef"]
|
||||||
|
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||||
|
expected_things = (*some_things, *more_things)
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||||
|
# immediately and we'll use it to send more_things later, after recieving some
|
||||||
|
# results
|
||||||
|
request_chan = AsyncChannel()
|
||||||
|
send_initial_requests = asyncio.ensure_future(
|
||||||
|
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||||
|
)
|
||||||
|
response_index = 0
|
||||||
|
async for response in client.get_different_things(request_chan):
|
||||||
|
assert response.name == expected_things[response_index]
|
||||||
|
assert response.version == response_index + 1
|
||||||
|
response_index += 1
|
||||||
|
if more_things:
|
||||||
|
# Send some more requests as we recieve reponses to be sure coordination of
|
||||||
|
# send/recieve events doesn't matter
|
||||||
|
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||||
|
elif not send_initial_requests.done():
|
||||||
|
# Make sure the sending task it completed
|
||||||
|
await send_initial_requests
|
||||||
|
else:
|
||||||
|
# No more things to send make sure channel is closed
|
||||||
|
request_chan.close()
|
||||||
|
assert response_index == len(
|
||||||
|
expected_things
|
||||||
|
), "Didn't recieve all exptected responses"
|
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import asyncio
|
||||||
|
import betterproto
|
||||||
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import pytest
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message(betterproto.Message):
|
||||||
|
body: str = betterproto.string_field(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_responses():
|
||||||
|
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientStub:
|
||||||
|
async def connect(self, requests: AsyncIterator):
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
async for request in requests:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield request
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield Message("Done")
|
||||||
|
|
||||||
|
|
||||||
|
async def to_list(generator: AsyncIterator):
|
||||||
|
result = []
|
||||||
|
async for value in generator:
|
||||||
|
result.append(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
# channel = Channel(host='127.0.0.1', port=50051)
|
||||||
|
# return ClientStub(channel)
|
||||||
|
return ClientStub()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_before_connect_and_close_automatically(
|
||||||
|
client, expected_responses
|
||||||
|
):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||||
|
)
|
||||||
|
responses = client.connect(requests)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_after_connect_and_close_automatically(
|
||||||
|
client, expected_responses
|
||||||
|
):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||||||
|
)
|
||||||
|
requests.close()
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send(Message(body="Hello world 1"))
|
||||||
|
await requests.send(Message(body="Hello world 2"))
|
||||||
|
requests.close()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send(Message(body="Hello world 1"))
|
||||||
|
await requests.send(Message(body="Hello world 2"))
|
||||||
|
responses = client.connect(requests)
|
||||||
|
requests.close()
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
83
betterproto/tests/grpc/thing_service.py
Normal file
83
betterproto/tests/grpc/thing_service.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
|
DoThingResponse,
|
||||||
|
DoThingRequest,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
TestStub as ThingServiceClient,
|
||||||
|
)
|
||||||
|
import grpclib
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class ThingService:
|
||||||
|
def __init__(self, test_hook=None):
|
||||||
|
# This lets us pass assertions to the servicer ;)
|
||||||
|
self.test_hook = test_hook
|
||||||
|
|
||||||
|
async def do_thing(
|
||||||
|
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||||
|
):
|
||||||
|
request = await stream.recv_message()
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
await stream.send_message(DoThingResponse([request.name]))
|
||||||
|
|
||||||
|
async def do_many_things(
|
||||||
|
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||||
|
):
|
||||||
|
thing_names = [request.name for request in stream]
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
await stream.send_message(DoThingResponse(thing_names))
|
||||||
|
|
||||||
|
async def get_thing_versions(
|
||||||
|
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||||
|
):
|
||||||
|
request = await stream.recv_message()
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
for version_num in range(1, 6):
|
||||||
|
await stream.send_message(
|
||||||
|
GetThingResponse(name=request.name, version=version_num)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_different_things(
|
||||||
|
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||||
|
):
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
# Respond to each input item immediately
|
||||||
|
response_num = 0
|
||||||
|
async for request in stream:
|
||||||
|
response_num += 1
|
||||||
|
await stream.send_message(
|
||||||
|
GetThingResponse(name=request.name, version=response_num)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||||
|
return {
|
||||||
|
"/service.Test/DoThing": grpclib.const.Handler(
|
||||||
|
self.do_thing,
|
||||||
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
|
DoThingRequest,
|
||||||
|
DoThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||||
|
self.do_many_things,
|
||||||
|
grpclib.const.Cardinality.STREAM_UNARY,
|
||||||
|
DoThingRequest,
|
||||||
|
DoThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||||
|
self.get_thing_versions,
|
||||||
|
grpclib.const.Cardinality.UNARY_STREAM,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||||
|
self.get_different_things,
|
||||||
|
grpclib.const.Cardinality.STREAM_STREAM,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
),
|
||||||
|
}
|
6
betterproto/tests/inputs/bool/test_bool.py
Normal file
6
betterproto/tests/inputs/bool/test_bool.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from betterproto.tests.output_betterproto.bool import Test
|
||||||
|
|
||||||
|
|
||||||
|
def test_value():
|
||||||
|
message = Test()
|
||||||
|
assert not message.value, "Boolean is False by default"
|
@@ -9,4 +9,10 @@ enum my_enum {
|
|||||||
message Test {
|
message Test {
|
||||||
int32 camelCase = 1;
|
int32 camelCase = 1;
|
||||||
my_enum snake_case = 2;
|
my_enum snake_case = 2;
|
||||||
|
snake_case_message snake_case_message = 3;
|
||||||
|
int32 UPPERCASE = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message snake_case_message {
|
||||||
|
|
||||||
}
|
}
|
23
betterproto/tests/inputs/casing/test_casing.py
Normal file
23
betterproto/tests/inputs/casing/test_casing.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import betterproto.tests.output_betterproto.casing as casing
|
||||||
|
from betterproto.tests.output_betterproto.casing import Test
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_attributes():
|
||||||
|
message = Test()
|
||||||
|
assert hasattr(
|
||||||
|
message, "snake_case_message"
|
||||||
|
), "snake_case field name is same in python"
|
||||||
|
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
|
||||||
|
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_casing():
|
||||||
|
assert hasattr(
|
||||||
|
casing, "SnakeCaseMessage"
|
||||||
|
), "snake_case Message name is converted to CamelCase in python"
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_casing():
|
||||||
|
assert hasattr(
|
||||||
|
casing, "MyEnum"
|
||||||
|
), "snake_case Enum name is converted to CamelCase in python"
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
int32 UPPERCASE = 1;
|
||||||
|
int32 UPPERCASE_V2 = 2;
|
||||||
|
int32 UPPER_CAMEL_CASE = 3;
|
||||||
|
}
|
@@ -0,0 +1,14 @@
|
|||||||
|
from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_casing():
|
||||||
|
message = Test()
|
||||||
|
assert hasattr(
|
||||||
|
message, "uppercase"
|
||||||
|
), "UPPERCASE attribute is converted to 'uppercase' in python"
|
||||||
|
assert hasattr(
|
||||||
|
message, "uppercase_v2"
|
||||||
|
), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
|
||||||
|
assert hasattr(
|
||||||
|
message, "upper_camel_case"
|
||||||
|
), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
|
22
betterproto/tests/inputs/config.py
Normal file
22
betterproto/tests/inputs/config.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
|
||||||
|
# Remove from list when fixed.
|
||||||
|
xfail = {
|
||||||
|
"import_circular_dependency",
|
||||||
|
"oneof_enum", # 63
|
||||||
|
"namespace_keywords", # 70
|
||||||
|
"namespace_builtin_types", # 53
|
||||||
|
"googletypes_struct", # 9
|
||||||
|
"googletypes_value", # 9
|
||||||
|
"enum_skipped_value", # 93
|
||||||
|
"import_capitalized_package",
|
||||||
|
"example", # This is the example in the readme. Not a test.
|
||||||
|
}
|
||||||
|
|
||||||
|
services = {
|
||||||
|
"googletypes_response",
|
||||||
|
"googletypes_response_embedded",
|
||||||
|
"service",
|
||||||
|
"import_service_input_message",
|
||||||
|
"googletypes_service_returns_empty",
|
||||||
|
"googletypes_service_returns_googletype",
|
||||||
|
}
|
@@ -0,0 +1,12 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
enum MyEnum {
|
||||||
|
ZERO = 0;
|
||||||
|
ONE = 1;
|
||||||
|
// TWO = 2;
|
||||||
|
THREE = 3;
|
||||||
|
FOUR = 4;
|
||||||
|
}
|
||||||
|
MyEnum x = 1;
|
||||||
|
}
|
@@ -0,0 +1,18 @@
|
|||||||
|
from betterproto.tests.output_betterproto.enum_skipped_value import (
|
||||||
|
Test,
|
||||||
|
TestMyEnum,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="#93")
|
||||||
|
def test_message_attributes():
|
||||||
|
assert (
|
||||||
|
Test(x=TestMyEnum.ONE).to_dict()["x"] == "ONE"
|
||||||
|
), "MyEnum.ONE is not serialized to 'ONE'"
|
||||||
|
assert (
|
||||||
|
Test(x=TestMyEnum.THREE).to_dict()["x"] == "THREE"
|
||||||
|
), "MyEnum.THREE is not serialized to 'THREE'"
|
||||||
|
assert (
|
||||||
|
Test(x=TestMyEnum.FOUR).to_dict()["x"] == "FOUR"
|
||||||
|
), "MyEnum.FOUR is not serialized to 'FOUR'"
|
8
betterproto/tests/inputs/example/example.proto
Normal file
8
betterproto/tests/inputs/example/example.proto
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package hello;
|
||||||
|
|
||||||
|
// Greeting represents a message you can tell a user.
|
||||||
|
message Greeting {
|
||||||
|
string message = 1;
|
||||||
|
}
|
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"foo": 4294967295,
|
||||||
|
"bar": -2147483648,
|
||||||
|
"baz": "18446744073709551615",
|
||||||
|
"qux": "-9223372036854775808"
|
||||||
|
}
|
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
fixed32 foo = 1;
|
||||||
|
sfixed32 bar = 2;
|
||||||
|
fixed64 baz = 3;
|
||||||
|
sfixed64 qux = 4;
|
||||||
|
}
|
@@ -1,5 +1,7 @@
|
|||||||
{
|
{
|
||||||
"maybe": false,
|
"maybe": false,
|
||||||
"ts": "1972-01-01T10:00:20.021Z",
|
"ts": "1972-01-01T10:00:20.021Z",
|
||||||
"duration": "1.200s"
|
"duration": "1.200s",
|
||||||
|
"important": 10,
|
||||||
|
"empty": {}
|
||||||
}
|
}
|
@@ -3,10 +3,12 @@ syntax = "proto3";
|
|||||||
import "google/protobuf/duration.proto";
|
import "google/protobuf/duration.proto";
|
||||||
import "google/protobuf/timestamp.proto";
|
import "google/protobuf/timestamp.proto";
|
||||||
import "google/protobuf/wrappers.proto";
|
import "google/protobuf/wrappers.proto";
|
||||||
|
import "google/protobuf/empty.proto";
|
||||||
|
|
||||||
message Test {
|
message Test {
|
||||||
google.protobuf.BoolValue maybe = 1;
|
google.protobuf.BoolValue maybe = 1;
|
||||||
google.protobuf.Timestamp ts = 2;
|
google.protobuf.Timestamp ts = 2;
|
||||||
google.protobuf.Duration duration = 3;
|
google.protobuf.Duration duration = 3;
|
||||||
google.protobuf.Int32Value important = 4;
|
google.protobuf.Int32Value important = 4;
|
||||||
|
google.protobuf.Empty empty = 5;
|
||||||
}
|
}
|
@@ -0,0 +1,21 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/wrappers.proto";
|
||||||
|
|
||||||
|
// Tests that wrapped values can be used directly as return values
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
|
||||||
|
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
|
||||||
|
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||||
|
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
|
||||||
|
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||||
|
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
|
||||||
|
rpc GetBool (Input) returns (google.protobuf.BoolValue);
|
||||||
|
rpc GetString (Input) returns (google.protobuf.StringValue);
|
||||||
|
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
message Input {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,54 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import betterproto.lib.google.protobuf as protobuf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.tests.mocks import MockChannel
|
||||||
|
from betterproto.tests.output_betterproto.googletypes_response import TestStub
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||||
|
(TestStub.get_float, protobuf.FloatValue, 2.5),
|
||||||
|
(TestStub.get_int64, protobuf.Int64Value, -64),
|
||||||
|
(TestStub.get_u_int64, protobuf.UInt64Value, 64),
|
||||||
|
(TestStub.get_int32, protobuf.Int32Value, -32),
|
||||||
|
(TestStub.get_u_int32, protobuf.UInt32Value, 32),
|
||||||
|
(TestStub.get_bool, protobuf.BoolValue, True),
|
||||||
|
(TestStub.get_string, protobuf.StringValue, "string"),
|
||||||
|
(TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
|
async def test_channel_recieves_wrapped_type(
|
||||||
|
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||||
|
):
|
||||||
|
wrapped_value = wrapper_class()
|
||||||
|
wrapped_value.value = value
|
||||||
|
channel = MockChannel(responses=[wrapped_value])
|
||||||
|
service = TestStub(channel)
|
||||||
|
|
||||||
|
await service_method(service)
|
||||||
|
|
||||||
|
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||||
|
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.xfail
|
||||||
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
|
async def test_service_unwraps_response(
|
||||||
|
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
grpclib does not unwrap wrapper values returned by services
|
||||||
|
"""
|
||||||
|
wrapped_value = wrapper_class()
|
||||||
|
wrapped_value.value = value
|
||||||
|
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||||
|
|
||||||
|
response_value = await service_method(service)
|
||||||
|
|
||||||
|
assert response_value == value
|
||||||
|
assert type(response_value) == type(value)
|
@@ -0,0 +1,24 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/wrappers.proto";
|
||||||
|
|
||||||
|
// Tests that wrapped values are supported as part of output message
|
||||||
|
service Test {
|
||||||
|
rpc getOutput (Input) returns (Output);
|
||||||
|
}
|
||||||
|
|
||||||
|
message Input {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
message Output {
|
||||||
|
google.protobuf.DoubleValue double_value = 1;
|
||||||
|
google.protobuf.FloatValue float_value = 2;
|
||||||
|
google.protobuf.Int64Value int64_value = 3;
|
||||||
|
google.protobuf.UInt64Value uint64_value = 4;
|
||||||
|
google.protobuf.Int32Value int32_value = 5;
|
||||||
|
google.protobuf.UInt32Value uint32_value = 6;
|
||||||
|
google.protobuf.BoolValue bool_value = 7;
|
||||||
|
google.protobuf.StringValue string_value = 8;
|
||||||
|
google.protobuf.BytesValue bytes_value = 9;
|
||||||
|
}
|
@@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.tests.mocks import MockChannel
|
||||||
|
from betterproto.tests.output_betterproto.googletypes_response_embedded import (
|
||||||
|
Output,
|
||||||
|
TestStub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
||||||
|
"""
|
||||||
|
We do not not need to implement value unwrapping for embedded well-known types,
|
||||||
|
as this is already handled by grpclib. This test merely shows that this is the case.
|
||||||
|
"""
|
||||||
|
output = Output(
|
||||||
|
double_value=10.0,
|
||||||
|
float_value=12.0,
|
||||||
|
int64_value=-13,
|
||||||
|
uint64_value=14,
|
||||||
|
int32_value=-15,
|
||||||
|
uint32_value=16,
|
||||||
|
bool_value=True,
|
||||||
|
string_value="string",
|
||||||
|
bytes_value=bytes(0xFF)[0:4],
|
||||||
|
)
|
||||||
|
|
||||||
|
service = TestStub(MockChannel(responses=[output]))
|
||||||
|
response = await service.get_output()
|
||||||
|
|
||||||
|
assert response.double_value == 10.0
|
||||||
|
assert response.float_value == 12.0
|
||||||
|
assert response.int64_value == -13
|
||||||
|
assert response.uint64_value == 14
|
||||||
|
assert response.int32_value == -15
|
||||||
|
assert response.uint32_value == 16
|
||||||
|
assert response.bool_value
|
||||||
|
assert response.string_value == "string"
|
||||||
|
assert response.bytes_value == bytes(0xFF)[0:4]
|
@@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/empty.proto";
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc Send (RequestMessage) returns (google.protobuf.Empty) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message RequestMessage {
|
||||||
|
}
|
@@ -0,0 +1,16 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/empty.proto";
|
||||||
|
import "google/protobuf/struct.proto";
|
||||||
|
|
||||||
|
// Tests that imports are generated correctly when returning Google well-known types
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty);
|
||||||
|
rpc GetStruct (RequestMessage) returns (google.protobuf.Struct);
|
||||||
|
rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue);
|
||||||
|
rpc GetValue (RequestMessage) returns (google.protobuf.Value);
|
||||||
|
}
|
||||||
|
|
||||||
|
message RequestMessage {
|
||||||
|
}
|
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"struct": {
|
||||||
|
"key": true
|
||||||
|
}
|
||||||
|
}
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/struct.proto";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
google.protobuf.Struct struct = 1;
|
||||||
|
}
|
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"value1": "hello world",
|
||||||
|
"value2": true,
|
||||||
|
"value3": 1,
|
||||||
|
"value4": null,
|
||||||
|
"value5": [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3
|
||||||
|
]
|
||||||
|
}
|
@@ -0,0 +1,13 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/struct.proto";
|
||||||
|
|
||||||
|
// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
google.protobuf.Value value1 = 1;
|
||||||
|
google.protobuf.Value value2 = 2;
|
||||||
|
google.protobuf.Value value3 = 3;
|
||||||
|
google.protobuf.Value value4 = 4;
|
||||||
|
google.protobuf.Value value5 = 5;
|
||||||
|
}
|
@@ -0,0 +1,8 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
|
||||||
|
package Capitalized;
|
||||||
|
|
||||||
|
message Message {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "capitalized.proto";
|
||||||
|
|
||||||
|
// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
Capitalized.Message message = 1;
|
||||||
|
}
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package package.childpackage;
|
||||||
|
|
||||||
|
message ChildMessage {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "package_message.proto";
|
||||||
|
|
||||||
|
// Tests generated imports when a message in a package refers to a message in a nested child package.
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
package.PackageMessage message = 1;
|
||||||
|
}
|
@@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "child.proto";
|
||||||
|
|
||||||
|
package package;
|
||||||
|
|
||||||
|
message PackageMessage {
|
||||||
|
package.childpackage.ChildMessage c = 1;
|
||||||
|
}
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package childpackage;
|
||||||
|
|
||||||
|
message Message {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "child.proto";
|
||||||
|
|
||||||
|
// Tests generated imports when a message in root refers to a message in a child package.
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
childpackage.Message child = 1;
|
||||||
|
}
|
@@ -0,0 +1,28 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "root.proto";
|
||||||
|
import "other.proto";
|
||||||
|
|
||||||
|
// This test-case verifies support for circular dependencies in the generated python files.
|
||||||
|
//
|
||||||
|
// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
|
||||||
|
//
|
||||||
|
// Scenario:
|
||||||
|
//
|
||||||
|
// The proto messages depend on each other in a non-circular way:
|
||||||
|
//
|
||||||
|
// Test -------> RootPackageMessage <--------------.
|
||||||
|
// `------------------------------------> OtherPackageMessage
|
||||||
|
//
|
||||||
|
// Test and RootPackageMessage are in different files, but belong to the same package (root):
|
||||||
|
//
|
||||||
|
// (Test -------> RootPackageMessage) <------------.
|
||||||
|
// `------------------------------------> OtherPackageMessage
|
||||||
|
//
|
||||||
|
// After grouping the packages into single files or modules, a circular dependency is created:
|
||||||
|
//
|
||||||
|
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
|
||||||
|
message Test {
|
||||||
|
RootPackageMessage message = 1;
|
||||||
|
other.OtherPackageMessage other = 2;
|
||||||
|
}
|
@@ -0,0 +1,8 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "root.proto";
|
||||||
|
package other;
|
||||||
|
|
||||||
|
message OtherPackageMessage {
|
||||||
|
RootPackageMessage rootPackageMessage = 1;
|
||||||
|
}
|
@@ -0,0 +1,5 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message RootPackageMessage {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,6 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package cousin.cousin_subpackage;
|
||||||
|
|
||||||
|
message CousinMessage {
|
||||||
|
}
|
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package test.subpackage;
|
||||||
|
|
||||||
|
import "cousin.proto";
|
||||||
|
|
||||||
|
// Verify that we can import message unrelated to us
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
cousin.cousin_subpackage.CousinMessage message = 1;
|
||||||
|
}
|
@@ -0,0 +1,6 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package cousin.subpackage;
|
||||||
|
|
||||||
|
message CousinMessage {
|
||||||
|
}
|
@@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package test.subpackage;
|
||||||
|
|
||||||
|
import "cousin.proto";
|
||||||
|
|
||||||
|
// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
cousin.subpackage.CousinMessage message = 1;
|
||||||
|
}
|
@@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "users_v1.proto";
|
||||||
|
import "posts_v1.proto";
|
||||||
|
|
||||||
|
// Tests generated message can correctly reference two packages with the same leaf-name
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
users.v1.User user = 1;
|
||||||
|
posts.v1.Post post = 2;
|
||||||
|
}
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package posts.v1;
|
||||||
|
|
||||||
|
message Post {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package users.v1;
|
||||||
|
|
||||||
|
message User {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,12 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "parent_package_message.proto";
|
||||||
|
|
||||||
|
package parent.child;
|
||||||
|
|
||||||
|
// Tests generated imports when a message refers to a message defined in its parent package
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
ParentPackageMessage message_implicit = 1;
|
||||||
|
parent.ParentPackageMessage message_explicit = 2;
|
||||||
|
}
|
@@ -0,0 +1,6 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package parent;
|
||||||
|
|
||||||
|
message ParentPackageMessage {
|
||||||
|
}
|
@@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package child;
|
||||||
|
|
||||||
|
import "root.proto";
|
||||||
|
|
||||||
|
// Verify that we can import root message from child package
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
RootMessage message = 1;
|
||||||
|
}
|
@@ -0,0 +1,5 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
|
||||||
|
message RootMessage {
|
||||||
|
}
|
@@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "sibling.proto";
|
||||||
|
|
||||||
|
// Tests generated imports when a message in the root package refers to another message in the root package
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
SiblingMessage sibling = 1;
|
||||||
|
}
|
@@ -0,0 +1,5 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message SiblingMessage {
|
||||||
|
|
||||||
|
}
|
@@ -0,0 +1,15 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "request_message.proto";
|
||||||
|
|
||||||
|
// Tests generated service correctly imports the RequestMessage
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc DoThing (RequestMessage) returns (RequestResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message RequestResponse {
|
||||||
|
int32 value = 1;
|
||||||
|
}
|
||||||
|
|
@@ -0,0 +1,5 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message RequestMessage {
|
||||||
|
int32 argument = 1;
|
||||||
|
}
|
@@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.tests.mocks import MockChannel
|
||||||
|
from betterproto.tests.output_betterproto.import_service_input_message import (
|
||||||
|
RequestResponse,
|
||||||
|
TestStub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_correctly_imports_reference_message():
|
||||||
|
mock_response = RequestResponse(value=10)
|
||||||
|
service = TestStub(MockChannel([mock_response]))
|
||||||
|
response = await service.do_thing()
|
||||||
|
assert mock_response == response
|
4
betterproto/tests/inputs/int32/int32.json
Normal file
4
betterproto/tests/inputs/int32/int32.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"positive": 150,
|
||||||
|
"negative": -150
|
||||||
|
}
|
@@ -3,5 +3,6 @@ syntax = "proto3";
|
|||||||
// Some documentation about the Test message.
|
// Some documentation about the Test message.
|
||||||
message Test {
|
message Test {
|
||||||
// Some documentation about the count.
|
// Some documentation about the count.
|
||||||
int32 count = 1;
|
int32 positive = 1;
|
||||||
|
int32 negative = 2;
|
||||||
}
|
}
|
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"int": "value-for-int",
|
||||||
|
"float": "value-for-float",
|
||||||
|
"complex": "value-for-complex",
|
||||||
|
"list": "value-for-list",
|
||||||
|
"tuple": "value-for-tuple",
|
||||||
|
"range": "value-for-range",
|
||||||
|
"str": "value-for-str",
|
||||||
|
"bytearray": "value-for-bytearray",
|
||||||
|
"bytes": "value-for-bytes",
|
||||||
|
"memoryview": "value-for-memoryview",
|
||||||
|
"set": "value-for-set",
|
||||||
|
"frozenset": "value-for-frozenset",
|
||||||
|
"map": "value-for-map",
|
||||||
|
"bool": "value-for-bool"
|
||||||
|
}
|
@@ -0,0 +1,38 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
// Tests that messages may contain fields with names that are python types
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
// https://docs.python.org/2/library/stdtypes.html#numeric-types-int-float-long-complex
|
||||||
|
string int = 1;
|
||||||
|
string float = 2;
|
||||||
|
string complex = 3;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#sequence-types-list-tuple-range
|
||||||
|
string list = 4;
|
||||||
|
string tuple = 5;
|
||||||
|
string range = 6;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#str
|
||||||
|
string str = 7;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#bytearray-objects
|
||||||
|
string bytearray = 8;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#bytes-and-bytearray-operations
|
||||||
|
string bytes = 9;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#memory-views
|
||||||
|
string memoryview = 10;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#set-types-set-frozenset
|
||||||
|
string set = 11;
|
||||||
|
string frozenset = 12;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#dict
|
||||||
|
string map = 13;
|
||||||
|
string dict = 14;
|
||||||
|
|
||||||
|
// https://docs.python.org/3/library/stdtypes.html#boolean-values
|
||||||
|
string bool = 15;
|
||||||
|
}
|
@@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"False": 1,
|
||||||
|
"None": 2,
|
||||||
|
"True": 3,
|
||||||
|
"and": 4,
|
||||||
|
"as": 5,
|
||||||
|
"assert": 6,
|
||||||
|
"async": 7,
|
||||||
|
"await": 8,
|
||||||
|
"break": 9,
|
||||||
|
"class": 10,
|
||||||
|
"continue": 11,
|
||||||
|
"def": 12,
|
||||||
|
"del": 13,
|
||||||
|
"elif": 14,
|
||||||
|
"else": 15,
|
||||||
|
"except": 16,
|
||||||
|
"finally": 17,
|
||||||
|
"for": 18,
|
||||||
|
"from": 19,
|
||||||
|
"global": 20,
|
||||||
|
"if": 21,
|
||||||
|
"import": 22,
|
||||||
|
"in": 23,
|
||||||
|
"is": 24,
|
||||||
|
"lambda": 25,
|
||||||
|
"nonlocal": 26,
|
||||||
|
"not": 27,
|
||||||
|
"or": 28,
|
||||||
|
"pass": 29,
|
||||||
|
"raise": 30,
|
||||||
|
"return": 31,
|
||||||
|
"try": 32,
|
||||||
|
"while": 33,
|
||||||
|
"with": 34,
|
||||||
|
"yield": 35
|
||||||
|
}
|
@@ -0,0 +1,44 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
// Tests that messages may contain fields that are Python keywords
|
||||||
|
//
|
||||||
|
// Generated with Python 3.7.6
|
||||||
|
// print('\n'.join(f'string {k} = {i+1};' for i,k in enumerate(keyword.kwlist)))
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
string False = 1;
|
||||||
|
string None = 2;
|
||||||
|
string True = 3;
|
||||||
|
string and = 4;
|
||||||
|
string as = 5;
|
||||||
|
string assert = 6;
|
||||||
|
string async = 7;
|
||||||
|
string await = 8;
|
||||||
|
string break = 9;
|
||||||
|
string class = 10;
|
||||||
|
string continue = 11;
|
||||||
|
string def = 12;
|
||||||
|
string del = 13;
|
||||||
|
string elif = 14;
|
||||||
|
string else = 15;
|
||||||
|
string except = 16;
|
||||||
|
string finally = 17;
|
||||||
|
string for = 18;
|
||||||
|
string from = 19;
|
||||||
|
string global = 20;
|
||||||
|
string if = 21;
|
||||||
|
string import = 22;
|
||||||
|
string in = 23;
|
||||||
|
string is = 24;
|
||||||
|
string lambda = 25;
|
||||||
|
string nonlocal = 26;
|
||||||
|
string not = 27;
|
||||||
|
string or = 28;
|
||||||
|
string pass = 29;
|
||||||
|
string raise = 30;
|
||||||
|
string return = 31;
|
||||||
|
string try = 32;
|
||||||
|
string while = 33;
|
||||||
|
string with = 34;
|
||||||
|
string yield = 35;
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user