Compare commits
158 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
bde6d06835 | ||
|
eaa4f7f5d9 | ||
|
cdddb2f42a | ||
|
d21cd6e391 | ||
|
af7115429a | ||
|
0d9387abec | ||
|
f4ebcb0f65 | ||
|
81711d2427 | ||
|
e3135ce766 | ||
|
9532844929 | ||
|
0c5d1ff868 | ||
|
5fb4b4b7ff | ||
|
4f820b4a6a | ||
|
75a4c230da | ||
|
5c9a12e2f6 | ||
|
e1ccd540a9 | ||
|
4e78fe9579 | ||
|
50bb67bf5d | ||
|
1ecbf1a125 | ||
|
0814729c5a | ||
|
f7aa6150e2 | ||
|
159c30ddd8 | ||
|
c8229e53a7 | ||
|
3185c67098 | ||
|
52eea5ce4c | ||
|
4b6f55dce5 | ||
|
fdbe0205f1 | ||
|
09f821921f | ||
|
a757da1b29 | ||
|
e2d672a422 | ||
|
63f5191f02 | ||
|
87f4b34930 | ||
|
2c360a55f2 | ||
|
04dce524aa | ||
|
8edec81b11 | ||
|
32c8e77274 | ||
|
d9fa6d2dd3 | ||
|
c88edfd093 | ||
|
a46979c8a6 | ||
|
83e13aa606 | ||
|
3ca75dadd7 | ||
|
5d2f3a2cd9 | ||
|
65c1f366ef | ||
|
34c34bd15a | ||
|
fb54917f2c | ||
|
1a95a7988e | ||
|
76db2f153e | ||
|
8567892352 | ||
|
3105e952ea | ||
|
7c8d47de6d | ||
|
c00e2aef19 | ||
|
fdf3b2e764 | ||
|
f7c2fd1194 | ||
|
d8abb850f8 | ||
|
d7ba27de2b | ||
|
57523a9e7f | ||
|
e5e61c873c | ||
|
9fd1c058e6 | ||
|
d336153845 | ||
|
9a45ea9f16 | ||
|
bb7f5229fb | ||
|
f7769a19d1 | ||
|
d31f90be6b | ||
|
919b0a6a7d | ||
|
7ecf3fe0e6 | ||
|
ff14948a4e | ||
|
cb00273257 | ||
|
973d68a154 | ||
|
ab9857b5fd | ||
|
2f658df666 | ||
|
b813d1cedb | ||
|
f5ce1b7108 | ||
|
62fc421d60 | ||
|
eeed1c0db7 | ||
|
2a3e1e1827 | ||
|
53ce1255d3 | ||
|
e8991339e9 | ||
|
4556d67503 | ||
|
f087c6c9bd | ||
|
eec24e4ee8 | ||
|
91111ab7d8 | ||
|
fcff3dff74 | ||
|
5c4969ff1c | ||
|
ed33a48d64 | ||
|
ee362a7a73 | ||
|
261e55b2c8 | ||
|
98930ce0d7 | ||
|
d7d277eb0d | ||
|
3860c0ab11 | ||
|
cd1c2dc3b5 | ||
|
be2a24d15c | ||
|
a5effb219a | ||
|
b354aeb692 | ||
|
6d9e3fc580 | ||
|
72de590651 | ||
|
3c70f21074 | ||
|
4b7d5d3de4 | ||
|
2d57f0d122 | ||
|
142e976c40 | ||
|
382fabb96c | ||
|
18598e77d4 | ||
|
6871053ab2 | ||
|
5bb6931df7 | ||
|
e8a9960b73 | ||
|
f25c66777a | ||
|
a68505b80e | ||
|
2f9497e064 | ||
|
33964b883e | ||
|
ec7574086d | ||
|
8a42027bc9 | ||
|
71737cf696 | ||
|
659ddd9c44 | ||
|
5b6997870a | ||
|
cdf7645722 | ||
|
ca20069ca3 | ||
|
59a4a7da43 | ||
|
15af4367e5 | ||
|
ec5683e572 | ||
|
20150fdcf3 | ||
|
d11b7d04c5 | ||
|
e2d35f4696 | ||
|
c3f08b9ef2 | ||
|
24d44898f4 | ||
|
074448c996 | ||
|
0fe557bd3c | ||
|
1a87ea43a1 | ||
|
983e0895a2 | ||
|
4a2baf3f0a | ||
|
8f0caf1db2 | ||
|
c50d9e2fdc | ||
|
35548cb43e | ||
|
b711d1e11f | ||
|
917de09bb6 | ||
|
1f7f39049e | ||
|
3d001a2a1a | ||
|
de61ddab21 | ||
|
5e2d9febea | ||
|
f6af077ffe | ||
|
92088ebda8 | ||
|
c3e3837f71 | ||
|
6bd9c7835c | ||
|
6ec902c1b5 | ||
|
960dba2ae8 | ||
|
4b4bdefb6f | ||
|
dfa0a56b39 | ||
|
dd4873dfba | ||
|
91f586f7d7 | ||
|
33fb83faad | ||
|
77c04414f5 | ||
|
6969ff7ff6 | ||
|
13e08fdaa8 | ||
|
6775632f77 | ||
|
b12f1e4e61 | ||
|
7e9ba0866c | ||
|
3546f55146 | ||
|
499489f1d3 | ||
|
5759e323bd | ||
|
c762c9c549 |
70
.github/workflows/ci.yml
vendored
70
.github/workflows/ci.yml
vendored
@@ -4,51 +4,71 @@ on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
check-formatting:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
name: Consult black on python formatting
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- uses: Gr1N/setup-poetry@v2
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pypoetry/virtualenvs
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-poetry-
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
- name: Run black
|
||||
run: make check-style
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
name: Run tests with tox
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ '3.6', '3.7' ]
|
||||
|
||||
name: Python ${{ matrix.python-version }} test
|
||||
python-version: [ '3.6', '3.7', '3.8']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions/setup-python@v1
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: dschep/install-pipenv-action@v1
|
||||
- uses: Gr1N/setup-poetry@v2
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pypoetry/virtualenvs
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-poetry-
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt install protobuf-compiler libprotobuf-dev
|
||||
pipenv install --dev --python ${pythonLocation}/python
|
||||
poetry install
|
||||
- name: Run tests
|
||||
run: |
|
||||
cp .env.default .env
|
||||
pipenv run pip install -e .
|
||||
pipenv run generate
|
||||
pipenv run test
|
||||
make generate
|
||||
make test
|
||||
|
||||
build-release:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions/setup-python@v1
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- uses: dschep/install-pipenv-action@v1
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt install protobuf-compiler libprotobuf-dev
|
||||
pipenv install --dev --python ${pythonLocation}/python
|
||||
- uses: Gr1N/setup-poetry@v2
|
||||
- name: Build package
|
||||
run: poetry build
|
||||
- name: Publish package to PyPI
|
||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
||||
run: pipenv run python setup.py sdist
|
||||
- name: Publish package
|
||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
||||
uses: pypa/gh-action-pypi-publish@v1.0.0a0
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.pypi }}
|
||||
run: poetry publish -n
|
||||
env:
|
||||
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}
|
||||
|
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,15 +1,16 @@
|
||||
.coverage
|
||||
.DS_Store
|
||||
.env
|
||||
.vscode/settings.json
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.python-version
|
||||
build/
|
||||
betterproto/tests/*.bin
|
||||
betterproto/tests/*_pb2.py
|
||||
betterproto/tests/*.py
|
||||
!betterproto/tests/generate.py
|
||||
!betterproto/tests/test_*.py
|
||||
betterproto/tests/output_*
|
||||
**/__pycache__
|
||||
dist
|
||||
**/*.egg-info
|
||||
output
|
||||
.idea
|
||||
.DS_Store
|
||||
.tox
|
||||
|
39
CHANGELOG.md
39
CHANGELOG.md
@@ -5,6 +5,42 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
|
||||
|
||||
## [2.0.0b1] - 2020-07-04
|
||||
|
||||
[Upgrade Guide](./docs/upgrading.md)
|
||||
|
||||
> Several bugfixes and improvements required or will require small breaking changes, necessitating a new version.
|
||||
> `2.0.0` will be released once the interface is stable.
|
||||
|
||||
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
|
||||
- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
|
||||
- Fix No arguments are generated for stub methods when using import with proto definition
|
||||
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
|
||||
|
||||
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)
|
||||
- Fix Import child package from package [#58](https://github.com/danielgtaylor/python-betterproto/issues/58)
|
||||
- Fix Import parent package from child package [#59](https://github.com/danielgtaylor/python-betterproto/issues/59)
|
||||
- Fix Import root package from child package [#60](https://github.com/danielgtaylor/python-betterproto/issues/60)
|
||||
- Fix Import root package from root [#61](https://github.com/danielgtaylor/python-betterproto/issues/61)
|
||||
|
||||
- Fix ALL_CAPS message fields are parsed incorrectly. [#11](https://github.com/danielgtaylor/python-betterproto/issues/11)
|
||||
|
||||
## [1.2.5] - 2020-04-27
|
||||
|
||||
- Add .j2 suffix to python template names to avoid confusing certain build tools [#72](https://github.com/danielgtaylor/python-betterproto/pull/72)
|
||||
|
||||
## [1.2.4] - 2020-04-26
|
||||
|
||||
- Enforce utf-8 for reading the readme in setup.py [#67](https://github.com/danielgtaylor/python-betterproto/pull/67)
|
||||
- Only import types from grpclib when type checking [#52](https://github.com/danielgtaylor/python-betterproto/pull/52)
|
||||
- Improve performance of serialize/deserialize by caching type information of fields in class [#46](https://github.com/danielgtaylor/python-betterproto/pull/46)
|
||||
- Support using Google's wrapper types as RPC output values [#40](https://github.com/danielgtaylor/python-betterproto/pull/40)
|
||||
- Fixes issue where protoc did not recognize plugin.py as win32 application [#38](https://github.com/danielgtaylor/python-betterproto/pull/38)
|
||||
- Fix services using non-pythonified field names [#34](https://github.com/danielgtaylor/python-betterproto/pull/34)
|
||||
- Add ability to provide metadata, timeout & deadline args to requests [#32](https://github.com/danielgtaylor/python-betterproto/pull/32)
|
||||
|
||||
## [1.2.3] - 2020-04-15
|
||||
|
||||
- Exclude empty lists from `to_dict` by default [#16](https://github.com/danielgtaylor/python-betterproto/pull/16)
|
||||
@@ -44,7 +80,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Initial release
|
||||
|
||||
[unreleased]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.3...HEAD
|
||||
[1.2.5]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.4...v1.2.5
|
||||
[1.2.4]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.3...v1.2.4
|
||||
[1.2.3]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.2...v1.2.3
|
||||
[1.2.2]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.1...v1.2.2
|
||||
[1.2.1]: https://github.com/danielgtaylor/python-betterproto/compare/v1.2.0...v1.2.1
|
||||
|
42
Makefile
Normal file
42
Makefile
Normal file
@@ -0,0 +1,42 @@
|
||||
.PHONY: help setup generate test types format clean plugin full-test check-style
|
||||
|
||||
help: ## - Show this help.
|
||||
@fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//'
|
||||
|
||||
# Dev workflow tasks
|
||||
|
||||
generate: ## - Generate test cases (do this once before running test)
|
||||
poetry run ./betterproto/tests/generate.py
|
||||
|
||||
test: ## - Run tests
|
||||
poetry run pytest --cov betterproto
|
||||
|
||||
types: ## - Check types with mypy
|
||||
poetry run mypy betterproto --ignore-missing-imports
|
||||
|
||||
format: ## - Apply black formatting to source code
|
||||
poetry run black . --exclude tests/output_
|
||||
|
||||
clean: ## - Clean out generated files from the workspace
|
||||
rm -rf .coverage \
|
||||
.mypy_cache \
|
||||
.pytest_cache \
|
||||
dist \
|
||||
**/__pycache__ \
|
||||
betterproto/tests/output_*
|
||||
|
||||
# Manual testing
|
||||
|
||||
# By default write plugin output to a directory called output
|
||||
o=output
|
||||
plugin: ## - Execute the protoc plugin, with output write to `output` or the value passed to `-o`
|
||||
mkdir -p $(o)
|
||||
protoc --plugin=protoc-gen-custom=betterproto/plugin.py $(i) --custom_out=$(o)
|
||||
|
||||
# CI tasks
|
||||
|
||||
full-test: generate ## - Run full testing sequence with multiple pythons
|
||||
poetry run tox
|
||||
|
||||
check-style: ## - Check if code style is correct
|
||||
poetry run black . --check --diff --exclude tests/output_
|
32
Pipfile
32
Pipfile
@@ -1,32 +0,0 @@
|
||||
[[source]]
|
||||
name = "pypi"
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
|
||||
[dev-packages]
|
||||
flake8 = "*"
|
||||
mypy = "*"
|
||||
isort = "*"
|
||||
pytest = "*"
|
||||
rope = "*"
|
||||
v = {editable = true,version = "*"}
|
||||
|
||||
[packages]
|
||||
protobuf = "*"
|
||||
jinja2 = "*"
|
||||
grpclib = "*"
|
||||
stringcase = "*"
|
||||
black = "*"
|
||||
backports-datetime-fromisoformat = "*"
|
||||
dataclasses = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.6"
|
||||
|
||||
[scripts]
|
||||
plugin = "protoc --plugin=protoc-gen-custom=betterproto/plugin.py --custom_out=output"
|
||||
generate = "python betterproto/tests/generate.py"
|
||||
test = "pytest ./betterproto/tests"
|
||||
|
||||
[pipenv]
|
||||
allow_prereleases = true
|
145
README.md
145
README.md
@@ -40,18 +40,24 @@ This project exists because I am unhappy with the state of the official Google p
|
||||
|
||||
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
|
||||
|
||||
## Installation & Getting Started
|
||||
## Installation
|
||||
|
||||
First, install the package. Note that the `[compiler]` feature flag tells it to install extra dependencies only needed by the `protoc` plugin:
|
||||
|
||||
```sh
|
||||
# Install both the library and compiler
|
||||
$ pip install "betterproto[compiler]"
|
||||
pip install "betterproto[compiler]"
|
||||
|
||||
# Install just the library (to use the generated code output)
|
||||
$ pip install betterproto
|
||||
pip install betterproto
|
||||
```
|
||||
|
||||
*Betterproto* is under active development. To install the latest beta version, use `pip install --pre betterproto`.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Compiling proto files
|
||||
|
||||
Now, given you installed the compiler and have a proto file, e.g `example.proto`:
|
||||
|
||||
```protobuf
|
||||
@@ -68,14 +74,15 @@ message Greeting {
|
||||
You can run the following:
|
||||
|
||||
```sh
|
||||
$ protoc -I . --python_betterproto_out=. example.proto
|
||||
mkdir lib
|
||||
protoc -I . --python_betterproto_out=lib example.proto
|
||||
```
|
||||
|
||||
This will generate `hello.py` which looks like:
|
||||
This will generate `lib/hello/__init__.py` which looks like:
|
||||
|
||||
```py
|
||||
```python
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: hello.proto
|
||||
# sources: example.proto
|
||||
# plugin: python-betterproto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -83,7 +90,7 @@ import betterproto
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hello(betterproto.Message):
|
||||
class Greeting(betterproto.Message):
|
||||
"""Greeting represents a message you can tell a user."""
|
||||
|
||||
message: str = betterproto.string_field(1)
|
||||
@@ -91,23 +98,23 @@ class Hello(betterproto.Message):
|
||||
|
||||
Now you can use it!
|
||||
|
||||
```py
|
||||
>>> from hello import Hello
|
||||
>>> test = Hello()
|
||||
```python
|
||||
>>> from lib.hello import Greeting
|
||||
>>> test = Greeting()
|
||||
>>> test
|
||||
Hello(message='')
|
||||
Greeting(message='')
|
||||
|
||||
>>> test.message = "Hey!"
|
||||
>>> test
|
||||
Hello(message="Hey!")
|
||||
Greeting(message="Hey!")
|
||||
|
||||
>>> serialized = bytes(test)
|
||||
>>> serialized
|
||||
b'\n\x04Hey!'
|
||||
|
||||
>>> another = Hello().parse(serialized)
|
||||
>>> another = Greeting().parse(serialized)
|
||||
>>> another
|
||||
Hello(message="Hey!")
|
||||
Greeting(message="Hey!")
|
||||
|
||||
>>> another.to_dict()
|
||||
{"message": "Hey!"}
|
||||
@@ -148,7 +155,7 @@ service Echo {
|
||||
|
||||
You can use it like so (enable async in the interactive shell first):
|
||||
|
||||
```py
|
||||
```python
|
||||
>>> import echo
|
||||
>>> from grpclib.client import Channel
|
||||
|
||||
@@ -173,8 +180,8 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
|
||||
|
||||
For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g:
|
||||
|
||||
```py
|
||||
>>> MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
|
||||
```python
|
||||
MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
|
||||
```
|
||||
|
||||
### Determining if a message was sent
|
||||
@@ -256,6 +263,7 @@ Google provides several well-known message types like a timestamp, duration, and
|
||||
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
|
||||
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
|
||||
| `google.protobuf.*Value` | `Optional[...]` | `None` |
|
||||
| `google.protobuf.*` | `betterproto.lib.google.protobuf.*` | `None` |
|
||||
|
||||
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
|
||||
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
|
||||
@@ -296,36 +304,106 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
||||
|
||||
## Development
|
||||
|
||||
First, make sure you have Python 3.6+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
||||
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||
|
||||
### Requirements
|
||||
|
||||
- Python (3.6 or higher)
|
||||
|
||||
- [protoc](https://grpc.io/docs/protoc-installation/) (3.12 or higher)
|
||||
*Needed to compile `.proto` files and run the tests*
|
||||
|
||||
- [poetry](https://python-poetry.org/docs/#installation)
|
||||
*Needed to install dependencies in a virtual environment*
|
||||
|
||||
- make ([ubuntu](https://www.howtoinstall.me/ubuntu/18-04/make/), [windows](https://stackoverflow.com/questions/32127524/how-to-install-and-use-make-in-windows), [mac](https://osxdaily.com/2014/02/12/install-command-line-tools-mac-os-x/))
|
||||
|
||||
*Needed to conveniently run development tasks.*
|
||||
*Alternatively, manually run the commands defined in the [Makefile](./Makefile)*
|
||||
|
||||
### Setup
|
||||
|
||||
```sh
|
||||
# Get set up with the virtual env & dependencies
|
||||
$ pipenv install --dev
|
||||
poetry install
|
||||
|
||||
# Link the local package
|
||||
$ pipenv shell
|
||||
$ pip install -e .
|
||||
# Activate the poetry environment
|
||||
poetry shell
|
||||
```
|
||||
|
||||
Run `make help` to see all available development tasks.
|
||||
|
||||
### Code style
|
||||
|
||||
This project enforces [black](https://github.com/psf/black) python code formatting.
|
||||
|
||||
Before committing changes run:
|
||||
|
||||
```sh
|
||||
make format
|
||||
```
|
||||
|
||||
To avoid merge conflicts later, non-black formatted python code will fail in CI.
|
||||
|
||||
### Tests
|
||||
|
||||
There are two types of tests:
|
||||
|
||||
1. Manually-written tests for some behavior of the library
|
||||
2. Proto files and JSON inputs for automated tests
|
||||
1. Standard tests
|
||||
2. Custom tests
|
||||
|
||||
For #2, you can add a new `*.proto` file into the `betterproto/tests` directory along with a sample `*.json` input and it will get automatically picked up.
|
||||
#### Standard tests
|
||||
|
||||
Adding a standard test case is easy.
|
||||
|
||||
- Create a new directory `betterproto/tests/inputs/<name>`
|
||||
- add `<name>.proto` with a message called `Test`
|
||||
- add `<name>.json` with some test data (optional)
|
||||
|
||||
It will be picked up automatically when you run the tests.
|
||||
|
||||
- See also: [Standard Tests Development Guide](betterproto/tests/README.md)
|
||||
|
||||
#### Custom tests
|
||||
|
||||
Custom tests are found in `tests/test_*.py` and are run with pytest.
|
||||
|
||||
#### Running
|
||||
|
||||
Here's how to run the tests.
|
||||
|
||||
```sh
|
||||
# Generate assets from sample .proto files
|
||||
$ pipenv run generate
|
||||
|
||||
# Generate assets from sample .proto files required by the tests
|
||||
make generate
|
||||
# Run the tests
|
||||
$ pipenv run test
|
||||
make test
|
||||
```
|
||||
|
||||
To run tests as they are run in CI (with tox) run:
|
||||
|
||||
```sh
|
||||
make full-test
|
||||
```
|
||||
|
||||
### (Re)compiling Google Well-known Types
|
||||
|
||||
Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google).
|
||||
Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
|
||||
|
||||
Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`.
|
||||
|
||||
Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows:
|
||||
|
||||
```sh
|
||||
protoc \
|
||||
--plugin=protoc-gen-custom=betterproto/plugin.py \
|
||||
--custom_opt=INCLUDE_GOOGLE \
|
||||
--custom_out=betterproto/lib \
|
||||
-I /usr/local/include/ \
|
||||
/usr/local/include/google/protobuf/*.proto
|
||||
```
|
||||
|
||||
|
||||
### TODO
|
||||
|
||||
- [x] Fixed length fields
|
||||
@@ -340,6 +418,9 @@ $ pipenv run test
|
||||
- [x] Refs to nested types
|
||||
- [x] Imports in proto files
|
||||
- [x] Well-known Google types
|
||||
- [ ] Support as request input
|
||||
- [ ] Support as response output
|
||||
- [ ] Automatically wrap/unwrap responses
|
||||
- [x] OneOf support
|
||||
- [x] Basic support on the wire
|
||||
- [x] Check which was set from the group
|
||||
@@ -363,6 +444,10 @@ $ pipenv run test
|
||||
- [x] Automate running tests
|
||||
- [ ] Cleanup!
|
||||
|
||||
## Community
|
||||
|
||||
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||
|
||||
## License
|
||||
|
||||
Copyright © 2019 Daniel G. Taylor
|
||||
|
@@ -5,34 +5,25 @@ import json
|
||||
import struct
|
||||
import sys
|
||||
from abc import ABC
|
||||
from base64 import b64encode, b64decode
|
||||
from base64 import b64decode, b64encode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
SupportsBytes,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
import stringcase
|
||||
|
||||
from .casing import safe_snake_case
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib._protocols import IProtoMessage
|
||||
from ._types import T
|
||||
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||
# Apply backport of datetime.fromisoformat from 3.7
|
||||
@@ -118,14 +109,18 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
def datetime_default_gen():
|
||||
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
DATETIME_ZERO = datetime_default_gen()
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = stringcase.camelcase
|
||||
SNAKE = stringcase.snakecase
|
||||
CAMEL = camel_case
|
||||
SNAKE = snake_case
|
||||
|
||||
|
||||
class _PLACEHOLDER:
|
||||
@@ -422,8 +417,82 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
)
|
||||
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
T = TypeVar("T", bound="Message")
|
||||
class ProtoClassMetadata:
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
default_gen: Dict[str, Callable]
|
||||
cls_by_field: Dict[str, Type]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
__slots__ = (
|
||||
"oneof_group_by_field",
|
||||
"oneof_field_by_group",
|
||||
"default_gen",
|
||||
"cls_by_field",
|
||||
"field_name_by_number",
|
||||
"meta_by_field_name",
|
||||
)
|
||||
|
||||
def __init__(self, cls: Type["Message"]):
|
||||
by_field = {}
|
||||
by_group: Dict[str, Set] = {}
|
||||
by_field_name = {}
|
||||
by_field_number = {}
|
||||
|
||||
fields = dataclasses.fields(cls)
|
||||
for field in fields:
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if meta.group:
|
||||
# This is part of a one-of group.
|
||||
by_field[field.name] = meta.group
|
||||
|
||||
by_group.setdefault(meta.group, set()).add(field)
|
||||
|
||||
by_field_name[field.name] = meta
|
||||
by_field_number[meta.number] = field.name
|
||||
|
||||
self.oneof_group_by_field = by_field
|
||||
self.oneof_field_by_group = by_group
|
||||
self.field_name_by_number = by_field_number
|
||||
self.meta_by_field_name = by_field_name
|
||||
|
||||
self.default_gen = self._get_default_gen(cls, fields)
|
||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||
|
||||
@staticmethod
|
||||
def _get_default_gen(cls, fields):
|
||||
default_gen = {}
|
||||
|
||||
for field in fields:
|
||||
default_gen[field.name] = cls._get_field_default_gen(field)
|
||||
|
||||
return default_gen
|
||||
|
||||
@staticmethod
|
||||
def _get_cls_by_field(cls, fields):
|
||||
field_cls = {}
|
||||
|
||||
for field in fields:
|
||||
meta = FieldMetadata.get(field)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
assert meta.map_types
|
||||
kt = cls._cls_for(field, index=0)
|
||||
vt = cls._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
field_cls[field.name] = Entry
|
||||
field_cls[field.name + ".value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = cls._cls_for(field)
|
||||
|
||||
return field_cls
|
||||
|
||||
|
||||
class Message(ABC):
|
||||
@@ -435,69 +504,74 @@ class Message(ABC):
|
||||
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_map: Dict[str, dict]
|
||||
_group_current: Dict[str, str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Keep track of whether every field was default
|
||||
all_sentinel = True
|
||||
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
# Set current field of each group after `__init__` has already been run.
|
||||
group_current: Dict[str, str] = {}
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
|
||||
if meta.group:
|
||||
# This is part of a one-of group.
|
||||
group_map["fields"][field.name] = meta.group
|
||||
group_current.setdefault(meta.group)
|
||||
|
||||
if meta.group not in group_map["groups"]:
|
||||
group_map["groups"][meta.group] = {"current": None, "fields": set()}
|
||||
group_map["groups"][meta.group]["fields"].add(field)
|
||||
|
||||
if getattr(self, field.name) != PLACEHOLDER:
|
||||
if getattr(self, field_name) != PLACEHOLDER:
|
||||
# Skip anything not set to the sentinel value
|
||||
all_sentinel = False
|
||||
|
||||
if meta.group:
|
||||
# This was set, so make it the selected value of the one-of.
|
||||
group_map["groups"][meta.group]["current"] = field
|
||||
group_current[meta.group] = field_name
|
||||
|
||||
continue
|
||||
|
||||
setattr(self, field.name, self._get_field_default(field, meta))
|
||||
setattr(self, field_name, self._get_field_default(field_name))
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
||||
self.__dict__["_unknown_fields"] = b""
|
||||
self.__dict__["_group_map"] = group_map
|
||||
self.__dict__["_group_current"] = group_current
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "_serialized_on_wire":
|
||||
# Track when a field has been set.
|
||||
self.__dict__["_serialized_on_wire"] = True
|
||||
|
||||
if attr in getattr(self, "_group_map", {}).get("fields", {}):
|
||||
group = self._group_map["fields"][attr]
|
||||
for field in self._group_map["groups"][group]["fields"]:
|
||||
if field.name == attr:
|
||||
self._group_map["groups"][group]["current"] = field
|
||||
else:
|
||||
super().__setattr__(
|
||||
field.name,
|
||||
self._get_field_default(field, FieldMetadata.get(field)),
|
||||
)
|
||||
if hasattr(self, "_group_current"): # __post_init__ had already run
|
||||
if attr in self._betterproto.oneof_group_by_field:
|
||||
group = self._betterproto.oneof_group_by_field[attr]
|
||||
for field in self._betterproto.oneof_field_by_group[group]:
|
||||
if field.name == attr:
|
||||
self._group_current[group] = field.name
|
||||
else:
|
||||
super().__setattr__(
|
||||
field.name, self._get_field_default(field.name),
|
||||
)
|
||||
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
def _betterproto(self):
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
but that won't affect the correctness.
|
||||
"""
|
||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||
if not meta:
|
||||
meta = ProtoClassMetadata(self.__class__)
|
||||
self.__class__._betterproto_meta = meta
|
||||
return meta
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
"""
|
||||
output = b""
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
value = getattr(self, field.name)
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
value = getattr(self, field_name)
|
||||
|
||||
if value is None:
|
||||
# Optional items should be skipped. This is used for the Google
|
||||
@@ -508,16 +582,16 @@ class Message(ABC):
|
||||
# currently set in a `oneof` group, so it must be serialized even
|
||||
# if the value is the default zero value.
|
||||
selected_in_group = False
|
||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||
if meta.group and self._group_current[meta.group] == field_name:
|
||||
selected_in_group = True
|
||||
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
# set (or recieved empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field, meta) and not (
|
||||
if value == self._get_field_default(field_name) and not (
|
||||
selected_in_group or serialize_empty
|
||||
):
|
||||
# Default (zero) values are not serialized. Two exceptions are
|
||||
@@ -560,50 +634,53 @@ class Message(ABC):
|
||||
# For compatibility with other libraries
|
||||
SerializeToString = __bytes__
|
||||
|
||||
def _type_hint(self, field_name: str) -> Type:
|
||||
module = inspect.getmodule(self.__class__)
|
||||
type_hints = get_type_hints(self.__class__, vars(module))
|
||||
@classmethod
|
||||
def _type_hint(cls, field_name: str) -> Type:
|
||||
module = inspect.getmodule(cls)
|
||||
type_hints = get_type_hints(cls, vars(module))
|
||||
return type_hints[field_name]
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
@classmethod
|
||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
cls = self._type_hint(field.name)
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = cls.__args__[index]
|
||||
return cls
|
||||
field_cls = cls._type_hint(field.name)
|
||||
if hasattr(field_cls, "__args__") and index >= 0:
|
||||
field_cls = field_cls.__args__[index]
|
||||
return field_cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
t = self._type_hint(field.name)
|
||||
def _get_field_default(self, field_name):
|
||||
return self._betterproto.default_gen[field_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
||||
t = cls._type_hint(field.name)
|
||||
|
||||
value: Any = 0
|
||||
if hasattr(t, "__origin__"):
|
||||
if t.__origin__ in (dict, Dict):
|
||||
# This is some kind of map (dict in Python).
|
||||
value = {}
|
||||
return dict
|
||||
elif t.__origin__ in (list, List):
|
||||
# This is some kind of list (repeated) field.
|
||||
value = []
|
||||
return list
|
||||
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||
# This is an optional (wrapped) field. For setting the default we
|
||||
# really don't care what kind of field it is.
|
||||
value = None
|
||||
return type(None)
|
||||
else:
|
||||
value = t()
|
||||
return t
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
value = 0
|
||||
return int
|
||||
elif t == datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
value = DATETIME_ZERO
|
||||
return datetime_default_gen
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
value = t()
|
||||
|
||||
return value
|
||||
return t
|
||||
|
||||
def _postprocess_single(
|
||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
||||
) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
@@ -625,7 +702,7 @@ class Message(ABC):
|
||||
if meta.proto_type == TYPE_STRING:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
@@ -639,20 +716,7 @@ class Message(ABC):
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
assert meta.map_types
|
||||
kt = self._cls_for(field, index=0)
|
||||
vt = self._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
value = Entry().parse(value)
|
||||
value = self._betterproto.cls_by_field[field_name]().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
@@ -661,49 +725,46 @@ class Message(ABC):
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
||||
for parsed in parse_fields(data):
|
||||
if parsed.number in fields:
|
||||
field = fields[parsed.number]
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
value: Any
|
||||
if (
|
||||
parsed.wire_type == WIRE_LEN_DELIM
|
||||
and meta.proto_type in PACKED_TYPES
|
||||
):
|
||||
# This is a packed repeated field.
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
|
||||
decoded, pos = parsed.value[pos : pos + 4], pos + 4
|
||||
wire_type = WIRE_FIXED_32
|
||||
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
|
||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||
wire_type = WIRE_FIXED_64
|
||||
else:
|
||||
decoded, pos = decode_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = self._postprocess_single(
|
||||
wire_type, meta, field, decoded
|
||||
)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = self._postprocess_single(
|
||||
parsed.wire_type, meta, field, parsed.value
|
||||
)
|
||||
|
||||
current = getattr(self, field.name)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Value represents a single key/value pair entry in the map.
|
||||
current[value.key] = value.value
|
||||
elif isinstance(current, list) and not isinstance(value, list):
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(self, field.name, value)
|
||||
else:
|
||||
field_name = self._betterproto.field_name_by_number.get(parsed.number)
|
||||
if not field_name:
|
||||
self._unknown_fields += parsed.raw
|
||||
continue
|
||||
|
||||
meta = self._betterproto.meta_by_field_name[field_name]
|
||||
|
||||
value: Any
|
||||
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
|
||||
# This is a packed repeated field.
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
|
||||
decoded, pos = parsed.value[pos : pos + 4], pos + 4
|
||||
wire_type = WIRE_FIXED_32
|
||||
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
|
||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||
wire_type = WIRE_FIXED_64
|
||||
else:
|
||||
decoded, pos = decode_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = self._postprocess_single(
|
||||
wire_type, meta, field_name, decoded
|
||||
)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = self._postprocess_single(
|
||||
parsed.wire_type, meta, field_name, parsed.value
|
||||
)
|
||||
|
||||
current = getattr(self, field_name)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Value represents a single key/value pair entry in the map.
|
||||
current[value.key] = value.value
|
||||
elif isinstance(current, list) and not isinstance(value, list):
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(self, field_name, value)
|
||||
|
||||
return self
|
||||
|
||||
@@ -714,7 +775,7 @@ class Message(ABC):
|
||||
|
||||
def to_dict(
|
||||
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
|
||||
) -> dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON. Defaults to camel casing for
|
||||
@@ -726,10 +787,9 @@ class Message(ABC):
|
||||
`False`.
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
v = getattr(self, field_name)
|
||||
cased_name = casing(field_name).rstrip("_") # type: ignore
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, datetime):
|
||||
if v != DATETIME_ZERO or include_default_values:
|
||||
@@ -755,7 +815,7 @@ class Message(ABC):
|
||||
|
||||
if v or include_default_values:
|
||||
output[cased_name] = v
|
||||
elif v != self._get_field_default(field, meta) or include_default_values:
|
||||
elif v != self._get_field_default(field_name) or include_default_values:
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [str(n) for n in v]
|
||||
@@ -767,7 +827,9 @@ class Message(ABC):
|
||||
else:
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(self._cls_for(field)) # type: ignore
|
||||
enum_values = list(
|
||||
self._betterproto.cls_by_field[field_name]
|
||||
) # type: ignore
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
else:
|
||||
@@ -784,56 +846,54 @@ class Message(ABC):
|
||||
self._serialized_on_wire = True
|
||||
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
||||
for key in value:
|
||||
snake_cased = safe_snake_case(key)
|
||||
if snake_cased in fields_by_name:
|
||||
field = fields_by_name[snake_cased]
|
||||
meta = FieldMetadata.get(field)
|
||||
field_name = safe_snake_case(key)
|
||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(
|
||||
value[key].replace("Z", "+00:00")
|
||||
)
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field.name, value[key])
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field_name)
|
||||
if isinstance(v, list):
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field_name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field_name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field_name, value[key])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
return self
|
||||
|
||||
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
||||
@@ -859,25 +919,29 @@ def serialized_on_wire(message: Message) -> bool:
|
||||
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
"""Return the name and value of a message's one-of field group."""
|
||||
field = message._group_map["groups"].get(group_name, {}).get("current")
|
||||
if not field:
|
||||
field_name = message._group_current.get(group_name)
|
||||
if not field_name:
|
||||
return ("", None)
|
||||
return (field.name, getattr(message, field.name))
|
||||
return (field_name, getattr(message, field_name))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Duration(Message):
|
||||
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
||||
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
||||
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
||||
seconds: int = int64_field(1)
|
||||
# Signed fractions of a second at nanosecond resolution of the span of time.
|
||||
# Durations less than one second are represented with a 0 `seconds` field and
|
||||
# a positive or negative `nanos` field. For durations of one second or more,
|
||||
# a non-zero value for the `nanos` field must be of the same sign as the
|
||||
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||
from .lib.google.protobuf import (
|
||||
Duration,
|
||||
Timestamp,
|
||||
BoolValue,
|
||||
BytesValue,
|
||||
DoubleValue,
|
||||
FloatValue,
|
||||
Int32Value,
|
||||
Int64Value,
|
||||
StringValue,
|
||||
UInt32Value,
|
||||
UInt64Value,
|
||||
)
|
||||
|
||||
|
||||
class _Duration(Duration):
|
||||
def to_timedelta(self) -> timedelta:
|
||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||
|
||||
@@ -890,16 +954,7 @@ class _Duration(Message):
|
||||
return ".".join(parts) + "s"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Timestamp(Message):
|
||||
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
||||
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
||||
seconds: int = int64_field(1)
|
||||
# Non-negative fractions of a second at nanosecond resolution. Negative
|
||||
# second values with fractions must still have non-negative nanos values that
|
||||
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
class _Timestamp(Timestamp):
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
@@ -940,93 +995,16 @@ class _WrappedMessage(Message):
|
||||
return self
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BoolValue(_WrappedMessage):
|
||||
value: bool = bool_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int32Value(_WrappedMessage):
|
||||
value: int = int32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt32Value(_WrappedMessage):
|
||||
value: int = uint32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int64Value(_WrappedMessage):
|
||||
value: int = int64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt64Value(_WrappedMessage):
|
||||
value: int = uint64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FloatValue(_WrappedMessage):
|
||||
value: float = float_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DoubleValue(_WrappedMessage):
|
||||
value: float = double_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _StringValue(_WrappedMessage):
|
||||
value: str = string_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BytesValue(_WrappedMessage):
|
||||
value: bytes = bytes_field(1)
|
||||
|
||||
|
||||
def _get_wrapper(proto_type: str) -> Type:
|
||||
"""Get the wrapper message class for a wrapped type."""
|
||||
return {
|
||||
TYPE_BOOL: _BoolValue,
|
||||
TYPE_INT32: _Int32Value,
|
||||
TYPE_UINT32: _UInt32Value,
|
||||
TYPE_INT64: _Int64Value,
|
||||
TYPE_UINT64: _UInt64Value,
|
||||
TYPE_FLOAT: _FloatValue,
|
||||
TYPE_DOUBLE: _DoubleValue,
|
||||
TYPE_STRING: _StringValue,
|
||||
TYPE_BYTES: _BytesValue,
|
||||
TYPE_BOOL: BoolValue,
|
||||
TYPE_INT32: Int32Value,
|
||||
TYPE_UINT32: UInt32Value,
|
||||
TYPE_INT64: Int64Value,
|
||||
TYPE_UINT64: UInt64Value,
|
||||
TYPE_FLOAT: FloatValue,
|
||||
TYPE_DOUBLE: DoubleValue,
|
||||
TYPE_STRING: StringValue,
|
||||
TYPE_BYTES: BytesValue,
|
||||
}[proto_type]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: grpclib.client.Channel) -> None:
|
||||
self.channel = channel
|
||||
|
||||
async def _unary_unary(
|
||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
||||
) -> T:
|
||||
"""Make a unary request and return the response."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def _unary_stream(
|
||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
async for message in stream:
|
||||
yield message
|
||||
|
9
betterproto/_types.py
Normal file
9
betterproto/_types.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Message
|
||||
from grpclib._protocols import IProtoMessage
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
T = TypeVar("T", bound="Message")
|
||||
ST = TypeVar("ST", bound="IProtoMessage")
|
@@ -1,9 +1,21 @@
|
||||
import stringcase
|
||||
import re
|
||||
|
||||
# Word delimiters and symbols that will not be preserved when re-casing.
|
||||
# language=PythonRegExp
|
||||
SYMBOLS = "[^a-zA-Z0-9]*"
|
||||
|
||||
# Optionally capitalized word.
|
||||
# language=PythonRegExp
|
||||
WORD = "[A-Z]*[a-z]*[0-9]*"
|
||||
|
||||
# Uppercase word, not followed by lowercase letters.
|
||||
# language=PythonRegExp
|
||||
WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
|
||||
|
||||
|
||||
def safe_snake_case(value: str) -> str:
|
||||
"""Snake case a value taking into account Python keywords."""
|
||||
value = stringcase.snakecase(value)
|
||||
value = snake_case(value)
|
||||
if value in [
|
||||
"and",
|
||||
"as",
|
||||
@@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str:
|
||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||
value += "_"
|
||||
return value
|
||||
|
||||
|
||||
def snake_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Join words with an underscore into lowercase and remove symbols.
|
||||
@param value: value to convert
|
||||
@param strict: force single underscores
|
||||
"""
|
||||
|
||||
def substitute_word(symbols, word, is_start):
|
||||
if not word:
|
||||
return ""
|
||||
if strict:
|
||||
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
|
||||
elif is_start:
|
||||
delimiter_count = len(symbols)
|
||||
elif word.isupper() or word.islower():
|
||||
delimiter_count = max(
|
||||
1, len(symbols)
|
||||
) # Preserve all delimiters if not strict.
|
||||
else:
|
||||
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
|
||||
|
||||
return ("_" * delimiter_count) + word.lower()
|
||||
|
||||
snake = re.sub(
|
||||
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
|
||||
value,
|
||||
)
|
||||
return snake
|
||||
|
||||
|
||||
def pascal_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Capitalize each word and remove symbols.
|
||||
@param value: value to convert
|
||||
@param strict: output only alphanumeric characters
|
||||
"""
|
||||
|
||||
def substitute_word(symbols, word):
|
||||
if strict:
|
||||
return word.capitalize() # Remove all delimiters
|
||||
|
||||
if word.islower():
|
||||
delimiter_length = len(symbols[:-1]) # Lose one delimiter
|
||||
else:
|
||||
delimiter_length = len(symbols) # Preserve all delimiters
|
||||
|
||||
return ("_" * delimiter_length) + word.capitalize()
|
||||
|
||||
return re.sub(
|
||||
f"({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||
lambda groups: substitute_word(groups[1], groups[2]),
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
def camel_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Capitalize all words except first and remove symbols.
|
||||
"""
|
||||
return lowercase_first(pascal_case(value, strict=strict))
|
||||
|
||||
|
||||
def lowercase_first(value: str):
|
||||
return value[0:1].lower() + value[1:]
|
||||
|
0
betterproto/compile/__init__.py
Normal file
0
betterproto/compile/__init__.py
Normal file
160
betterproto/compile/importing.py
Normal file
160
betterproto/compile/importing.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Set, Type
|
||||
|
||||
from betterproto import safe_snake_case
|
||||
from betterproto.compile.naming import pythonize_class_name
|
||||
from betterproto.lib.google import protobuf as google_protobuf
|
||||
|
||||
WRAPPER_TYPES: Dict[str, Type] = {
|
||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||
".google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||
".google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||
".google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||
".google.protobuf.StringValue": google_protobuf.StringValue,
|
||||
".google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||
}
|
||||
|
||||
|
||||
def parse_source_type_name(field_type_name):
|
||||
"""
|
||||
Split full source type name into package and type name.
|
||||
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
||||
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
|
||||
"""
|
||||
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
|
||||
if package_match:
|
||||
package = package_match.group(1)
|
||||
name = package_match.group(2)
|
||||
else:
|
||||
package = ""
|
||||
name = field_type_name.lstrip(".")
|
||||
return package, name
|
||||
|
||||
|
||||
def get_type_reference(
|
||||
package: str, imports: set, source_type: str, unwrap: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary. Unwraps well known type if required.
|
||||
"""
|
||||
if unwrap:
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if source_type == ".google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
source_package, source_type = parse_source_type_name(source_type)
|
||||
|
||||
current_package: List[str] = package.split(".") if package else []
|
||||
py_package: List[str] = source_package.split(".") if source_package else []
|
||||
py_type: str = pythonize_class_name(source_type)
|
||||
|
||||
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||
if importing_google_protobuf and not compiling_google_protobuf:
|
||||
py_package = ["betterproto", "lib"] + py_package
|
||||
|
||||
if py_package[:1] == ["betterproto"]:
|
||||
return reference_absolute(imports, py_package, py_type)
|
||||
|
||||
if py_package == current_package:
|
||||
return reference_sibling(py_type)
|
||||
|
||||
if py_package[: len(current_package)] == current_package:
|
||||
return reference_descendent(current_package, imports, py_package, py_type)
|
||||
|
||||
if current_package[: len(py_package)] == py_package:
|
||||
return reference_ancestor(current_package, imports, py_package, py_type)
|
||||
|
||||
return reference_cousin(current_package, imports, py_package, py_type)
|
||||
|
||||
|
||||
def reference_absolute(imports, py_package, py_type):
|
||||
"""
|
||||
Returns a reference to a python type located in the root, i.e. sys.path.
|
||||
"""
|
||||
string_import = ".".join(py_package)
|
||||
string_alias = safe_snake_case(string_import)
|
||||
imports.add(f"import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
|
||||
|
||||
def reference_sibling(py_type: str) -> str:
|
||||
"""
|
||||
Returns a reference to a python type within the same package as the current package.
|
||||
"""
|
||||
return f'"{py_type}"'
|
||||
|
||||
|
||||
def reference_descendent(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is a descendent of the current package,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
"""
|
||||
importing_descendent = py_package[len(current_package) :]
|
||||
string_from = ".".join(importing_descendent[:-1])
|
||||
string_import = importing_descendent[-1]
|
||||
if string_from:
|
||||
string_alias = "_".join(importing_descendent)
|
||||
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
else:
|
||||
imports.add(f"from . import {string_import}")
|
||||
return f"{string_import}.{py_type}"
|
||||
|
||||
|
||||
def reference_ancestor(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package which is an ancestor to the current package,
|
||||
and adds the required import that is aliased (if possible) to avoid name conflicts.
|
||||
|
||||
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
|
||||
"""
|
||||
distance_up = len(current_package) - len(py_package)
|
||||
if py_package:
|
||||
string_import = py_package[-1]
|
||||
string_alias = f"_{'_' * distance_up}{string_import}__"
|
||||
string_from = f"..{'.' * distance_up}"
|
||||
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
else:
|
||||
string_alias = f"{'_' * distance_up}{py_type}__"
|
||||
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
|
||||
return string_alias
|
||||
|
||||
|
||||
def reference_cousin(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
"""
|
||||
shared_ancestry = os.path.commonprefix([current_package, py_package])
|
||||
distance_up = len(current_package) - len(shared_ancestry)
|
||||
string_from = f".{'.' * distance_up}" + ".".join(
|
||||
py_package[len(shared_ancestry) : -1]
|
||||
)
|
||||
string_import = py_package[-1]
|
||||
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
|
||||
string_alias = (
|
||||
f"{'_' * distance_up}"
|
||||
+ safe_snake_case(".".join(py_package[len(shared_ancestry) :]))
|
||||
+ "__"
|
||||
)
|
||||
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
13
betterproto/compile/naming.py
Normal file
13
betterproto/compile/naming.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from betterproto import casing
|
||||
|
||||
|
||||
def pythonize_class_name(name):
|
||||
return casing.pascal_case(name)
|
||||
|
||||
|
||||
def pythonize_field_name(name: str):
|
||||
return casing.safe_snake_case(name)
|
||||
|
||||
|
||||
def pythonize_method_name(name: str):
|
||||
return casing.safe_snake_case(name)
|
0
betterproto/grpc/__init__.py
Normal file
0
betterproto/grpc/__init__.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import grpclib.const
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from .._types import ST, T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib._protocols import IProtoMessage
|
||||
from grpclib.client import Channel, Stream
|
||||
from grpclib.metadata import Deadline
|
||||
|
||||
|
||||
_Value = Union[str, bytes]
|
||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
||||
_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: "Channel",
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> None:
|
||||
self.channel = channel
|
||||
self.timeout = timeout
|
||||
self.deadline = deadline
|
||||
self.metadata = metadata
|
||||
|
||||
def __resolve_request_kwargs(
|
||||
self,
|
||||
timeout: Optional[float],
|
||||
deadline: Optional["Deadline"],
|
||||
metadata: Optional[_MetadataLike],
|
||||
):
|
||||
return {
|
||||
"timeout": self.timeout if timeout is None else timeout,
|
||||
"deadline": self.deadline if deadline is None else deadline,
|
||||
"metadata": self.metadata if metadata is None else metadata,
|
||||
}
|
||||
|
||||
async def _unary_unary(
|
||||
self,
|
||||
route: str,
|
||||
request: "IProtoMessage",
|
||||
response_type: Type[T],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> T:
|
||||
"""Make a unary request and return the response."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
type(request),
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def _unary_stream(
|
||||
self,
|
||||
route: str,
|
||||
request: "IProtoMessage",
|
||||
response_type: Type[T],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
type(request),
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
async for message in stream:
|
||||
yield message
|
||||
|
||||
async def _stream_unary(
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: _MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> T:
|
||||
"""Make a stream request and return the response."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
request_type,
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await self._send_messages(stream, request_iterator)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def _stream_stream(
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: _MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
"""
|
||||
Make a stream request and return an AsyncIterator to iterate over response
|
||||
messages.
|
||||
"""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
request_type,
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await stream.send_request()
|
||||
sending_task = asyncio.ensure_future(
|
||||
self._send_messages(stream, request_iterator)
|
||||
)
|
||||
try:
|
||||
async for response in stream:
|
||||
yield response
|
||||
except:
|
||||
sending_task.cancel()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _send_messages(stream, messages: _MessageSource):
|
||||
if isinstance(messages, AsyncIterable):
|
||||
async for message in messages:
|
||||
await stream.send_message(message)
|
||||
else:
|
||||
for message in messages:
|
||||
await stream.send_message(message)
|
||||
await stream.end()
|
0
betterproto/grpc/util/__init__.py
Normal file
0
betterproto/grpc/util/__init__.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Iterable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ChannelClosed(Exception):
|
||||
"""
|
||||
An exception raised on an attempt to send through a closed channel
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelDone(Exception):
|
||||
"""
|
||||
An exception raised on an attempt to send recieve from a channel that is both closed
|
||||
and empty.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||
|
||||
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||
|
||||
.. code-block:: python
|
||||
client = GeneratedStub(grpclib_chan)
|
||||
request_chan = await AsyncChannel()
|
||||
# We can start be sending all the requests we already have
|
||||
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
||||
async for response in client.rpc_call(request_chan):
|
||||
# The response iterator will remain active until the connection is closed
|
||||
...
|
||||
# More items can be sent at any time
|
||||
await request_chan.send(ReqestObject(...))
|
||||
...
|
||||
# The channel must be closed to complete the gRPC connection
|
||||
request_chan.close()
|
||||
|
||||
Items can be sent through the channel by either:
|
||||
- providing an iterable to the send_from method
|
||||
- passing them to the send method one at a time
|
||||
|
||||
Items can be recieved from the channel by either:
|
||||
- iterating over the channel with a for loop to get all items
|
||||
- calling the recieve method to get one item at a time
|
||||
|
||||
If the channel is empty then recievers will wait until either an item appears or the
|
||||
channel is closed.
|
||||
|
||||
Once the channel is closed then subsequent attempt to send through the channel will
|
||||
fail with a ChannelClosed exception.
|
||||
|
||||
When th channel is closed and empty then it is done, and further attempts to recieve
|
||||
from it will fail with a ChannelDone exception
|
||||
|
||||
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
||||
recieved by only one of the recievers.
|
||||
|
||||
:param source:
|
||||
An optional iterable will items that should be sent through the channel
|
||||
immediately.
|
||||
:param buffer_limit:
|
||||
Limit the number of items that can be buffered in the channel, A value less than
|
||||
1 implies no limit. If the channel is full then attempts to send more items will
|
||||
result in the sender waiting until an item is recieved from the channel.
|
||||
:param close:
|
||||
If set to True then the channel will automatically close after exhausting source
|
||||
or immediately if no source is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, buffer_limit: int = 0, close: bool = False,
|
||||
):
|
||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||
self._closed = False
|
||||
self._waiting_recievers: int = 0
|
||||
# Track whether flush has been invoked so it can only happen once
|
||||
self._flushed = False
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[T]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> T:
|
||||
if self.done():
|
||||
raise StopAsyncIteration
|
||||
self._waiting_recievers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
raise StopAsyncIteration
|
||||
return result
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def closed(self) -> bool:
|
||||
"""
|
||||
Returns True if this channel is closed and no-longer accepting new items
|
||||
"""
|
||||
return self._closed
|
||||
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if this channel is done.
|
||||
|
||||
:return: True if this channel is closed and and has been drained of items in
|
||||
which case any further attempts to recieve an item from this channel will raise
|
||||
a ChannelDone exception.
|
||||
"""
|
||||
# After close the channel is not yet done until there is at least one waiting
|
||||
# reciever per enqueued item.
|
||||
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
||||
|
||||
async def send_from(
|
||||
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||
) -> "AsyncChannel[T]":
|
||||
"""
|
||||
Iterates the given [Async]Iterable and sends all the resulting items.
|
||||
If close is set to True then subsequent send calls will be rejected with a
|
||||
ChannelClosed exception.
|
||||
:param source: an iterable of items to send
|
||||
:param close:
|
||||
if True then the channel will be closed after the source has been exhausted
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ChannelClosed("Cannot send through a closed channel")
|
||||
if isinstance(source, AsyncIterable):
|
||||
async for item in source:
|
||||
await self._queue.put(item)
|
||||
else:
|
||||
for item in source:
|
||||
await self._queue.put(item)
|
||||
if close:
|
||||
# Complete the closing process
|
||||
self.close()
|
||||
return self
|
||||
|
||||
async def send(self, item: T) -> "AsyncChannel[T]":
|
||||
"""
|
||||
Send a single item over this channel.
|
||||
:param item: The item to send
|
||||
"""
|
||||
if self._closed:
|
||||
raise ChannelClosed("Cannot send through a closed channel")
|
||||
await self._queue.put(item)
|
||||
return self
|
||||
|
||||
async def recieve(self) -> Optional[T]:
|
||||
"""
|
||||
Returns the next item from this channel when it becomes available,
|
||||
or None if the channel is closed before another item is sent.
|
||||
:return: An item from the channel
|
||||
"""
|
||||
if self.done():
|
||||
raise ChannelDone("Cannot recieve from a closed channel")
|
||||
self._waiting_recievers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
return None
|
||||
return result
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close this channel to new items
|
||||
"""
|
||||
self._closed = True
|
||||
asyncio.ensure_future(self._flush_queue())
|
||||
|
||||
async def _flush_queue(self):
|
||||
"""
|
||||
To be called after the channel is closed. Pushes a number of self.__flush
|
||||
objects to the queue to ensure no waiting consumers get deadlocked.
|
||||
"""
|
||||
if not self._flushed:
|
||||
self._flushed = True
|
||||
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
||||
for _ in range(deadlocked_recievers):
|
||||
await self._queue.put(self.__flush)
|
||||
|
||||
# A special signal object for flushing the queue when the channel is closed
|
||||
__flush = object()
|
0
betterproto/lib/__init__.py
Normal file
0
betterproto/lib/__init__.py
Normal file
0
betterproto/lib/google/__init__.py
Normal file
0
betterproto/lib/google/__init__.py
Normal file
1312
betterproto/lib/google/protobuf/__init__.py
Normal file
1312
betterproto/lib/google/protobuf/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
2
betterproto/plugin.bat
Normal file
2
betterproto/plugin.bat
Normal file
@@ -0,0 +1,2 @@
|
||||
@SET plugin_dir=%~dp0
|
||||
@python %plugin_dir%/plugin.py %*
|
@@ -1,112 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any, List, Tuple
|
||||
from typing import List, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto.compile.importing import get_type_reference
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
import black
|
||||
except ImportError:
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
)
|
||||
import google.protobuf.wrappers_pb2 as google_wrappers
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"Unable to import `black` formatter. Did you install the compiler feature with `pip install betterproto[compiler]`?"
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
import jinja2
|
||||
import stringcase
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
"google.protobuf.DoubleValue": "float",
|
||||
"google.protobuf.FloatValue": "float",
|
||||
"google.protobuf.Int64Value": "int",
|
||||
"google.protobuf.UInt64Value": "int",
|
||||
"google.protobuf.Int32Value": "int",
|
||||
"google.protobuf.UInt32Value": "int",
|
||||
"google.protobuf.BoolValue": "bool",
|
||||
"google.protobuf.StringValue": "str",
|
||||
"google.protobuf.BytesValue": "bytes",
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
if type_name in WRAPPER_TYPES:
|
||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
parts = type_name.split(".")
|
||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||
|
||||
return type_name
|
||||
|
||||
|
||||
def py_type(
|
||||
package: str,
|
||||
imports: set,
|
||||
message: DescriptorProto,
|
||||
descriptor: FieldDescriptorProto,
|
||||
) -> str:
|
||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
||||
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
|
||||
if field.type in [1, 2]:
|
||||
return "float"
|
||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
||||
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
|
||||
return "int"
|
||||
elif descriptor.type == 8:
|
||||
elif field.type == 8:
|
||||
return "bool"
|
||||
elif descriptor.type == 9:
|
||||
elif field.type == 9:
|
||||
return "str"
|
||||
elif descriptor.type in [11, 14]:
|
||||
elif field.type in [11, 14]:
|
||||
# Type referencing another defined Message or a named enum
|
||||
return get_ref_type(package, imports, descriptor.type_name)
|
||||
elif descriptor.type == 12:
|
||||
return get_type_reference(package, imports, field.type_name)
|
||||
elif field.type == 12:
|
||||
return "bytes"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
||||
raise NotImplementedError(f"Unknown type {field.type}")
|
||||
|
||||
|
||||
def get_py_zero(type_num: int) -> str:
|
||||
zero = 0
|
||||
def get_py_zero(type_num: int) -> Union[str, float]:
|
||||
zero: Union[str, float] = 0
|
||||
if type_num in []:
|
||||
zero = 0.0
|
||||
elif type_num == 8:
|
||||
@@ -122,7 +75,7 @@ def get_py_zero(type_num: int) -> str:
|
||||
|
||||
|
||||
def traverse(proto_file):
|
||||
def _traverse(path, items, prefix = ''):
|
||||
def _traverse(path, items, prefix=""):
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the heirarchy.
|
||||
item.name = next_prefix = prefix + item.name
|
||||
@@ -167,25 +120,28 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
||||
|
||||
|
||||
def generate_code(request, response):
|
||||
plugin_options = request.parameter.split(",") if request.parameter else []
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
|
||||
)
|
||||
template = env.get_template("template.py")
|
||||
template = env.get_template("template.py.j2")
|
||||
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
out = proto_file.package
|
||||
if out == "google.protobuf":
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
continue
|
||||
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
|
||||
|
||||
if out not in output_map:
|
||||
output_map[out] = {"package": proto_file.package, "files": []}
|
||||
output_map[out]["files"].append(proto_file)
|
||||
if output_file not in output_map:
|
||||
output_map[output_file] = {"package": proto_file.package, "files": []}
|
||||
output_map[output_file]["files"].append(proto_file)
|
||||
|
||||
# TODO: Figure out how to handle gRPC request/response messages and add
|
||||
# processing below for Service.
|
||||
@@ -204,17 +160,10 @@ def generate_code(request, response):
|
||||
"services": [],
|
||||
}
|
||||
|
||||
type_mapping = {}
|
||||
|
||||
for proto_file in options["files"]:
|
||||
# print(proto_file.message_type, file=sys.stderr)
|
||||
# print(proto_file.service, file=sys.stderr)
|
||||
# print(proto_file.source_code_info, file=sys.stderr)
|
||||
|
||||
item: DescriptorProto
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
|
||||
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
@@ -231,7 +180,7 @@ def generate_code(request, response):
|
||||
)
|
||||
|
||||
for i, f in enumerate(item.field):
|
||||
t = py_type(package, output["imports"], item, f)
|
||||
t = py_type(package, output["imports"], f)
|
||||
zero = get_py_zero(f.type)
|
||||
|
||||
repeated = False
|
||||
@@ -240,11 +189,13 @@ def generate_code(request, response):
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
|
||||
field_wraps = ""
|
||||
if f.type_name.startswith(
|
||||
".google.protobuf"
|
||||
) and f.type_name.endswith("Value"):
|
||||
w = f.type_name.split(".").pop()[:-5].upper()
|
||||
field_wraps = f"betterproto.TYPE_{w}"
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", f.type_name
|
||||
)
|
||||
if match_wrapper:
|
||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||
if hasattr(betterproto, wrapped_type):
|
||||
field_wraps = f"betterproto.{wrapped_type}"
|
||||
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
@@ -264,13 +215,11 @@ def generate_code(request, response):
|
||||
k = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[0],
|
||||
)
|
||||
v = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[1],
|
||||
)
|
||||
t = f"Dict[{k}, {v}]"
|
||||
@@ -306,7 +255,7 @@ def generate_code(request, response):
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
"py_name": safe_snake_case(f.name),
|
||||
"py_name": pythonize_field_name(f.name),
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
@@ -347,17 +296,14 @@ def generate_code(request, response):
|
||||
|
||||
data = {
|
||||
"name": service.name,
|
||||
"py_name": stringcase.pascalcase(service.name),
|
||||
"py_name": pythonize_class_name(service.name),
|
||||
"comment": get_comment(proto_file, [6, i]),
|
||||
"methods": [],
|
||||
}
|
||||
|
||||
for j, method in enumerate(service.method):
|
||||
if method.client_streaming:
|
||||
raise NotImplementedError("Client streaming not yet supported")
|
||||
|
||||
input_message = None
|
||||
input_type = get_ref_type(
|
||||
input_type = get_type_reference(
|
||||
package, output["imports"], method.input_type
|
||||
).strip('"')
|
||||
for msg in output["messages"]:
|
||||
@@ -371,23 +317,30 @@ def generate_code(request, response):
|
||||
data["methods"].append(
|
||||
{
|
||||
"name": method.name,
|
||||
"py_name": stringcase.snakecase(method.name),
|
||||
"py_name": pythonize_method_name(method.name),
|
||||
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
||||
"route": f"/{package}.{service.name}/{method.name}",
|
||||
"input": get_ref_type(
|
||||
"input": get_type_reference(
|
||||
package, output["imports"], method.input_type
|
||||
).strip('"'),
|
||||
"input_message": input_message,
|
||||
"output": get_ref_type(
|
||||
package, output["imports"], method.output_type
|
||||
"output": get_type_reference(
|
||||
package,
|
||||
output["imports"],
|
||||
method.output_type,
|
||||
unwrap=False,
|
||||
).strip('"'),
|
||||
"client_streaming": method.client_streaming,
|
||||
"server_streaming": method.server_streaming,
|
||||
}
|
||||
)
|
||||
|
||||
if method.client_streaming:
|
||||
output["typing_imports"].add("AsyncIterable")
|
||||
output["typing_imports"].add("Iterable")
|
||||
output["typing_imports"].add("Union")
|
||||
if method.server_streaming:
|
||||
output["typing_imports"].add("AsyncGenerator")
|
||||
output["typing_imports"].add("AsyncIterator")
|
||||
|
||||
output["services"].append(data)
|
||||
|
||||
@@ -397,8 +350,7 @@ def generate_code(request, response):
|
||||
|
||||
# Fill response
|
||||
f = response.file.add()
|
||||
# print(filename, file=sys.stderr)
|
||||
f.name = filename.replace(".", os.path.sep) + ".py"
|
||||
f.name = filename
|
||||
|
||||
# Render and then format the output file.
|
||||
f.content = black.format_str(
|
||||
@@ -406,32 +358,23 @@ def generate_code(request, response):
|
||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
||||
)
|
||||
|
||||
inits = set([""])
|
||||
for f in response.file:
|
||||
# Ensure output paths exist
|
||||
# print(f.name, file=sys.stderr)
|
||||
dirnames = os.path.dirname(f.name)
|
||||
if dirnames:
|
||||
os.makedirs(dirnames, exist_ok=True)
|
||||
base = ""
|
||||
for part in dirnames.split(os.path.sep):
|
||||
base = os.path.join(base, part)
|
||||
inits.add(base)
|
||||
|
||||
for base in inits:
|
||||
name = os.path.join(base, "__init__.py")
|
||||
|
||||
if os.path.exists(name):
|
||||
# Never overwrite inits as they may have custom stuff in them.
|
||||
continue
|
||||
# Make each output directory a package with __init__ file
|
||||
output_paths = set(pathlib.Path(path) for path in output_map.keys())
|
||||
init_files = (
|
||||
set(
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
)
|
||||
- output_paths
|
||||
)
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
init.name = name
|
||||
init.content = b""
|
||||
init.name = str(init_file)
|
||||
|
||||
filenames = sorted([f.name for f in response.file])
|
||||
for fname in filenames:
|
||||
print(f"Writing {fname}", file=sys.stderr)
|
||||
for filename in sorted(output_paths.union(init_files)):
|
||||
print(f"Writing {filename}", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
|
@@ -63,34 +63,72 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
|
||||
{% endif %}
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.input_message and method.input_message.properties -%}, *,
|
||||
{%- for field in method.input_message.properties -%}
|
||||
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
|
||||
Optional[{{ field.type }}]
|
||||
{%- else -%}
|
||||
{{ field.type }}
|
||||
{%- endif -%} = {{ field.zero }}
|
||||
{%- if not loop.last %}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
{% endif %}
|
||||
{% if not method.client_streaming %}
|
||||
request = {{ method.input }}()
|
||||
{% for field in method.input_message.properties %}
|
||||
{% if field.field_type == 'message' %}
|
||||
if {{ field.name }} is not None:
|
||||
request.{{ field.name }} = {{ field.name }}
|
||||
if {{ field.py_name }} is not None:
|
||||
request.{{ field.py_name }} = {{ field.py_name }}
|
||||
{% else %}
|
||||
request.{{ field.name }} = {{ field.name }}
|
||||
request.{{ field.py_name }} = {{ field.py_name }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if method.server_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
async for response in self._stream_stream(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
{{ method.input }},
|
||||
{{ method.output }},
|
||||
):
|
||||
yield response
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
async for response in self._unary_stream(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
{{ method.output }},
|
||||
):
|
||||
yield response
|
||||
{% else %}
|
||||
|
||||
{% endif %}{# if client streaming #}
|
||||
{% else %}{# i.e. not server streaming #}
|
||||
{% if method.client_streaming %}
|
||||
return await self._stream_unary(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
{{ method.input }},
|
||||
{{ method.output }}
|
||||
)
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
return await self._unary_unary(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
{{ method.output }},
|
||||
{{ method.output }}
|
||||
)
|
||||
{% endif %}{# client streaming #}
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
91
betterproto/tests/README.md
Normal file
91
betterproto/tests/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# Standard Tests Development Guide
|
||||
|
||||
Standard test cases are found in [betterproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation.
|
||||
|
||||
```
|
||||
inputs/
|
||||
bool/
|
||||
double/
|
||||
int32/
|
||||
...
|
||||
```
|
||||
|
||||
## Test case directory structure
|
||||
|
||||
Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
|
||||
|
||||
```bash
|
||||
bool/
|
||||
bool.proto
|
||||
bool.json # optional
|
||||
test_bool.py # optional
|
||||
```
|
||||
|
||||
### proto
|
||||
|
||||
`<name>.proto` — *The protobuf message to test*
|
||||
|
||||
```protobuf
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
bool value = 1;
|
||||
}
|
||||
```
|
||||
|
||||
You can add multiple `.proto` files to the test case, as long as one file matches the directory name.
|
||||
|
||||
### json
|
||||
|
||||
`<name>.json` — *Test-data to validate the message with*
|
||||
|
||||
```json
|
||||
{
|
||||
"value": true
|
||||
}
|
||||
```
|
||||
|
||||
### pytest
|
||||
|
||||
`test_<name>.py` — *Custom test to validate specific aspects of the generated class*
|
||||
|
||||
```python
|
||||
from betterproto.tests.output_betterproto.bool.bool import Test
|
||||
|
||||
def test_value():
|
||||
message = Test()
|
||||
assert not message.value, "Boolean is False by default"
|
||||
```
|
||||
|
||||
## Standard tests
|
||||
|
||||
The following tests are automatically executed for all cases:
|
||||
|
||||
- [x] Can the generated python code be imported?
|
||||
- [x] Can the generated message class be instantiated?
|
||||
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
|
||||
- _when `.json` is present_
|
||||
|
||||
## Running the tests
|
||||
|
||||
- `pipenv run generate`
|
||||
This generates:
|
||||
- `betterproto/tests/output_betterproto` — *the plugin generated python classes*
|
||||
- `betterproto/tests/output_reference` — *reference implementation classes*
|
||||
- `pipenv run test`
|
||||
|
||||
## Intentionally Failing tests
|
||||
|
||||
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
|
||||
|
||||
When running `pytest`, they show up as `x` or `X` in the test results.
|
||||
|
||||
```
|
||||
betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%]
|
||||
```
|
||||
|
||||
- `.` — PASSED
|
||||
- `x` — XFAIL: expected failure
|
||||
- `X` — XPASS: expected failure, but still passed
|
||||
|
||||
Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py)
|
0
betterproto/tests/__init__.py
Normal file
0
betterproto/tests/__init__.py
Normal file
189
betterproto/tests/generate.py
Normal file → Executable file
189
betterproto/tests/generate.py
Normal file → Executable file
@@ -1,84 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
from betterproto.tests.util import (
|
||||
get_directories,
|
||||
inputs_path,
|
||||
output_path_betterproto,
|
||||
output_path_reference,
|
||||
protoc_plugin,
|
||||
protoc_reference,
|
||||
)
|
||||
|
||||
# Force pure-python implementation instead of C++, otherwise imports
|
||||
# break things because we can't properly reset the symbol database.
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from google.protobuf import symbol_database
|
||||
from google.protobuf.descriptor_pool import DescriptorPool
|
||||
from google.protobuf.json_format import MessageToJson, Parse
|
||||
def clear_directory(dir_path: Path):
|
||||
for file_or_directory in dir_path.glob("*"):
|
||||
if file_or_directory.is_dir():
|
||||
shutil.rmtree(file_or_directory)
|
||||
else:
|
||||
file_or_directory.unlink()
|
||||
|
||||
|
||||
root = os.path.dirname(os.path.realpath(__file__))
|
||||
async def generate(whitelist: Set[str], verbose: bool):
|
||||
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
|
||||
|
||||
path_whitelist = set()
|
||||
name_whitelist = set()
|
||||
for item in whitelist:
|
||||
if item in test_case_names:
|
||||
name_whitelist.add(item)
|
||||
continue
|
||||
path_whitelist.add(item)
|
||||
|
||||
generation_tasks = []
|
||||
for test_case_name in sorted(test_case_names):
|
||||
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
|
||||
if (
|
||||
whitelist
|
||||
and str(test_case_input_path) not in path_whitelist
|
||||
and test_case_name not in name_whitelist
|
||||
):
|
||||
continue
|
||||
generation_tasks.append(
|
||||
generate_test_case_output(test_case_input_path, test_case_name, verbose)
|
||||
)
|
||||
|
||||
failed_test_cases = []
|
||||
# Wait for all subprocs and match any failures to names to report
|
||||
for test_case_name, result in zip(
|
||||
sorted(test_case_names), await asyncio.gather(*generation_tasks)
|
||||
):
|
||||
if result != 0:
|
||||
failed_test_cases.append(test_case_name)
|
||||
|
||||
if failed_test_cases:
|
||||
sys.stderr.write(
|
||||
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
|
||||
)
|
||||
for failed_test_case in failed_test_cases:
|
||||
sys.stderr.write(f"- {failed_test_case}\n")
|
||||
|
||||
|
||||
def get_files(end: str) -> Generator[str, None, None]:
|
||||
for r, dirs, files in os.walk(root):
|
||||
for filename in [f for f in files if f.endswith(end)]:
|
||||
yield os.path.join(r, filename)
|
||||
async def generate_test_case_output(
|
||||
test_case_input_path: Path, test_case_name: str, verbose: bool
|
||||
) -> int:
|
||||
"""
|
||||
Returns the max of the subprocess return values
|
||||
"""
|
||||
|
||||
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
|
||||
|
||||
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||
|
||||
clear_directory(test_case_output_path_reference)
|
||||
clear_directory(test_case_output_path_betterproto)
|
||||
|
||||
(
|
||||
(ref_out, ref_err, ref_code),
|
||||
(plg_out, plg_err, plg_code),
|
||||
) = await asyncio.gather(
|
||||
protoc_reference(test_case_input_path, test_case_output_path_reference),
|
||||
protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
|
||||
)
|
||||
|
||||
message = f"Generated output for {test_case_name!r}"
|
||||
if verbose:
|
||||
print(f"\033[31;1;4m{message}\033[0m")
|
||||
if ref_out:
|
||||
sys.stdout.buffer.write(ref_out)
|
||||
if ref_err:
|
||||
sys.stderr.buffer.write(ref_err)
|
||||
if plg_out:
|
||||
sys.stdout.buffer.write(plg_out)
|
||||
if plg_err:
|
||||
sys.stderr.buffer.write(plg_err)
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
else:
|
||||
print(message)
|
||||
|
||||
return max(ref_code, plg_code)
|
||||
|
||||
|
||||
def get_base(filename: str) -> str:
|
||||
return os.path.splitext(os.path.basename(filename))[0]
|
||||
HELP = "\n".join(
|
||||
(
|
||||
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
|
||||
"Generate python classes for standard tests.",
|
||||
"",
|
||||
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||||
" python generate.py inputs/bool inputs/double inputs/enum",
|
||||
"",
|
||||
"NAMES One or more test-case names to generate classes for.",
|
||||
" python generate.py bool double enums",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def ensure_ext(filename: str, ext: str) -> str:
|
||||
if not filename.endswith(ext):
|
||||
return filename + ext
|
||||
return filename
|
||||
def main():
|
||||
if set(sys.argv).intersection({"-h", "--help"}):
|
||||
print(HELP)
|
||||
return
|
||||
if sys.argv[1:2] == ["-v"]:
|
||||
verbose = True
|
||||
whitelist = set(sys.argv[2:])
|
||||
else:
|
||||
verbose = False
|
||||
whitelist = set(sys.argv[1:])
|
||||
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.chdir(root)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]]
|
||||
bases = {get_base(f) for f in proto_files}
|
||||
json_files = [
|
||||
f for f in get_files(".json") if get_base(f).split("-")[0] in bases
|
||||
]
|
||||
else:
|
||||
proto_files = get_files(".proto")
|
||||
json_files = get_files(".json")
|
||||
|
||||
for filename in proto_files:
|
||||
print(f"Generating code for {os.path.basename(filename)}")
|
||||
subprocess.run(
|
||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||
)
|
||||
subprocess.run(
|
||||
f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
for filename in json_files:
|
||||
# Reset the internal symbol database so we can import the `Test` message
|
||||
# multiple times. Ugh.
|
||||
sym = symbol_database.Default()
|
||||
sym.pool = DescriptorPool()
|
||||
|
||||
parts = get_base(filename).split("-")
|
||||
out = filename.replace(".json", ".bin")
|
||||
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
||||
|
||||
imported = importlib.import_module(f"{parts[0]}_pb2")
|
||||
input_json = open(filename).read()
|
||||
parsed = Parse(input_json, imported.Test())
|
||||
serialized = parsed.SerializeToString()
|
||||
preserve = "casing" not in filename
|
||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
|
||||
|
||||
s_loaded = json.loads(serialized_json)
|
||||
in_loaded = json.loads(input_json)
|
||||
|
||||
if s_loaded != in_loaded:
|
||||
raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded)
|
||||
|
||||
open(out, "wb").write(serialized)
|
||||
main()
|
||||
|
0
betterproto/tests/grpc/__init__.py
Normal file
0
betterproto/tests/grpc/__init__.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client, name="clean room", **kwargs):
|
||||
response = await client.do_thing(name=name)
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_recieved(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
THING_TO_DO = "get milk"
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_unary_stream_request():
|
||||
thing_name = "my milkshakes"
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(name=thing_name):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_stream_stream_request():
|
||||
some_things = ["cake", "cricket", "coral reef"]
|
||||
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||
expected_things = (*some_things, *more_things)
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||
# immediately and we'll use it to send more_things later, after recieving some
|
||||
# results
|
||||
request_chan = AsyncChannel()
|
||||
send_initial_requests = asyncio.ensure_future(
|
||||
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||
)
|
||||
response_index = 0
|
||||
async for response in client.get_different_things(request_chan):
|
||||
assert response.name == expected_things[response_index]
|
||||
assert response.version == response_index + 1
|
||||
response_index += 1
|
||||
if more_things:
|
||||
# Send some more requests as we recieve reponses to be sure coordination of
|
||||
# send/recieve events doesn't matter
|
||||
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||
elif not send_initial_requests.done():
|
||||
# Make sure the sending task it completed
|
||||
await send_initial_requests
|
||||
else:
|
||||
# No more things to send make sure channel is closed
|
||||
request_chan.close()
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't recieve all exptected responses"
|
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
import betterproto
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from dataclasses import dataclass
|
||||
import pytest
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message(betterproto.Message):
|
||||
body: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_responses():
|
||||
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||||
|
||||
|
||||
class ClientStub:
|
||||
async def connect(self, requests: AsyncIterator):
|
||||
await asyncio.sleep(0.1)
|
||||
async for request in requests:
|
||||
await asyncio.sleep(0.1)
|
||||
yield request
|
||||
await asyncio.sleep(0.1)
|
||||
yield Message("Done")
|
||||
|
||||
|
||||
async def to_list(generator: AsyncIterator):
|
||||
result = []
|
||||
async for value in generator:
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# channel = Channel(host='127.0.0.1', port=50051)
|
||||
# return ClientStub(channel)
|
||||
return ClientStub()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_before_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_after_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||||
)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
requests.close()
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
responses = client.connect(requests)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
83
betterproto/tests/grpc/thing_service.py
Normal file
83
betterproto/tests/grpc/thing_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ThingService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def do_thing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse([request.name]))
|
||||
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
||||
async def get_thing_versions(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for version_num in range(1, 6):
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=version_num)
|
||||
)
|
||||
|
||||
async def get_different_things(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
# Respond to each input item immediately
|
||||
response_num = 0
|
||||
async for request in stream:
|
||||
response_num += 1
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=response_num)
|
||||
)
|
||||
|
||||
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.do_thing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||
self.do_many_things,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||
self.get_thing_versions,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||
self.get_different_things,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
}
|
6
betterproto/tests/inputs/bool/test_bool.py
Normal file
6
betterproto/tests/inputs/bool/test_bool.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from betterproto.tests.output_betterproto.bool import Test
|
||||
|
||||
|
||||
def test_value():
|
||||
message = Test()
|
||||
assert not message.value, "Boolean is False by default"
|
@@ -9,4 +9,10 @@ enum my_enum {
|
||||
message Test {
|
||||
int32 camelCase = 1;
|
||||
my_enum snake_case = 2;
|
||||
snake_case_message snake_case_message = 3;
|
||||
int32 UPPERCASE = 4;
|
||||
}
|
||||
|
||||
message snake_case_message {
|
||||
|
||||
}
|
23
betterproto/tests/inputs/casing/test_casing.py
Normal file
23
betterproto/tests/inputs/casing/test_casing.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import betterproto.tests.output_betterproto.casing as casing
|
||||
from betterproto.tests.output_betterproto.casing import Test
|
||||
|
||||
|
||||
def test_message_attributes():
|
||||
message = Test()
|
||||
assert hasattr(
|
||||
message, "snake_case_message"
|
||||
), "snake_case field name is same in python"
|
||||
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
|
||||
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
|
||||
|
||||
|
||||
def test_message_casing():
|
||||
assert hasattr(
|
||||
casing, "SnakeCaseMessage"
|
||||
), "snake_case Message name is converted to CamelCase in python"
|
||||
|
||||
|
||||
def test_enum_casing():
|
||||
assert hasattr(
|
||||
casing, "MyEnum"
|
||||
), "snake_case Enum name is converted to CamelCase in python"
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
int32 UPPERCASE = 1;
|
||||
int32 UPPERCASE_V2 = 2;
|
||||
int32 UPPER_CAMEL_CASE = 3;
|
||||
}
|
@@ -0,0 +1,14 @@
|
||||
from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
|
||||
|
||||
|
||||
def test_message_casing():
|
||||
message = Test()
|
||||
assert hasattr(
|
||||
message, "uppercase"
|
||||
), "UPPERCASE attribute is converted to 'uppercase' in python"
|
||||
assert hasattr(
|
||||
message, "uppercase_v2"
|
||||
), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
|
||||
assert hasattr(
|
||||
message, "upper_camel_case"
|
||||
), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
|
21
betterproto/tests/inputs/config.py
Normal file
21
betterproto/tests/inputs/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
|
||||
# Remove from list when fixed.
|
||||
xfail = {
|
||||
"import_circular_dependency",
|
||||
"oneof_enum", # 63
|
||||
"namespace_keywords", # 70
|
||||
"namespace_builtin_types", # 53
|
||||
"googletypes_struct", # 9
|
||||
"googletypes_value", # 9,
|
||||
"import_capitalized_package",
|
||||
"example", # This is the example in the readme. Not a test.
|
||||
}
|
||||
|
||||
services = {
|
||||
"googletypes_response",
|
||||
"googletypes_response_embedded",
|
||||
"service",
|
||||
"import_service_input_message",
|
||||
"googletypes_service_returns_empty",
|
||||
"googletypes_service_returns_googletype",
|
||||
}
|
8
betterproto/tests/inputs/example/example.proto
Normal file
8
betterproto/tests/inputs/example/example.proto
Normal file
@@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package hello;
|
||||
|
||||
// Greeting represents a message you can tell a user.
|
||||
message Greeting {
|
||||
string message = 1;
|
||||
}
|
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"foo": 4294967295,
|
||||
"bar": -2147483648,
|
||||
"baz": "18446744073709551615",
|
||||
"qux": "-9223372036854775808"
|
||||
}
|
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
@@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
fixed32 foo = 1;
|
||||
sfixed32 bar = 2;
|
||||
fixed64 baz = 3;
|
||||
sfixed64 qux = 4;
|
||||
}
|
@@ -1,5 +1,7 @@
|
||||
{
|
||||
"maybe": false,
|
||||
"ts": "1972-01-01T10:00:20.021Z",
|
||||
"duration": "1.200s"
|
||||
"duration": "1.200s",
|
||||
"important": 10,
|
||||
"empty": {}
|
||||
}
|
@@ -3,10 +3,12 @@ syntax = "proto3";
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.BoolValue maybe = 1;
|
||||
google.protobuf.Timestamp ts = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
google.protobuf.Int32Value important = 4;
|
||||
google.protobuf.Empty empty = 5;
|
||||
}
|
@@ -0,0 +1,21 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that wrapped values can be used directly as return values
|
||||
|
||||
service Test {
|
||||
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
|
||||
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
|
||||
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
|
||||
rpc GetBool (Input) returns (google.protobuf.BoolValue);
|
||||
rpc GetString (Input) returns (google.protobuf.StringValue);
|
||||
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import betterproto.lib.google.protobuf as protobuf
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response import TestStub
|
||||
|
||||
test_cases = [
|
||||
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||
(TestStub.get_float, protobuf.FloatValue, 2.5),
|
||||
(TestStub.get_int64, protobuf.Int64Value, -64),
|
||||
(TestStub.get_u_int64, protobuf.UInt64Value, 64),
|
||||
(TestStub.get_int32, protobuf.Int32Value, -32),
|
||||
(TestStub.get_u_int32, protobuf.UInt32Value, 32),
|
||||
(TestStub.get_bool, protobuf.BoolValue, True),
|
||||
(TestStub.get_string, protobuf.StringValue, "string"),
|
||||
(TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_recieves_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[wrapped_value])
|
||||
service = TestStub(channel)
|
||||
|
||||
await service_method(service)
|
||||
|
||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_service_unwraps_response(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
"""
|
||||
grpclib does not unwrap wrapper values returned by services
|
||||
"""
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||
|
||||
response_value = await service_method(service)
|
||||
|
||||
assert response_value == value
|
||||
assert type(response_value) == type(value)
|
@@ -0,0 +1,24 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that wrapped values are supported as part of output message
|
||||
service Test {
|
||||
rpc getOutput (Input) returns (Output);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
||||
|
||||
message Output {
|
||||
google.protobuf.DoubleValue double_value = 1;
|
||||
google.protobuf.FloatValue float_value = 2;
|
||||
google.protobuf.Int64Value int64_value = 3;
|
||||
google.protobuf.UInt64Value uint64_value = 4;
|
||||
google.protobuf.Int32Value int32_value = 5;
|
||||
google.protobuf.UInt32Value uint32_value = 6;
|
||||
google.protobuf.BoolValue bool_value = 7;
|
||||
google.protobuf.StringValue string_value = 8;
|
||||
google.protobuf.BytesValue bytes_value = 9;
|
||||
}
|
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response_embedded import (
|
||||
Output,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
||||
"""
|
||||
We do not not need to implement value unwrapping for embedded well-known types,
|
||||
as this is already handled by grpclib. This test merely shows that this is the case.
|
||||
"""
|
||||
output = Output(
|
||||
double_value=10.0,
|
||||
float_value=12.0,
|
||||
int64_value=-13,
|
||||
uint64_value=14,
|
||||
int32_value=-15,
|
||||
uint32_value=16,
|
||||
bool_value=True,
|
||||
string_value="string",
|
||||
bytes_value=bytes(0xFF)[0:4],
|
||||
)
|
||||
|
||||
service = TestStub(MockChannel(responses=[output]))
|
||||
response = await service.get_output()
|
||||
|
||||
assert response.double_value == 10.0
|
||||
assert response.float_value == 12.0
|
||||
assert response.int64_value == -13
|
||||
assert response.uint64_value == 14
|
||||
assert response.int32_value == -15
|
||||
assert response.uint32_value == 16
|
||||
assert response.bool_value
|
||||
assert response.string_value == "string"
|
||||
assert response.bytes_value == bytes(0xFF)[0:4]
|
@@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
service Test {
|
||||
rpc Send (RequestMessage) returns (google.protobuf.Empty) {
|
||||
}
|
||||
}
|
||||
|
||||
message RequestMessage {
|
||||
}
|
@@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
// Tests that imports are generated correctly when returning Google well-known types
|
||||
|
||||
service Test {
|
||||
rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty);
|
||||
rpc GetStruct (RequestMessage) returns (google.protobuf.Struct);
|
||||
rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue);
|
||||
rpc GetValue (RequestMessage) returns (google.protobuf.Value);
|
||||
}
|
||||
|
||||
message RequestMessage {
|
||||
}
|
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"struct": {
|
||||
"key": true
|
||||
}
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.Struct struct = 1;
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"value1": "hello world",
|
||||
"value2": true,
|
||||
"value3": 1,
|
||||
"value4": null,
|
||||
"value5": [
|
||||
1,
|
||||
2,
|
||||
3
|
||||
]
|
||||
}
|
@@ -0,0 +1,13 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.
|
||||
|
||||
message Test {
|
||||
google.protobuf.Value value1 = 1;
|
||||
google.protobuf.Value value2 = 2;
|
||||
google.protobuf.Value value3 = 3;
|
||||
google.protobuf.Value value4 = 4;
|
||||
google.protobuf.Value value5 = 5;
|
||||
}
|
@@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
|
||||
package Capitalized;
|
||||
|
||||
message Message {
|
||||
|
||||
}
|
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "capitalized.proto";
|
||||
|
||||
// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
|
||||
|
||||
message Test {
|
||||
Capitalized.Message message = 1;
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package package.childpackage;
|
||||
|
||||
message ChildMessage {
|
||||
|
||||
}
|
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "package_message.proto";
|
||||
|
||||
// Tests generated imports when a message in a package refers to a message in a nested child package.
|
||||
|
||||
message Test {
|
||||
package.PackageMessage message = 1;
|
||||
}
|
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "child.proto";
|
||||
|
||||
package package;
|
||||
|
||||
message PackageMessage {
|
||||
package.childpackage.ChildMessage c = 1;
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package childpackage;
|
||||
|
||||
message Message {
|
||||
|
||||
}
|
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "child.proto";
|
||||
|
||||
// Tests generated imports when a message in root refers to a message in a child package.
|
||||
|
||||
message Test {
|
||||
childpackage.Message child = 1;
|
||||
}
|
@@ -0,0 +1,28 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "root.proto";
|
||||
import "other.proto";
|
||||
|
||||
// This test-case verifies support for circular dependencies in the generated python files.
|
||||
//
|
||||
// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
|
||||
//
|
||||
// Scenario:
|
||||
//
|
||||
// The proto messages depend on each other in a non-circular way:
|
||||
//
|
||||
// Test -------> RootPackageMessage <--------------.
|
||||
// `------------------------------------> OtherPackageMessage
|
||||
//
|
||||
// Test and RootPackageMessage are in different files, but belong to the same package (root):
|
||||
//
|
||||
// (Test -------> RootPackageMessage) <------------.
|
||||
// `------------------------------------> OtherPackageMessage
|
||||
//
|
||||
// After grouping the packages into single files or modules, a circular dependency is created:
|
||||
//
|
||||
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
|
||||
message Test {
|
||||
RootPackageMessage message = 1;
|
||||
other.OtherPackageMessage other = 2;
|
||||
}
|
@@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "root.proto";
|
||||
package other;
|
||||
|
||||
message OtherPackageMessage {
|
||||
RootPackageMessage rootPackageMessage = 1;
|
||||
}
|
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message RootPackageMessage {
|
||||
|
||||
}
|
@@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package cousin.cousin_subpackage;
|
||||
|
||||
message CousinMessage {
|
||||
}
|
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
@@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package test.subpackage;
|
||||
|
||||
import "cousin.proto";
|
||||
|
||||
// Verify that we can import message unrelated to us
|
||||
|
||||
message Test {
|
||||
cousin.cousin_subpackage.CousinMessage message = 1;
|
||||
}
|
@@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package cousin.subpackage;
|
||||
|
||||
message CousinMessage {
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package test.subpackage;
|
||||
|
||||
import "cousin.proto";
|
||||
|
||||
// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
|
||||
|
||||
message Test {
|
||||
cousin.subpackage.CousinMessage message = 1;
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "users_v1.proto";
|
||||
import "posts_v1.proto";
|
||||
|
||||
// Tests generated message can correctly reference two packages with the same leaf-name
|
||||
|
||||
message Test {
|
||||
users.v1.User user = 1;
|
||||
posts.v1.Post post = 2;
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package posts.v1;
|
||||
|
||||
message Post {
|
||||
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package users.v1;
|
||||
|
||||
message User {
|
||||
|
||||
}
|
@@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "parent_package_message.proto";
|
||||
|
||||
package parent.child;
|
||||
|
||||
// Tests generated imports when a message refers to a message defined in its parent package
|
||||
|
||||
message Test {
|
||||
ParentPackageMessage message_implicit = 1;
|
||||
parent.ParentPackageMessage message_explicit = 2;
|
||||
}
|
@@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package parent;
|
||||
|
||||
message ParentPackageMessage {
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package child;
|
||||
|
||||
import "root.proto";
|
||||
|
||||
// Verify that we can import root message from child package
|
||||
|
||||
message Test {
|
||||
RootMessage message = 1;
|
||||
}
|
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
|
||||
message RootMessage {
|
||||
}
|
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "sibling.proto";
|
||||
|
||||
// Tests generated imports when a message in the root package refers to another message in the root package
|
||||
|
||||
message Test {
|
||||
SiblingMessage sibling = 1;
|
||||
}
|
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message SiblingMessage {
|
||||
|
||||
}
|
@@ -0,0 +1,15 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "request_message.proto";
|
||||
|
||||
// Tests generated service correctly imports the RequestMessage
|
||||
|
||||
service Test {
|
||||
rpc DoThing (RequestMessage) returns (RequestResponse);
|
||||
}
|
||||
|
||||
|
||||
message RequestResponse {
|
||||
int32 value = 1;
|
||||
}
|
||||
|
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message RequestMessage {
|
||||
int32 argument = 1;
|
||||
}
|
@@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.import_service_input_message import (
|
||||
RequestResponse,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_correctly_imports_reference_message():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing()
|
||||
assert mock_response == response
|
4
betterproto/tests/inputs/int32/int32.json
Normal file
4
betterproto/tests/inputs/int32/int32.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"positive": 150,
|
||||
"negative": -150
|
||||
}
|
@@ -3,5 +3,6 @@ syntax = "proto3";
|
||||
// Some documentation about the Test message.
|
||||
message Test {
|
||||
// Some documentation about the count.
|
||||
int32 count = 1;
|
||||
int32 positive = 1;
|
||||
int32 negative = 2;
|
||||
}
|
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"int": "value-for-int",
|
||||
"float": "value-for-float",
|
||||
"complex": "value-for-complex",
|
||||
"list": "value-for-list",
|
||||
"tuple": "value-for-tuple",
|
||||
"range": "value-for-range",
|
||||
"str": "value-for-str",
|
||||
"bytearray": "value-for-bytearray",
|
||||
"bytes": "value-for-bytes",
|
||||
"memoryview": "value-for-memoryview",
|
||||
"set": "value-for-set",
|
||||
"frozenset": "value-for-frozenset",
|
||||
"map": "value-for-map",
|
||||
"bool": "value-for-bool"
|
||||
}
|
@@ -0,0 +1,38 @@
|
||||
syntax = "proto3";
|
||||
|
||||
// Tests that messages may contain fields with names that are python types
|
||||
|
||||
message Test {
|
||||
// https://docs.python.org/2/library/stdtypes.html#numeric-types-int-float-long-complex
|
||||
string int = 1;
|
||||
string float = 2;
|
||||
string complex = 3;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#sequence-types-list-tuple-range
|
||||
string list = 4;
|
||||
string tuple = 5;
|
||||
string range = 6;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#str
|
||||
string str = 7;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#bytearray-objects
|
||||
string bytearray = 8;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#bytes-and-bytearray-operations
|
||||
string bytes = 9;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#memory-views
|
||||
string memoryview = 10;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#set-types-set-frozenset
|
||||
string set = 11;
|
||||
string frozenset = 12;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#dict
|
||||
string map = 13;
|
||||
string dict = 14;
|
||||
|
||||
// https://docs.python.org/3/library/stdtypes.html#boolean-values
|
||||
string bool = 15;
|
||||
}
|
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"False": 1,
|
||||
"None": 2,
|
||||
"True": 3,
|
||||
"and": 4,
|
||||
"as": 5,
|
||||
"assert": 6,
|
||||
"async": 7,
|
||||
"await": 8,
|
||||
"break": 9,
|
||||
"class": 10,
|
||||
"continue": 11,
|
||||
"def": 12,
|
||||
"del": 13,
|
||||
"elif": 14,
|
||||
"else": 15,
|
||||
"except": 16,
|
||||
"finally": 17,
|
||||
"for": 18,
|
||||
"from": 19,
|
||||
"global": 20,
|
||||
"if": 21,
|
||||
"import": 22,
|
||||
"in": 23,
|
||||
"is": 24,
|
||||
"lambda": 25,
|
||||
"nonlocal": 26,
|
||||
"not": 27,
|
||||
"or": 28,
|
||||
"pass": 29,
|
||||
"raise": 30,
|
||||
"return": 31,
|
||||
"try": 32,
|
||||
"while": 33,
|
||||
"with": 34,
|
||||
"yield": 35
|
||||
}
|
@@ -0,0 +1,44 @@
|
||||
syntax = "proto3";
|
||||
|
||||
// Tests that messages may contain fields that are Python keywords
|
||||
//
|
||||
// Generated with Python 3.7.6
|
||||
// print('\n'.join(f'string {k} = {i+1};' for i,k in enumerate(keyword.kwlist)))
|
||||
|
||||
message Test {
|
||||
string False = 1;
|
||||
string None = 2;
|
||||
string True = 3;
|
||||
string and = 4;
|
||||
string as = 5;
|
||||
string assert = 6;
|
||||
string async = 7;
|
||||
string await = 8;
|
||||
string break = 9;
|
||||
string class = 10;
|
||||
string continue = 11;
|
||||
string def = 12;
|
||||
string del = 13;
|
||||
string elif = 14;
|
||||
string else = 15;
|
||||
string except = 16;
|
||||
string finally = 17;
|
||||
string for = 18;
|
||||
string from = 19;
|
||||
string global = 20;
|
||||
string if = 21;
|
||||
string import = 22;
|
||||
string in = 23;
|
||||
string is = 24;
|
||||
string lambda = 25;
|
||||
string nonlocal = 26;
|
||||
string not = 27;
|
||||
string or = 28;
|
||||
string pass = 29;
|
||||
string raise = 30;
|
||||
string return = 31;
|
||||
string try = 32;
|
||||
string while = 33;
|
||||
string with = 34;
|
||||
string yield = 35;
|
||||
}
|
@@ -15,4 +15,4 @@ message Test {
|
||||
|
||||
message Sibling {
|
||||
int32 foo = 1;
|
||||
}
|
||||
}
|
19
betterproto/tests/inputs/nested2/nested2.proto
Normal file
19
betterproto/tests/inputs/nested2/nested2.proto
Normal file
@@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "package.proto";
|
||||
|
||||
message Game {
|
||||
message Player {
|
||||
enum Race {
|
||||
human = 0;
|
||||
orc = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message Test {
|
||||
Game game = 1;
|
||||
Game.Player GamePlayer = 2;
|
||||
Game.Player.Race GamePlayerRace = 3;
|
||||
equipment.Weapon Weapon = 4;
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user