9 Commits

Author SHA1 Message Date
boukeversteegh
0c02d1b21a Update with master 2020-07-04 18:54:26 +02:00
Bouke Versteegh
ac32bcd25a Merge branch 'master' into michael-sayapin/master 2020-07-04 11:23:42 +02:00
boukeversteegh
72855227bd Fix import 2020-06-25 15:52:43 +02:00
boukeversteegh
47081617c2 Merge branch 'master' into michael-sayapin/master 2020-06-25 15:02:50 +02:00
boukeversteegh
d734206fe5 Rename test-case to keep it close with other enum test 2020-06-24 21:55:31 +02:00
Bouke Versteegh
bbf40f9694 Mark test xfail 2020-06-24 21:48:26 +02:00
Michael Sayapin
6671d87cef Conformance formatting 2020-06-17 11:37:36 +08:00
Michael Sayapin
cd66b0511a Fixes enum class name 2020-06-15 13:52:58 +08:00
Michael Sayapin
c48ca2e386 Test to_dict with missing enum values 2020-06-15 12:51:51 +08:00
210 changed files with 2086 additions and 7165 deletions

View File

@@ -1,23 +0,0 @@
# Contributing
There's lots to do, and we're working hard, so any help is welcome!
- :speech_balloon: Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
What can you do?
- :+1: Vote on [issues](https://github.com/danielgtaylor/python-betterproto/issues).
- :speech_balloon: Give feedback on [Pull Requests](https://github.com/danielgtaylor/python-betterproto/pulls) and [Issues](https://github.com/danielgtaylor/python-betterproto/issues):
- Suggestions
- Express approval
- Raise concerns
- :small_red_triangle: Create an issue:
- File a bug (please check its not a duplicate)
- Propose an enhancement
- :white_check_mark: Create a PR:
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/betterproto/tests/README.md) to make bug-fixing easier
- Fix any of the open issues
- [Good first issues](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
- [Issues with tests](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22has+test%22)
- New bugfix or idea
- If you'd like to discuss your idea first, join us on Slack!

View File

@@ -1,69 +1,74 @@
name: CI name: CI
on: on: [push, pull_request]
push:
branches:
- master
pull_request:
branches:
- '**'
jobs: jobs:
tests:
name: ${{ matrix.os }} / ${{ matrix.python-version }} check-formatting:
runs-on: ${{ matrix.os }}-latest runs-on: ubuntu-latest
strategy:
matrix: name: Consult black on python formatting
os: [Ubuntu, MacOS, Windows]
python-version: ['3.6.7', '3.7', '3.8', '3.9', '3.10']
exclude:
- os: Windows
python-version: 3.6
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7
- uses: Gr1N/setup-poetry@v2
- uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install dependencies
run: poetry install
- name: Run black
run: make check-style
- name: Set up Python ${{ matrix.python-version }} run-tests:
uses: actions/setup-python@v2 runs-on: ubuntu-latest
name: Run tests with tox
strategy:
matrix:
python-version: [ '3.6', '3.7', '3.8']
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- uses: Gr1N/setup-poetry@v2
- name: Get full Python version - uses: actions/cache@v2
id: full-python-version
shell: bash
run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")
- name: Install poetry
shell: bash
run: |
python -m pip install poetry
echo "$HOME/.poetry/bin" >> $GITHUB_PATH
- name: Configure poetry
shell: bash
run: poetry config virtualenvs.in-project true
- name: Set up cache
uses: actions/cache@v2
id: cache
with: with:
path: .venv path: ~/.cache/pypoetry/virtualenvs
key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }} key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
restore-keys: |
- name: Ensure cache is healthy ${{ runner.os }}-poetry-
if: steps.cache.outputs.cache-hit == 'true'
shell: bash
run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv
- name: Install dependencies - name: Install dependencies
shell: bash
run: | run: |
poetry run python -m pip install pip -U sudo apt install protobuf-compiler libprotobuf-dev
poetry install poetry install
- name: Run tests
run: |
make generate
make test
- name: Generate code from proto files build-release:
shell: bash runs-on: ubuntu-latest
run: poetry run python -m tests.generate -v
- name: Execute test suite steps:
shell: bash - uses: actions/checkout@v2
run: poetry run python -m pytest tests/ - uses: actions/setup-python@v2
with:
python-version: 3.7
- uses: Gr1N/setup-poetry@v2
- name: Build package
run: poetry build
- name: Publish package to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
run: poetry publish -n
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}

View File

@@ -1,26 +0,0 @@
name: Code Quality
on:
push:
branches:
- master
pull_request:
branches:
- '**'
jobs:
check-formatting:
name: Check code/doc formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run Black
uses: lgeiger/black-action@master
with:
args: --check src/ tests/ benchmarks/
- name: Install rST dependcies
run: python -m pip install doc8
- name: Lint documentation for errors
run: python -m doc8 docs --max-line-length 88 --ignore-path-errors "docs/migrating.rst;D001"
# it has a table which is longer than 88 characters long

View File

@@ -1,31 +0,0 @@
name: Release
on:
push:
branches:
- master
tags:
- '**'
pull_request:
branches:
- '**'
jobs:
packaging:
name: Distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install poetry
run: python -m pip install poetry
- name: Build package
run: poetry build
- name: Publish package to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}
run: poetry publish -n

5
.gitignore vendored
View File

@@ -6,7 +6,7 @@
.pytest_cache .pytest_cache
.python-version .python-version
build/ build/
tests/output_* betterproto/tests/output_*
**/__pycache__ **/__pycache__
dist dist
**/*.egg-info **/*.egg-info
@@ -14,6 +14,3 @@ output
.idea .idea
.DS_Store .DS_Store
.tox .tox
.venv
.asv
venv

View File

@@ -1,17 +0,0 @@
version: 2
formats: []
build:
image: latest
sphinx:
configuration: docs/conf.py
fail_on_warning: false
python:
version: 3.7
install:
- method: pip
path: .
extra_requirements:
- dev

View File

@@ -5,83 +5,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [2.0.0b4] - 2022-01-03
- **Breaking**: the minimum Python version has been bumped to `3.6.2`
- Always add `AsyncIterator` to imports if there are services [#264](https://github.com/danielgtaylor/python-betterproto/pull/264)
- Allow parsing of messages from `ByteStrings` [#266](https://github.com/danielgtaylor/python-betterproto/pull/266)
- Add support for proto3 optional [#281](https://github.com/danielgtaylor/python-betterproto/pull/281)
- Fix compilation of fields with names identical to builtin types [#294](https://github.com/danielgtaylor/python-betterproto/pull/294)
- Fix default values for enum service args [#299](https://github.com/danielgtaylor/python-betterproto/pull/299)
## [2.0.0b3] - 2021-04-07
- Generate grpclib service stubs [#170](https://github.com/danielgtaylor/python-betterproto/pull/170)
- Add \_\_version\_\_ attribute to package [#134](https://github.com/danielgtaylor/python-betterproto/pull/134)
- Use betterproto generated messages in the plugin [#161](https://github.com/danielgtaylor/python-betterproto/pull/161)
- Sort the list of sources in generated file headers [#164](https://github.com/danielgtaylor/python-betterproto/pull/164)
- Micro-optimization: use tuples instead of lists for conditions [#228](https://github.com/danielgtaylor/python-betterproto/pull/228)
- Improve datestring parsing [#213](https://github.com/danielgtaylor/python-betterproto/pull/213)
- Fix serialization of repeated fields with empty messages [#180](https://github.com/danielgtaylor/python-betterproto/pull/180)
- Fix compilation of fields named 'bytes' or 'str' [#226](https://github.com/danielgtaylor/python-betterproto/pull/226)
- Fix json serialization of infinite and nan floats/doubles [#215](https://github.com/danielgtaylor/python-betterproto/pull/215)
- Fix template bug resulting in empty \_\_post_init\_\_ methods [#162](https://github.com/danielgtaylor/python-betterproto/pull/162)
- Fix serialization of zero-value messages in a oneof group [#176](https://github.com/danielgtaylor/python-betterproto/pull/176)
- Fix missing typing and datetime imports [#183](https://github.com/danielgtaylor/python-betterproto/pull/183)
- Fix code generation for empty services [#222](https://github.com/danielgtaylor/python-betterproto/pull/222)
- Fix Message.to_dict and from_dict handling of repeated timestamps and durations [#211](https://github.com/danielgtaylor/python-betterproto/pull/211)
- Fix incorrect routes in generated client when service is not in a package [#177](https://github.com/danielgtaylor/python-betterproto/pull/177)
## [2.0.0b2] - 2020-11-24
- Add support for deprecated message and fields [#126](https://github.com/danielgtaylor/python-betterproto/pull/126)
- Add support for recursive messages [#130](https://github.com/danielgtaylor/python-betterproto/pull/130)
- Add support for `bool(Message)` [#142](https://github.com/danielgtaylor/python-betterproto/pull/142)
- Improve support for Python 3.9 [#140](https://github.com/danielgtaylor/python-betterproto/pull/140) [#173](https://github.com/danielgtaylor/python-betterproto/pull/173)
- Improve keyword sanitisation for generated code [#137](https://github.com/danielgtaylor/python-betterproto/pull/137)
- Fix missing serialized_on_wire when message contains only lists [#81](https://github.com/danielgtaylor/python-betterproto/pull/81)
- Fix circular dependencies [#100](https://github.com/danielgtaylor/python-betterproto/pull/100)
- Fix to_dict enum fields when numbering is not consecutive [#102](https://github.com/danielgtaylor/python-betterproto/pull/102)
- Fix argument generation for stub methods when using `import` with proto definition [#103](https://github.com/danielgtaylor/python-betterproto/pull/103)
- Fix missing async/await keywords when casing [#104](https://github.com/danielgtaylor/python-betterproto/pull/104)
- Fix mutable default arguments in generated code [#105](https://github.com/danielgtaylor/python-betterproto/pull/105)
- Fix serialisation of default values in oneofs when calling to_dict() or to_json() [#110](https://github.com/danielgtaylor/python-betterproto/pull/110)
- Fix static type checking for grpclib client [#124](https://github.com/danielgtaylor/python-betterproto/pull/124)
- Fix python3.6 compatibility issue with dataclasses [#124](https://github.com/danielgtaylor/python-betterproto/pull/124)
- Fix handling of trailer-only responses [#127](https://github.com/danielgtaylor/python-betterproto/pull/127)
- Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data [#121](https://github.com/danielgtaylor/python-betterproto/pull/121)
- Refactor template compilation logic [#136](https://github.com/danielgtaylor/python-betterproto/pull/136)
- Replace use of platform provided protoc with development dependency on grpcio-tools [#107](https://github.com/danielgtaylor/python-betterproto/pull/107)
- Switch to using `poe` from `make` to manage project development tasks [#118](https://github.com/danielgtaylor/python-betterproto/pull/118)
- Improve CI platform coverage [#128](https://github.com/danielgtaylor/python-betterproto/pull/128)
## [2.0.0b1] - 2020-07-04
[Upgrade Guide](./docs/upgrading.md)
> Several bugfixes and improvements required or will require small breaking changes, necessitating a new version.
> `2.0.0` will be released once the interface is stable.
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)
- Fix Import child package from package [#58](https://github.com/danielgtaylor/python-betterproto/issues/58)
- Fix Import parent package from child package [#59](https://github.com/danielgtaylor/python-betterproto/issues/59)
- Fix Import root package from child package [#60](https://github.com/danielgtaylor/python-betterproto/issues/60)
- Fix Import root package from root [#61](https://github.com/danielgtaylor/python-betterproto/issues/61)
- Fix ALL_CAPS message fields are parsed incorrectly. [#11](https://github.com/danielgtaylor/python-betterproto/issues/11)
## [1.2.5] - 2020-04-27 ## [1.2.5] - 2020-04-27
- Add .j2 suffix to python template names to avoid confusing certain build tools [#72](https://github.com/danielgtaylor/python-betterproto/pull/72) - Add .j2 suffix to python template names to avoid confusing certain build tools [#72](https://github.com/danielgtaylor/python-betterproto/pull/72)

42
Makefile Normal file
View File

@@ -0,0 +1,42 @@
.PHONY: help setup generate test types format clean plugin full-test check-style
help: ## - Show this help.
@fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//'
# Dev workflow tasks
generate: ## - Generate test cases (do this once before running test)
poetry run ./betterproto/tests/generate.py
test: ## - Run tests
poetry run pytest --cov betterproto
types: ## - Check types with mypy
poetry run mypy betterproto --ignore-missing-imports
format: ## - Apply black formatting to source code
poetry run black . --exclude tests/output_
clean: ## - Clean out generated files from the workspace
rm -rf .coverage \
.mypy_cache \
.pytest_cache \
dist \
**/__pycache__ \
betterproto/tests/output_*
# Manual testing
# By default write plugin output to a directory called output
o=output
plugin: ## - Execute the protoc plugin, with output write to `output` or the value passed to `-o`
mkdir -p $(o)
protoc --plugin=protoc-gen-custom=betterproto/plugin.py $(i) --custom_out=$(o)
# CI tasks
full-test: generate ## - Run full testing sequence with multiple pythons
poetry run tox
check-style: ## - Check if code style is correct
poetry run black . --check --diff --exclude tests/output_

135
README.md
View File

@@ -1,7 +1,6 @@
# Better Protobuf / gRPC Support for Python # Better Protobuf / gRPC Support for Python
![](https://github.com/danielgtaylor/python-betterproto/workflows/CI/badge.svg) ![](https://github.com/danielgtaylor/python-betterproto/workflows/CI/badge.svg)
> :octocat: If you're reading this on github, please be aware that it might mention unreleased features! See the latest released README on [pypi](https://pypi.org/project/betterproto/).
This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable, idiomatic Python code. It will not support legacy features or environments (e.g. Protobuf 2). The following are supported: This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable, idiomatic Python code. It will not support legacy features or environments (e.g. Protobuf 2). The following are supported:
@@ -38,9 +37,10 @@ This project exists because I am unhappy with the state of the official Google p
- Uses `SerializeToString()` rather than the built-in `__bytes__()` - Uses `SerializeToString()` rather than the built-in `__bytes__()`
- Special wrapped types don't use Python's `None` - Special wrapped types don't use Python's `None`
- Timestamp/duration types don't use Python's built-in `datetime` module - Timestamp/duration types don't use Python's built-in `datetime` module
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
## Installation ## Installation & Getting Started
First, install the package. Note that the `[compiler]` feature flag tells it to install extra dependencies only needed by the `protoc` plugin: First, install the package. Note that the `[compiler]` feature flag tells it to install extra dependencies only needed by the `protoc` plugin:
@@ -52,12 +52,6 @@ pip install "betterproto[compiler]"
pip install betterproto pip install betterproto
``` ```
*Betterproto* is under active development. To install the latest beta version, use `pip install --pre betterproto`.
## Getting Started
### Compiling proto files
Now, given you installed the compiler and have a proto file, e.g `example.proto`: Now, given you installed the compiler and have a proto file, e.g `example.proto`:
```protobuf ```protobuf
@@ -71,20 +65,13 @@ message Greeting {
} }
``` ```
You can run the following to invoke protoc directly: You can run the following:
```sh ```sh
mkdir lib mkdir lib
protoc -I . --python_betterproto_out=lib example.proto protoc -I . --python_betterproto_out=lib example.proto
``` ```
or run the following to invoke protoc via grpcio-tools:
```sh
pip install grpcio-tools
python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto
```
This will generate `lib/hello/__init__.py` which looks like: This will generate `lib/hello/__init__.py` which looks like:
```python ```python
@@ -133,7 +120,7 @@ Greeting(message="Hey!")
The generated Protobuf `Message` classes are compatible with [grpclib](https://github.com/vmagamedov/grpclib) so you are free to use it if you like. That said, this project also includes support for async gRPC stub generation with better static type checking and code completion support. It is enabled by default. The generated Protobuf `Message` classes are compatible with [grpclib](https://github.com/vmagamedov/grpclib) so you are free to use it if you like. That said, this project also includes support for async gRPC stub generation with better static type checking and code completion support. It is enabled by default.
Given an example service definition: Given an example like:
```protobuf ```protobuf
syntax = "proto3"; syntax = "proto3";
@@ -160,74 +147,22 @@ service Echo {
} }
``` ```
Generate echo proto file: You can use it like so (enable async in the interactive shell first):
``` ```py
python -m grpc_tools.protoc -I . --python_betterproto_out=. echo.proto >>> import echo
``` >>> from grpclib.client import Channel
A client can be implemented as follows: >>> channel = Channel(host="127.0.0.1", port=1234)
```python >>> service = echo.EchoStub(channel)
import asyncio >>> await service.echo(value="hello", extra_times=1)
import echo EchoResponse(values=["hello", "hello"])
from grpclib.client import Channel >>> async for response in service.echo_stream(value="hello", extra_times=1)
async def main():
channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1)
print(response) print(response)
async for response in service.echo_stream(value="hello", extra_times=1): EchoStreamResponse(value="hello")
print(response) EchoStreamResponse(value="hello")
# don't forget to close the channel when done!
channel.close()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
```
which would output
```python
EchoResponse(values=['hello', 'hello'])
EchoStreamResponse(value='hello')
EchoStreamResponse(value='hello')
```
This project also produces server-facing stubs that can be used to implement a Python
gRPC server.
To use them, simply subclass the base class in the generated files and override the
service methods:
```python
import asyncio
from echo import EchoBase, EchoResponse, EchoStreamResponse
from grpclib.server import Server
from typing import AsyncIterator
class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
return EchoResponse([value for _ in range(extra_times)])
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times):
yield EchoStreamResponse(value)
async def main():
server = Server([EchoService()])
await server.start("127.0.0.1", 50051)
await server.wait_closed()
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
``` ```
### JSON ### JSON
@@ -239,8 +174,8 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g: For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g:
```python ```py
MyMessage().to_dict(casing=betterproto.Casing.SNAKE) >>> MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
``` ```
### Determining if a message was sent ### Determining if a message was sent
@@ -363,31 +298,20 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
## Development ## Development
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_ Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
- _See how you can help → [Contributing](.github/CONTRIBUTING.md)_
### Requirements First, make sure you have Python 3.6+ and `poetry` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
- Python (3.6 or higher)
- [poetry](https://python-poetry.org/docs/#installation)
*Needed to install dependencies in a virtual environment*
- [poethepoet](https://github.com/nat-n/poethepoet) for running development tasks as defined in pyproject.toml
- Can be installed to your host environment via `pip install poethepoet` then executed as simple `poe`
- or run from the poetry venv as `poetry run poe`
### Setup
```sh ```sh
# Get set up with the virtual env & dependencies # Get set up with the virtual env & dependencies
poetry run pip install --upgrade pip
poetry install poetry install
# Activate the poetry environment # Activate the poetry environment
poetry shell poetry shell
``` ```
To benefit from the collection of standard development tasks ensure you have make installed and run `make help` to see available tasks.
### Code style ### Code style
This project enforces [black](https://github.com/psf/black) python code formatting. This project enforces [black](https://github.com/psf/black) python code formatting.
@@ -395,7 +319,7 @@ This project enforces [black](https://github.com/psf/black) python code formatti
Before committing changes run: Before committing changes run:
```sh ```sh
poe format make format
``` ```
To avoid merge conflicts later, non-black formatted python code will fail in CI. To avoid merge conflicts later, non-black formatted python code will fail in CI.
@@ -429,15 +353,15 @@ Here's how to run the tests.
```sh ```sh
# Generate assets from sample .proto files required by the tests # Generate assets from sample .proto files required by the tests
poe generate make generate
# Run the tests # Run the tests
poe test make test
``` ```
To run tests as they are run in CI (with tox) run: To run tests as they are run in CI (with tox) run:
```sh ```sh
poe full-test make full-test
``` ```
### (Re)compiling Google Well-known Types ### (Re)compiling Google Well-known Types
@@ -451,13 +375,14 @@ Assuming your `google.protobuf` source files (included with all releases of `pro
```sh ```sh
protoc \ protoc \
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py \ --plugin=protoc-gen-custom=betterproto/plugin.py \
--custom_opt=INCLUDE_GOOGLE \ --custom_opt=INCLUDE_GOOGLE \
--custom_out=src/betterproto/lib \ --custom_out=betterproto/lib \
-I /usr/local/include/ \ -I /usr/local/include/ \
/usr/local/include/google/protobuf/*.proto /usr/local/include/google/protobuf/*.proto
``` ```
### TODO ### TODO
- [x] Fixed length fields - [x] Fixed length fields
@@ -488,10 +413,10 @@ protoc \
- [x] Enum strings - [x] Enum strings
- [x] Well known types support (timestamp, duration, wrappers) - [x] Well known types support (timestamp, duration, wrappers)
- [x] Support different casing (orig vs. camel vs. others?) - [x] Support different casing (orig vs. camel vs. others?)
- [x] Async service stubs - [ ] Async service stubs
- [x] Unary-unary - [x] Unary-unary
- [x] Server streaming response - [x] Server streaming response
- [x] Client streaming request - [ ] Client streaming request
- [x] Renaming messages and fields to conform to Python name standards - [x] Renaming messages and fields to conform to Python name standards
- [x] Renaming clashes with language keywords - [x] Renaming clashes with language keywords
- [x] Python package - [x] Python package

View File

@@ -1,157 +0,0 @@
{
// The version of the config file format. Do not change, unless
// you know what you are doing.
"version": 1,
// The name of the project being benchmarked
"project": "python-betterproto",
// The project's homepage
"project_url": "https://github.com/danielgtaylor/python-betterproto",
// The URL or local path of the source code repository for the
// project being benchmarked
"repo": ".",
// The Python project's subdirectory in your repo. If missing or
// the empty string, the project is assumed to be located at the root
// of the repository.
// "repo_subdir": "",
// Customizable commands for building, installing, and
// uninstalling the project. See asv.conf.json documentation.
//
"install_command": ["python -m pip install ."],
"uninstall_command": ["return-code=any python -m pip uninstall -y {project}"],
"build_command": ["python -m pip wheel -w {build_cache_dir} {build_dir}"],
// List of branches to benchmark. If not provided, defaults to "master"
// (for git) or "default" (for mercurial).
// "branches": ["master"], // for git
// "branches": ["default"], // for mercurial
// The DVCS being used. If not set, it will be automatically
// determined from "repo" by looking at the protocol in the URL
// (if remote), or by looking for special directories, such as
// ".git" (if local).
// "dvcs": "git",
// The tool to use to create environments. May be "conda",
// "virtualenv" or other value depending on the plugins in use.
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"environment_type": "virtualenv",
// timeout in seconds for installing any dependencies in environment
// defaults to 10 min
//"install_timeout": 600,
// the base URL to show a commit for the project.
// "show_commit_url": "http://github.com/owner/project/commit/",
// The Pythons you'd like to test against. If not provided, defaults
// to the current version of Python used to run `asv`.
// "pythons": ["2.7", "3.6"],
// The list of conda channel names to be searched for benchmark
// dependency packages in the specified order
// "conda_channels": ["conda-forge", "defaults"],
// The matrix of dependencies to test. Each key is the name of a
// package (in PyPI) and the values are version numbers. An empty
// list or empty string indicates to just test against the default
// (latest) version. null indicates that the package is to not be
// installed. If the package to be tested is only available from
// PyPi, and the 'environment_type' is conda, then you can preface
// the package name by 'pip+', and the package will be installed via
// pip (with all the conda available packages installed first,
// followed by the pip installed packages).
//
// "matrix": {
// "numpy": ["1.6", "1.7"],
// "six": ["", null], // test with and without six installed
// "pip+emcee": [""], // emcee is only available for install with pip.
// },
// Combinations of libraries/python versions can be excluded/included
// from the set to test. Each entry is a dictionary containing additional
// key-value pairs to include/exclude.
//
// An exclude entry excludes entries where all values match. The
// values are regexps that should match the whole string.
//
// An include entry adds an environment. Only the packages listed
// are installed. The 'python' key is required. The exclude rules
// do not apply to includes.
//
// In addition to package names, the following keys are available:
//
// - python
// Python version, as in the *pythons* variable above.
// - environment_type
// Environment type, as above.
// - sys_platform
// Platform, as in sys.platform. Possible values for the common
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
//
// "exclude": [
// {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
// {"environment_type": "conda", "six": null}, // don't run without six on conda
// ],
//
// "include": [
// // additional env for python2.7
// {"python": "2.7", "numpy": "1.8"},
// // additional env if run on windows+conda
// {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""},
// ],
// The directory (relative to the current directory) that benchmarks are
// stored in. If not provided, defaults to "benchmarks"
// "benchmark_dir": "benchmarks",
// The directory (relative to the current directory) to cache the Python
// environments in. If not provided, defaults to "env"
"env_dir": ".asv/env",
// The directory (relative to the current directory) that raw benchmark
// results are stored in. If not provided, defaults to "results".
"results_dir": ".asv/results",
// The directory (relative to the current directory) that the html tree
// should be written to. If not provided, defaults to "html".
"html_dir": ".asv/html",
// The number of characters to retain in the commit hashes.
// "hash_length": 8,
// `asv` will cache results of the recent builds in each
// environment, making them faster to install next time. This is
// the number of builds to keep, per environment.
// "build_cache_size": 2,
// The commits after which the regression search in `asv publish`
// should start looking for regressions. Dictionary whose keys are
// regexps matching to benchmark names, and values corresponding to
// the commit (exclusive) after which to start looking for
// regressions. The default is to start from the first commit
// with results. If the commit is `null`, regression detection is
// skipped for the matching benchmark.
//
// "regressions_first_commits": {
// "some_benchmark": "352cdf", // Consider regressions only after this commit
// "another_benchmark": null, // Skip regression detection altogether
// },
// The thresholds for relative change in results, after which `asv
// publish` starts reporting regressions. Dictionary of the same
// form as in ``regressions_first_commits``, with values
// indicating the thresholds. If multiple entries match, the
// maximum is taken. If no entry matches, the default is 5%.
//
// "regressions_thresholds": {
// "some_benchmark": 0.01, // Threshold of 1%
// "another_benchmark": 0.5, // Threshold of 50%
// },
}

View File

@@ -1 +0,0 @@

View File

@@ -1,128 +0,0 @@
import betterproto
from dataclasses import dataclass
from typing import List
@dataclass
class TestMessage(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
@dataclass
class TestNestedChildMessage(betterproto.Message):
str_key: str = betterproto.string_field(0)
bytes_key: bytes = betterproto.bytes_field(1)
bool_key: bool = betterproto.bool_field(2)
float_key: float = betterproto.float_field(3)
int_key: int = betterproto.uint64_field(4)
@dataclass
class TestNestedMessage(betterproto.Message):
foo: TestNestedChildMessage = betterproto.message_field(0)
bar: TestNestedChildMessage = betterproto.message_field(1)
baz: TestNestedChildMessage = betterproto.message_field(2)
@dataclass
class TestRepeatedMessage(betterproto.Message):
foo_repeat: List[str] = betterproto.string_field(0)
bar_repeat: List[int] = betterproto.int64_field(1)
baz_repeat: List[bool] = betterproto.bool_field(2)
class BenchMessage:
"""Test creation and usage a proto message."""
def setup(self):
self.cls = TestMessage
self.instance = TestMessage()
self.instance_filled = TestMessage(0, "test", 0.0)
self.instance_filled_bytes = bytes(self.instance_filled)
self.instance_filled_nested = TestNestedMessage(
TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, -302),
TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300),
)
self.instance_filled_nested_bytes = bytes(self.instance_filled_nested)
self.instance_filled_repeated = TestRepeatedMessage(
[
"test1",
"test2",
"test3",
"test4",
"test5",
"test6",
"test7",
"test8",
"test9",
"test10",
],
[2, -100, 0, 500000, 600, -425678, 1000000000, -300, 1, -694214214466],
[True, False, False, False, True, True, False, True, False, False],
)
self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated)
def time_overhead(self):
"""Overhead in class definition."""
@dataclass
class Message(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
def time_instantiation(self):
"""Time instantiation"""
self.cls()
def time_attribute_access(self):
"""Time to access an attribute"""
self.instance.foo
self.instance.bar
self.instance.baz
def time_init_with_values(self):
"""Time to set an attribute"""
self.cls(0, "test", 0.0)
def time_attribute_setting(self):
"""Time to set attributes"""
self.instance.foo = 0
self.instance.bar = "test"
self.instance.baz = 0.0
def time_serialize(self):
"""Time serializing a message to wire."""
bytes(self.instance_filled)
def time_deserialize(self):
"""Time deserialize a message."""
TestMessage().parse(self.instance_filled_bytes)
def time_serialize_nested(self):
"""Time serializing a nested message to wire."""
bytes(self.instance_filled_nested)
def time_deserialize_nested(self):
"""Time deserialize a nested message."""
TestNestedMessage().parse(self.instance_filled_nested_bytes)
def time_serialize_repeated(self):
"""Time serializing a repeated message to wire."""
bytes(self.instance_filled_repeated)
def time_deserialize_repeated(self):
"""Time deserialize a repeated message."""
TestRepeatedMessage().parse(self.instance_filled_repeated_bytes)
class MemSuite:
def setup(self):
self.cls = TestMessage
def mem_instance(self):
return self.cls()

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib._typing import IProtoMessage
from . import Message from . import Message
from grpclib._protocols import IProtoMessage
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message") T = TypeVar("T", bound="Message")

View File

@@ -1,4 +1,3 @@
import keyword
import re import re
# Word delimiters and symbols that will not be preserved when re-casing. # Word delimiters and symbols that will not be preserved when re-casing.
@@ -17,28 +16,51 @@ WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
def safe_snake_case(value: str) -> str: def safe_snake_case(value: str) -> str:
"""Snake case a value taking into account Python keywords.""" """Snake case a value taking into account Python keywords."""
value = snake_case(value) value = snake_case(value)
value = sanitize_name(value) if value in [
"and",
"as",
"assert",
"break",
"class",
"continue",
"def",
"del",
"elif",
"else",
"except",
"finally",
"for",
"from",
"global",
"if",
"import",
"in",
"is",
"lambda",
"nonlocal",
"not",
"or",
"pass",
"raise",
"return",
"try",
"while",
"with",
"yield",
]:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
value += "_"
return value return value
def snake_case(value: str, strict: bool = True) -> str: def snake_case(value: str, strict: bool = True):
""" """
Join words with an underscore into lowercase and remove symbols. Join words with an underscore into lowercase and remove symbols.
@param value: value to convert
Parameters @param strict: force single underscores
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to force single underscores.
Returns
--------
:class:`str`
The value in snake_case.
""" """
def substitute_word(symbols: str, word: str, is_start: bool) -> str: def substitute_word(symbols, word, is_start):
if not word: if not word:
return "" return ""
if strict: if strict:
@@ -62,21 +84,11 @@ def snake_case(value: str, strict: bool = True) -> str:
return snake return snake
def pascal_case(value: str, strict: bool = True) -> str: def pascal_case(value: str, strict: bool = True):
""" """
Capitalize each word and remove symbols. Capitalize each word and remove symbols.
@param value: value to convert
Parameters @param strict: output only alphanumeric characters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to output only alphanumeric characters.
Returns
--------
:class:`str`
The value in PascalCase.
""" """
def substitute_word(symbols, word): def substitute_word(symbols, word):
@@ -97,42 +109,12 @@ def pascal_case(value: str, strict: bool = True) -> str:
) )
def camel_case(value: str, strict: bool = True) -> str: def camel_case(value: str, strict: bool = True):
""" """
Capitalize all words except first and remove symbols. Capitalize all words except first and remove symbols.
Parameters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to output only alphanumeric characters.
Returns
--------
:class:`str`
The value in camelCase.
""" """
return lowercase_first(pascal_case(value, strict=strict)) return lowercase_first(pascal_case(value, strict=strict))
def lowercase_first(value: str) -> str: def lowercase_first(value: str):
"""
Lower cases the first character of the value.
Parameters
----------
value: :class:`str`
The value to lower case.
Returns
-------
:class:`str`
The lower cased string.
"""
return value[0:1].lower() + value[1:] return value[0:1].lower() + value[1:]
def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if keyword.iskeyword(value) else value

View File

@@ -1,10 +1,10 @@
import os import os
import re import re
from typing import Dict, List, Set, Tuple, Type from typing import Dict, List, Set, Type
from ..casing import safe_snake_case from betterproto import safe_snake_case
from ..lib.google import protobuf as google_protobuf from betterproto.compile.naming import pythonize_class_name
from .naming import pythonize_class_name from betterproto.lib.google import protobuf as google_protobuf
WRAPPER_TYPES: Dict[str, Type] = { WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
@@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
} }
def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: def parse_source_type_name(field_type_name):
""" """
Split full source type name into package and type name. Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message') E.g. 'root.package.Message' -> ('root.package', 'Message')
@@ -36,7 +36,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
def get_type_reference( def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True package: str, imports: set, source_type: str, unwrap: bool = True,
) -> str: ) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
@@ -50,7 +50,7 @@ def get_type_reference(
if source_type == ".google.protobuf.Duration": if source_type == ".google.protobuf.Duration":
return "timedelta" return "timedelta"
elif source_type == ".google.protobuf.Timestamp": if source_type == ".google.protobuf.Timestamp":
return "datetime" return "datetime"
source_package, source_type = parse_source_type_name(source_type) source_package, source_type = parse_source_type_name(source_type)
@@ -79,14 +79,14 @@ def get_type_reference(
return reference_cousin(current_package, imports, py_package, py_type) return reference_cousin(current_package, imports, py_package, py_type)
def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str: def reference_absolute(imports, py_package, py_type):
""" """
Returns a reference to a python type located in the root, i.e. sys.path. Returns a reference to a python type located in the root, i.e. sys.path.
""" """
string_import = ".".join(py_package) string_import = ".".join(py_package)
string_alias = safe_snake_case(string_import) string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}") imports.add(f"import {string_import} as {string_alias}")
return f'"{string_alias}.{py_type}"' return f"{string_alias}.{py_type}"
def reference_sibling(py_type: str) -> str: def reference_sibling(py_type: str) -> str:
@@ -100,9 +100,8 @@ def reference_descendent(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str: ) -> str:
""" """
Returns a reference to a python type in a package that is a descendent of the Returns a reference to a python type in a package that is a descendent of the current package,
current package, and adds the required import that is aliased to avoid name and adds the required import that is aliased to avoid name conflicts.
conflicts.
""" """
importing_descendent = py_package[len(current_package) :] importing_descendent = py_package[len(current_package) :]
string_from = ".".join(importing_descendent[:-1]) string_from = ".".join(importing_descendent[:-1])
@@ -110,19 +109,18 @@ def reference_descendent(
if string_from: if string_from:
string_alias = "_".join(importing_descendent) string_alias = "_".join(importing_descendent)
imports.add(f"from .{string_from} import {string_import} as {string_alias}") imports.add(f"from .{string_from} import {string_import} as {string_alias}")
return f'"{string_alias}.{py_type}"' return f"{string_alias}.{py_type}"
else: else:
imports.add(f"from . import {string_import}") imports.add(f"from . import {string_import}")
return f'"{string_import}.{py_type}"' return f"{string_import}.{py_type}"
def reference_ancestor( def reference_ancestor(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str: ) -> str:
""" """
Returns a reference to a python type in a package which is an ancestor to the Returns a reference to a python type in a package which is an ancestor to the current package,
current package, and adds the required import that is aliased (if possible) to avoid and adds the required import that is aliased (if possible) to avoid name conflicts.
name conflicts.
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34). Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
""" """
@@ -132,21 +130,21 @@ def reference_ancestor(
string_alias = f"_{'_' * distance_up}{string_import}__" string_alias = f"_{'_' * distance_up}{string_import}__"
string_from = f"..{'.' * distance_up}" string_from = f"..{'.' * distance_up}"
imports.add(f"from {string_from} import {string_import} as {string_alias}") imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f'"{string_alias}.{py_type}"' return f"{string_alias}.{py_type}"
else: else:
string_alias = f"{'_' * distance_up}{py_type}__" string_alias = f"{'_' * distance_up}{py_type}__"
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
return f'"{string_alias}"' return string_alias
def reference_cousin( def reference_cousin(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str: ) -> str:
""" """
Returns a reference to a python type in a package that is not descendent, ancestor Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
or sibling, and adds the required import that is aliased to avoid name conflicts. and adds the required import that is aliased to avoid name conflicts.
""" """
shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore shared_ancestry = os.path.commonprefix([current_package, py_package])
distance_up = len(current_package) - len(shared_ancestry) distance_up = len(current_package) - len(shared_ancestry)
string_from = f".{'.' * distance_up}" + ".".join( string_from = f".{'.' * distance_up}" + ".".join(
py_package[len(shared_ancestry) : -1] py_package[len(shared_ancestry) : -1]
@@ -159,4 +157,4 @@ def reference_cousin(
+ "__" + "__"
) )
imports.add(f"from {string_from} import {string_import} as {string_alias}") imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f'"{string_alias}.{py_type}"' return f"{string_alias}.{py_type}"

View File

@@ -1,13 +1,13 @@
from betterproto import casing from betterproto import casing
def pythonize_class_name(name: str) -> str: def pythonize_class_name(name):
return casing.pascal_case(name) return casing.pascal_case(name)
def pythonize_field_name(name: str) -> str: def pythonize_field_name(name: str):
return casing.safe_snake_case(name) return casing.safe_snake_case(name)
def pythonize_method_name(name: str) -> str: def pythonize_method_name(name: str):
return casing.safe_snake_case(name) return casing.safe_snake_case(name)

View File

@@ -1,7 +1,8 @@
import asyncio
from abc import ABC from abc import ABC
import asyncio
import grpclib.const
from typing import ( from typing import (
TYPE_CHECKING, Any,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Collection, Collection,
@@ -9,23 +10,21 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
TYPE_CHECKING,
Type, Type,
Union, Union,
) )
import grpclib.const
from .._types import ST, T from .._types import ST, T
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib.client import Channel from grpclib._protocols import IProtoMessage
from grpclib.client import Channel, Stream
from grpclib.metadata import Deadline from grpclib.metadata import Deadline
_Value = Union[str, bytes] _Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
_MessageLike = Union[T, ST] _MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
class ServiceStub(ABC): class ServiceStub(ABC):
@@ -61,7 +60,7 @@ class ServiceStub(ABC):
async def _unary_unary( async def _unary_unary(
self, self,
route: str, route: str,
request: _MessageLike, request: "IProtoMessage",
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@@ -84,7 +83,7 @@ class ServiceStub(ABC):
async def _unary_stream( async def _unary_stream(
self, self,
route: str, route: str,
request: _MessageLike, request: "IProtoMessage",
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,

View File

@@ -1,5 +1,12 @@
import asyncio import asyncio
from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union from typing import (
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)
T = TypeVar("T") T = TypeVar("T")
@@ -9,53 +16,57 @@ class ChannelClosed(Exception):
An exception raised on an attempt to send through a closed channel An exception raised on an attempt to send through a closed channel
""" """
pass
class ChannelDone(Exception): class ChannelDone(Exception):
""" """
An exception raised on an attempt to send receive from a channel that is both closed An exception raised on an attempt to send recieve from a channel that is both closed
and empty. and empty.
""" """
pass
class AsyncChannel(AsyncIterable[T]): class AsyncChannel(AsyncIterable[T]):
""" """
A buffered async channel for sending items between coroutines with FIFO ordering. A buffered async channel for sending items between coroutines with FIFO ordering.
This makes decoupled bidirectional steaming gRPC requests easy if used like: This makes decoupled bidirection steaming gRPC requests easy if used like:
.. code-block:: python .. code-block:: python
client = GeneratedStub(grpclib_chan) client = GeneratedStub(grpclib_chan)
request_channel = await AsyncChannel() request_chan = await AsyncChannel()
# We can start be sending all the requests we already have # We can start be sending all the requests we already have
await request_channel.send_from([RequestObject(...), RequestObject(...)]) await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
async for response in client.rpc_call(request_channel): async for response in client.rpc_call(request_chan):
# The response iterator will remain active until the connection is closed # The response iterator will remain active until the connection is closed
... ...
# More items can be sent at any time # More items can be sent at any time
await request_channel.send(RequestObject(...)) await request_chan.send(ReqestObject(...))
... ...
# The channel must be closed to complete the gRPC connection # The channel must be closed to complete the gRPC connection
request_channel.close() request_chan.close()
Items can be sent through the channel by either: Items can be sent through the channel by either:
- providing an iterable to the send_from method - providing an iterable to the send_from method
- passing them to the send method one at a time - passing them to the send method one at a time
Items can be received from the channel by either: Items can be recieved from the channel by either:
- iterating over the channel with a for loop to get all items - iterating over the channel with a for loop to get all items
- calling the receive method to get one item at a time - calling the recieve method to get one item at a time
If the channel is empty then receivers will wait until either an item appears or the If the channel is empty then recievers will wait until either an item appears or the
channel is closed. channel is closed.
Once the channel is closed then subsequent attempt to send through the channel will Once the channel is closed then subsequent attempt to send through the channel will
fail with a ChannelClosed exception. fail with a ChannelClosed exception.
When th channel is closed and empty then it is done, and further attempts to receive When th channel is closed and empty then it is done, and further attempts to recieve
from it will fail with a ChannelDone exception from it will fail with a ChannelDone exception
If multiple coroutines receive from the channel concurrently, each item sent will be If multiple coroutines recieve from the channel concurrently, each item sent will be
received by only one of the receivers. recieved by only one of the recievers.
:param source: :param source:
An optional iterable will items that should be sent through the channel An optional iterable will items that should be sent through the channel
@@ -63,16 +74,18 @@ class AsyncChannel(AsyncIterable[T]):
:param buffer_limit: :param buffer_limit:
Limit the number of items that can be buffered in the channel, A value less than Limit the number of items that can be buffered in the channel, A value less than
1 implies no limit. If the channel is full then attempts to send more items will 1 implies no limit. If the channel is full then attempts to send more items will
result in the sender waiting until an item is received from the channel. result in the sender waiting until an item is recieved from the channel.
:param close: :param close:
If set to True then the channel will automatically close after exhausting source If set to True then the channel will automatically close after exhausting source
or immediately if no source is provided. or immediately if no source is provided.
""" """
def __init__(self, *, buffer_limit: int = 0, close: bool = False): def __init__(
self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit) self, *, buffer_limit: int = 0, close: bool = False,
):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False self._closed = False
self._waiting_receivers: int = 0 self._waiting_recievers: int = 0
# Track whether flush has been invoked so it can only happen once # Track whether flush has been invoked so it can only happen once
self._flushed = False self._flushed = False
@@ -82,14 +95,14 @@ class AsyncChannel(AsyncIterable[T]):
async def __anext__(self) -> T: async def __anext__(self) -> T:
if self.done(): if self.done():
raise StopAsyncIteration raise StopAsyncIteration
self._waiting_receivers += 1 self._waiting_recievers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
raise StopAsyncIteration raise StopAsyncIteration
return result return result
finally: finally:
self._waiting_receivers -= 1 self._waiting_recievers -= 1
self._queue.task_done() self._queue.task_done()
def closed(self) -> bool: def closed(self) -> bool:
@@ -103,12 +116,12 @@ class AsyncChannel(AsyncIterable[T]):
Check if this channel is done. Check if this channel is done.
:return: True if this channel is closed and and has been drained of items in :return: True if this channel is closed and and has been drained of items in
which case any further attempts to receive an item from this channel will raise which case any further attempts to recieve an item from this channel will raise
a ChannelDone exception. a ChannelDone exception.
""" """
# After close the channel is not yet done until there is at least one waiting # After close the channel is not yet done until there is at least one waiting
# receiver per enqueued item. # reciever per enqueued item.
return self._closed and self._queue.qsize() <= self._waiting_receivers return self._closed and self._queue.qsize() <= self._waiting_recievers
async def send_from( async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
@@ -145,22 +158,22 @@ class AsyncChannel(AsyncIterable[T]):
await self._queue.put(item) await self._queue.put(item)
return self return self
async def receive(self) -> Optional[T]: async def recieve(self) -> Optional[T]:
""" """
Returns the next item from this channel when it becomes available, Returns the next item from this channel when it becomes available,
or None if the channel is closed before another item is sent. or None if the channel is closed before another item is sent.
:return: An item from the channel :return: An item from the channel
""" """
if self.done(): if self.done():
raise ChannelDone("Cannot receive from a closed channel") raise ChannelDone("Cannot recieve from a closed channel")
self._waiting_receivers += 1 self._waiting_recievers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
return None return None
return result return result
finally: finally:
self._waiting_receivers -= 1 self._waiting_recievers -= 1
self._queue.task_done() self._queue.task_done()
def close(self): def close(self):
@@ -177,8 +190,8 @@ class AsyncChannel(AsyncIterable[T]):
""" """
if not self._flushed: if not self._flushed:
self._flushed = True self._flushed = True
deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize()) deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
for _ in range(deadlocked_receivers): for _ in range(deadlocked_recievers):
await self._queue.put(self.__flush) await self._queue.put(self.__flush)
# A special signal object for flushing the queue when the channel is closed # A special signal object for flushing the queue when the channel is closed

View File

@@ -1,12 +1,10 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto # sources: google/protobuf/any.proto, google/protobuf/source_context.proto, google/protobuf/type.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/wrappers.proto
# plugin: python-betterproto # plugin: python-betterproto
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List
import betterproto import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
class Syntax(betterproto.Enum): class Syntax(betterproto.Enum):
@@ -109,7 +107,7 @@ class NullValue(betterproto.Enum):
NULL_VALUE = 0 NULL_VALUE = 0
@dataclass(eq=False, repr=False) @dataclass
class Any(betterproto.Message): class Any(betterproto.Message):
""" """
`Any` contains an arbitrary serialized protocol buffer message along with a `Any` contains an arbitrary serialized protocol buffer message along with a
@@ -123,25 +121,24 @@ class Any(betterproto.Message):
Example 3: Pack and unpack a message in Python. foo = Foo(...) any Example 3: Pack and unpack a message in Python. foo = Foo(...) any
= Any() any.Pack(foo) ... if any.Is(Foo.DESCRIPTOR): = Any() any.Pack(foo) ... if any.Is(Foo.DESCRIPTOR):
any.Unpack(foo) ... Example 4: Pack and unpack a message in Go any.Unpack(foo) ... Example 4: Pack and unpack a message in Go
foo := &pb.Foo{...} any, err := anypb.New(foo) if err != nil { foo := &pb.Foo{...} any, err := ptypes.MarshalAny(foo) ...
... } ... foo := &pb.Foo{} if err := foo := &pb.Foo{} if err := ptypes.UnmarshalAny(any, foo); err != nil {
any.UnmarshalTo(foo); err != nil { ... } The pack methods ... } The pack methods provided by protobuf library will by default
provided by protobuf library will by default use use 'type.googleapis.com/full.type.name' as the type URL and the unpack
'type.googleapis.com/full.type.name' as the type URL and the unpack methods methods only use the fully qualified type name after the last '/' in the
only use the fully qualified type name after the last '/' in the type URL, type URL, for example "foo.bar.com/x/y.z" will yield type name "y.z". JSON
for example "foo.bar.com/x/y.z" will yield type name "y.z". JSON ==== The ==== The JSON representation of an `Any` value uses the regular
JSON representation of an `Any` value uses the regular representation of representation of the deserialized, embedded message, with an additional
the deserialized, embedded message, with an additional field `@type` which field `@type` which contains the type URL. Example: package
contains the type URL. Example: package google.profile; message google.profile; message Person { string first_name = 1;
Person { string first_name = 1; string last_name = 2; } string last_name = 2; } { "@type":
{ "@type": "type.googleapis.com/google.profile.Person", "type.googleapis.com/google.profile.Person", "firstName": <string>,
"firstName": <string>, "lastName": <string> } If the embedded "lastName": <string> } If the embedded message type is well-known and
message type is well-known and has a custom JSON representation, that has a custom JSON representation, that representation will be embedded
representation will be embedded adding a field `value` which holds the adding a field `value` which holds the custom JSON in addition to the
custom JSON in addition to the `@type` field. Example (for message `@type` field. Example (for message [google.protobuf.Duration][]): {
[google.protobuf.Duration][]): { "@type": "@type": "type.googleapis.com/google.protobuf.Duration", "value":
"type.googleapis.com/google.protobuf.Duration", "value": "1.212s" "1.212s" }
}
""" """
# A URL/resource name that uniquely identifies the type of the serialized # A URL/resource name that uniquely identifies the type of the serialized
@@ -168,7 +165,7 @@ class Any(betterproto.Message):
value: bytes = betterproto.bytes_field(2) value: bytes = betterproto.bytes_field(2)
@dataclass(eq=False, repr=False) @dataclass
class SourceContext(betterproto.Message): class SourceContext(betterproto.Message):
""" """
`SourceContext` represents information about the source of a protobuf `SourceContext` represents information about the source of a protobuf
@@ -180,7 +177,7 @@ class SourceContext(betterproto.Message):
file_name: str = betterproto.string_field(1) file_name: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False) @dataclass
class Type(betterproto.Message): class Type(betterproto.Message):
"""A protocol buffer message type.""" """A protocol buffer message type."""
@@ -198,7 +195,7 @@ class Type(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(6) syntax: "Syntax" = betterproto.enum_field(6)
@dataclass(eq=False, repr=False) @dataclass
class Field(betterproto.Message): class Field(betterproto.Message):
"""A single field of a message type.""" """A single field of a message type."""
@@ -226,7 +223,7 @@ class Field(betterproto.Message):
default_value: str = betterproto.string_field(11) default_value: str = betterproto.string_field(11)
@dataclass(eq=False, repr=False) @dataclass
class Enum(betterproto.Message): class Enum(betterproto.Message):
"""Enum type definition.""" """Enum type definition."""
@@ -244,7 +241,7 @@ class Enum(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(5) syntax: "Syntax" = betterproto.enum_field(5)
@dataclass(eq=False, repr=False) @dataclass
class EnumValue(betterproto.Message): class EnumValue(betterproto.Message):
"""Enum value definition.""" """Enum value definition."""
@@ -256,7 +253,7 @@ class EnumValue(betterproto.Message):
options: List["Option"] = betterproto.message_field(3) options: List["Option"] = betterproto.message_field(3)
@dataclass(eq=False, repr=False) @dataclass
class Option(betterproto.Message): class Option(betterproto.Message):
""" """
A protocol buffer option, which can be attached to a message, field, A protocol buffer option, which can be attached to a message, field,
@@ -275,7 +272,7 @@ class Option(betterproto.Message):
value: "Any" = betterproto.message_field(2) value: "Any" = betterproto.message_field(2)
@dataclass(eq=False, repr=False) @dataclass
class Api(betterproto.Message): class Api(betterproto.Message):
""" """
Api is a light-weight descriptor for an API Interface. Interfaces are also Api is a light-weight descriptor for an API Interface. Interfaces are also
@@ -318,7 +315,7 @@ class Api(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(7) syntax: "Syntax" = betterproto.enum_field(7)
@dataclass(eq=False, repr=False) @dataclass
class Method(betterproto.Message): class Method(betterproto.Message):
"""Method represents a method of an API interface.""" """Method represents a method of an API interface."""
@@ -338,7 +335,7 @@ class Method(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(7) syntax: "Syntax" = betterproto.enum_field(7)
@dataclass(eq=False, repr=False) @dataclass
class Mixin(betterproto.Message): class Mixin(betterproto.Message):
""" """
Declares an API Interface to be included in this interface. The including Declares an API Interface to be included in this interface. The including
@@ -363,7 +360,7 @@ class Mixin(betterproto.Message):
implies that all methods in `AccessControl` are also declared with same implies that all methods in `AccessControl` are also declared with same
name and request/response types in `Storage`. A documentation generator or name and request/response types in `Storage`. A documentation generator or
annotation processor will see the effective `Storage.GetAcl` method after annotation processor will see the effective `Storage.GetAcl` method after
inheriting documentation and annotations as follows: service Storage { inherting documentation and annotations as follows: service Storage {
// Get the underlying ACL object. rpc GetAcl(GetAclRequest) returns // Get the underlying ACL object. rpc GetAcl(GetAclRequest) returns
(Acl) { option (google.api.http).get = "/v2/{resource=**}:getAcl"; (Acl) { option (google.api.http).get = "/v2/{resource=**}:getAcl";
} ... } Note how the version in the path pattern changed from } ... } Note how the version in the path pattern changed from
@@ -383,7 +380,7 @@ class Mixin(betterproto.Message):
root: str = betterproto.string_field(2) root: str = betterproto.string_field(2)
@dataclass(eq=False, repr=False) @dataclass
class FileDescriptorSet(betterproto.Message): class FileDescriptorSet(betterproto.Message):
""" """
The protocol compiler can output a FileDescriptorSet containing the .proto The protocol compiler can output a FileDescriptorSet containing the .proto
@@ -393,7 +390,7 @@ class FileDescriptorSet(betterproto.Message):
file: List["FileDescriptorProto"] = betterproto.message_field(1) file: List["FileDescriptorProto"] = betterproto.message_field(1)
@dataclass(eq=False, repr=False) @dataclass
class FileDescriptorProto(betterproto.Message): class FileDescriptorProto(betterproto.Message):
"""Describes a complete .proto file.""" """Describes a complete .proto file."""
@@ -422,7 +419,7 @@ class FileDescriptorProto(betterproto.Message):
syntax: str = betterproto.string_field(12) syntax: str = betterproto.string_field(12)
@dataclass(eq=False, repr=False) @dataclass
class DescriptorProto(betterproto.Message): class DescriptorProto(betterproto.Message):
"""Describes a message type.""" """Describes a message type."""
@@ -442,14 +439,14 @@ class DescriptorProto(betterproto.Message):
reserved_name: List[str] = betterproto.string_field(10) reserved_name: List[str] = betterproto.string_field(10)
@dataclass(eq=False, repr=False) @dataclass
class DescriptorProtoExtensionRange(betterproto.Message): class DescriptorProtoExtensionRange(betterproto.Message):
start: int = betterproto.int32_field(1) start: int = betterproto.int32_field(1)
end: int = betterproto.int32_field(2) end: int = betterproto.int32_field(2)
options: "ExtensionRangeOptions" = betterproto.message_field(3) options: "ExtensionRangeOptions" = betterproto.message_field(3)
@dataclass(eq=False, repr=False) @dataclass
class DescriptorProtoReservedRange(betterproto.Message): class DescriptorProtoReservedRange(betterproto.Message):
""" """
Range of reserved tag numbers. Reserved tag numbers may not be used by Range of reserved tag numbers. Reserved tag numbers may not be used by
@@ -461,13 +458,13 @@ class DescriptorProtoReservedRange(betterproto.Message):
end: int = betterproto.int32_field(2) end: int = betterproto.int32_field(2)
@dataclass(eq=False, repr=False) @dataclass
class ExtensionRangeOptions(betterproto.Message): class ExtensionRangeOptions(betterproto.Message):
# The parser stores options it doesn't recognize here. See above. # The parser stores options it doesn't recognize here. See above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class FieldDescriptorProto(betterproto.Message): class FieldDescriptorProto(betterproto.Message):
"""Describes a field within a message.""" """Describes a field within a message."""
@@ -499,26 +496,9 @@ class FieldDescriptorProto(betterproto.Message):
# camelCase. # camelCase.
json_name: str = betterproto.string_field(10) json_name: str = betterproto.string_field(10)
options: "FieldOptions" = betterproto.message_field(8) options: "FieldOptions" = betterproto.message_field(8)
# If true, this is a proto3 "optional". When a proto3 field is optional, it
# tracks presence regardless of field type. When proto3_optional is true,
# this field must be belong to a oneof to signal to old proto3 clients that
# presence is tracked for this field. This oneof is known as a "synthetic"
# oneof, and this field must be its sole member (each proto3 optional field
# gets its own synthetic oneof). Synthetic oneofs exist in the descriptor
# only, and do not generate any API. Synthetic oneofs must be ordered after
# all "real" oneofs. For message fields, proto3_optional doesn't create any
# semantic change, since non-repeated message fields always track presence.
# However it still indicates the semantic detail of whether the user wrote
# "optional" or not. This can be useful for round-tripping the .proto file.
# For consistency we give message fields a synthetic oneof also, even though
# it is not required to track presence. This is especially important because
# the parser can't tell if a field is a message or an enum, so it must always
# create a synthetic oneof. Proto2 optional fields do not set this flag,
# because they already indicate optional with `LABEL_OPTIONAL`.
proto3_optional: bool = betterproto.bool_field(17)
@dataclass(eq=False, repr=False) @dataclass
class OneofDescriptorProto(betterproto.Message): class OneofDescriptorProto(betterproto.Message):
"""Describes a oneof.""" """Describes a oneof."""
@@ -526,12 +506,14 @@ class OneofDescriptorProto(betterproto.Message):
options: "OneofOptions" = betterproto.message_field(2) options: "OneofOptions" = betterproto.message_field(2)
@dataclass(eq=False, repr=False) @dataclass
class EnumDescriptorProto(betterproto.Message): class EnumDescriptorProto(betterproto.Message):
"""Describes an enum type.""" """Describes an enum type."""
name: str = betterproto.string_field(1) name: str = betterproto.string_field(1)
value: List["EnumValueDescriptorProto"] = betterproto.message_field(2) value: List["EnumValueDescriptorProto"] = betterproto.message_field(
2, wraps=betterproto.TYPE_ENUM
)
options: "EnumOptions" = betterproto.message_field(3) options: "EnumOptions" = betterproto.message_field(3)
# Range of reserved numeric values. Reserved numeric values may not be used # Range of reserved numeric values. Reserved numeric values may not be used
# by enum values in the same enum declaration. Reserved ranges may not # by enum values in the same enum declaration. Reserved ranges may not
@@ -544,7 +526,7 @@ class EnumDescriptorProto(betterproto.Message):
reserved_name: List[str] = betterproto.string_field(5) reserved_name: List[str] = betterproto.string_field(5)
@dataclass(eq=False, repr=False) @dataclass
class EnumDescriptorProtoEnumReservedRange(betterproto.Message): class EnumDescriptorProtoEnumReservedRange(betterproto.Message):
""" """
Range of reserved numeric values. Reserved values may not be used by Range of reserved numeric values. Reserved values may not be used by
@@ -557,16 +539,18 @@ class EnumDescriptorProtoEnumReservedRange(betterproto.Message):
end: int = betterproto.int32_field(2) end: int = betterproto.int32_field(2)
@dataclass(eq=False, repr=False) @dataclass
class EnumValueDescriptorProto(betterproto.Message): class EnumValueDescriptorProto(betterproto.Message):
"""Describes a value within an enum.""" """Describes a value within an enum."""
name: str = betterproto.string_field(1) name: str = betterproto.string_field(1)
number: int = betterproto.int32_field(2) number: int = betterproto.int32_field(2)
options: "EnumValueOptions" = betterproto.message_field(3) options: "EnumValueOptions" = betterproto.message_field(
3, wraps=betterproto.TYPE_ENUM
)
@dataclass(eq=False, repr=False) @dataclass
class ServiceDescriptorProto(betterproto.Message): class ServiceDescriptorProto(betterproto.Message):
"""Describes a service.""" """Describes a service."""
@@ -575,7 +559,7 @@ class ServiceDescriptorProto(betterproto.Message):
options: "ServiceOptions" = betterproto.message_field(3) options: "ServiceOptions" = betterproto.message_field(3)
@dataclass(eq=False, repr=False) @dataclass
class MethodDescriptorProto(betterproto.Message): class MethodDescriptorProto(betterproto.Message):
"""Describes a method of a service.""" """Describes a method of a service."""
@@ -591,25 +575,24 @@ class MethodDescriptorProto(betterproto.Message):
server_streaming: bool = betterproto.bool_field(6) server_streaming: bool = betterproto.bool_field(6)
@dataclass(eq=False, repr=False) @dataclass
class FileOptions(betterproto.Message): class FileOptions(betterproto.Message):
# Sets the Java package where classes generated from this .proto will be # Sets the Java package where classes generated from this .proto will be
# placed. By default, the proto package is used, but this is often # placed. By default, the proto package is used, but this is often
# inappropriate because proto packages do not normally start with backwards # inappropriate because proto packages do not normally start with backwards
# domain names. # domain names.
java_package: str = betterproto.string_field(1) java_package: str = betterproto.string_field(1)
# Controls the name of the wrapper Java class generated for the .proto file. # If set, all the classes from the .proto file are wrapped in a single outer
# That class will always contain the .proto file's getDescriptor() method as # class with the given name. This applies to both Proto1 (equivalent to the
# well as any top-level extensions defined in the .proto file. If # old "--one_java_file" option) and Proto2 (where a .proto always translates
# java_multiple_files is disabled, then all the other classes from the .proto # to a single class, but you may want to explicitly choose the class name).
# file will be nested inside the single wrapper outer class.
java_outer_classname: str = betterproto.string_field(8) java_outer_classname: str = betterproto.string_field(8)
# If enabled, then the Java code generator will generate a separate .java # If set true, then the Java code generator will generate a separate .java
# file for each top-level message, enum, and service defined in the .proto # file for each top-level message, enum, and service defined in the .proto
# file. Thus, these types will *not* be nested inside the wrapper class # file. Thus, these types will *not* be nested inside the outer class named
# named by java_outer_classname. However, the wrapper class will still be # by java_outer_classname. However, the outer class will still be generated
# generated to contain the file's getDescriptor() method as well as any top- # to contain the file's getDescriptor() method as well as any top-level
# level extensions defined in the file. # extensions defined in the file.
java_multiple_files: bool = betterproto.bool_field(10) java_multiple_files: bool = betterproto.bool_field(10)
# This option does nothing. # This option does nothing.
java_generate_equals_and_hash: bool = betterproto.bool_field(20) java_generate_equals_and_hash: bool = betterproto.bool_field(20)
@@ -674,16 +657,8 @@ class FileOptions(betterproto.Message):
# for the "Options" section above. # for the "Options" section above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
def __post_init__(self) -> None:
super().__post_init__()
if self.java_generate_equals_and_hash:
warnings.warn(
"FileOptions.java_generate_equals_and_hash is deprecated",
DeprecationWarning,
)
@dataclass
@dataclass(eq=False, repr=False)
class MessageOptions(betterproto.Message): class MessageOptions(betterproto.Message):
# Set true to use the old proto1 MessageSet wire format for extensions. This # Set true to use the old proto1 MessageSet wire format for extensions. This
# is provided for backwards-compatibility with the MessageSet wire format. # is provided for backwards-compatibility with the MessageSet wire format.
@@ -720,7 +695,7 @@ class MessageOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class FieldOptions(betterproto.Message): class FieldOptions(betterproto.Message):
# The ctype option instructs the C++ code generator to use a different # The ctype option instructs the C++ code generator to use a different
# representation of the field than it normally would. See the specific # representation of the field than it normally would. See the specific
@@ -777,13 +752,13 @@ class FieldOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class OneofOptions(betterproto.Message): class OneofOptions(betterproto.Message):
# The parser stores options it doesn't recognize here. See above. # The parser stores options it doesn't recognize here. See above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class EnumOptions(betterproto.Message): class EnumOptions(betterproto.Message):
# Set this option to true to allow mapping different tag names to the same # Set this option to true to allow mapping different tag names to the same
# value. # value.
@@ -796,7 +771,7 @@ class EnumOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class EnumValueOptions(betterproto.Message): class EnumValueOptions(betterproto.Message):
# Is this enum value deprecated? Depending on the target platform, this can # Is this enum value deprecated? Depending on the target platform, this can
# emit Deprecated annotations for the enum value, or it will be completely # emit Deprecated annotations for the enum value, or it will be completely
@@ -807,7 +782,7 @@ class EnumValueOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class ServiceOptions(betterproto.Message): class ServiceOptions(betterproto.Message):
# Is this service deprecated? Depending on the target platform, this can emit # Is this service deprecated? Depending on the target platform, this can emit
# Deprecated annotations for the service, or it will be completely ignored; # Deprecated annotations for the service, or it will be completely ignored;
@@ -817,7 +792,7 @@ class ServiceOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class MethodOptions(betterproto.Message): class MethodOptions(betterproto.Message):
# Is this method deprecated? Depending on the target platform, this can emit # Is this method deprecated? Depending on the target platform, this can emit
# Deprecated annotations for the method, or it will be completely ignored; in # Deprecated annotations for the method, or it will be completely ignored; in
@@ -828,7 +803,7 @@ class MethodOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass(eq=False, repr=False) @dataclass
class UninterpretedOption(betterproto.Message): class UninterpretedOption(betterproto.Message):
""" """
A message representing a option the parser does not recognize. This only A message representing a option the parser does not recognize. This only
@@ -850,7 +825,7 @@ class UninterpretedOption(betterproto.Message):
aggregate_value: str = betterproto.string_field(8) aggregate_value: str = betterproto.string_field(8)
@dataclass(eq=False, repr=False) @dataclass
class UninterpretedOptionNamePart(betterproto.Message): class UninterpretedOptionNamePart(betterproto.Message):
""" """
The name of the uninterpreted option. Each string represents a segment in The name of the uninterpreted option. Each string represents a segment in
@@ -864,7 +839,7 @@ class UninterpretedOptionNamePart(betterproto.Message):
is_extension: bool = betterproto.bool_field(2) is_extension: bool = betterproto.bool_field(2)
@dataclass(eq=False, repr=False) @dataclass
class SourceCodeInfo(betterproto.Message): class SourceCodeInfo(betterproto.Message):
""" """
Encapsulates information about the original source file from which a Encapsulates information about the original source file from which a
@@ -903,7 +878,7 @@ class SourceCodeInfo(betterproto.Message):
location: List["SourceCodeInfoLocation"] = betterproto.message_field(1) location: List["SourceCodeInfoLocation"] = betterproto.message_field(1)
@dataclass(eq=False, repr=False) @dataclass
class SourceCodeInfoLocation(betterproto.Message): class SourceCodeInfoLocation(betterproto.Message):
# Identifies which part of the FileDescriptorProto was defined at this # Identifies which part of the FileDescriptorProto was defined at this
# location. Each element is a field number or an index. They form a path # location. Each element is a field number or an index. They form a path
@@ -950,7 +925,7 @@ class SourceCodeInfoLocation(betterproto.Message):
leading_detached_comments: List[str] = betterproto.string_field(6) leading_detached_comments: List[str] = betterproto.string_field(6)
@dataclass(eq=False, repr=False) @dataclass
class GeneratedCodeInfo(betterproto.Message): class GeneratedCodeInfo(betterproto.Message):
""" """
Describes the relationship between generated code and its original source Describes the relationship between generated code and its original source
@@ -963,7 +938,7 @@ class GeneratedCodeInfo(betterproto.Message):
annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field(1) annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field(1)
@dataclass(eq=False, repr=False) @dataclass
class GeneratedCodeInfoAnnotation(betterproto.Message): class GeneratedCodeInfoAnnotation(betterproto.Message):
# Identifies the element in the original source .proto file. This field is # Identifies the element in the original source .proto file. This field is
# formatted the same as SourceCodeInfo.Location.path. # formatted the same as SourceCodeInfo.Location.path.
@@ -979,7 +954,7 @@ class GeneratedCodeInfoAnnotation(betterproto.Message):
end: int = betterproto.int32_field(4) end: int = betterproto.int32_field(4)
@dataclass(eq=False, repr=False) @dataclass
class Duration(betterproto.Message): class Duration(betterproto.Message):
""" """
A Duration represents a signed, fixed-length span of time represented as a A Duration represents a signed, fixed-length span of time represented as a
@@ -1024,7 +999,7 @@ class Duration(betterproto.Message):
nanos: int = betterproto.int32_field(2) nanos: int = betterproto.int32_field(2)
@dataclass(eq=False, repr=False) @dataclass
class Empty(betterproto.Message): class Empty(betterproto.Message):
""" """
A generic empty message that you can re-use to avoid defining duplicated A generic empty message that you can re-use to avoid defining duplicated
@@ -1037,7 +1012,7 @@ class Empty(betterproto.Message):
pass pass
@dataclass(eq=False, repr=False) @dataclass
class FieldMask(betterproto.Message): class FieldMask(betterproto.Message):
""" """
`FieldMask` represents a set of symbolic field paths, for example: `FieldMask` represents a set of symbolic field paths, for example:
@@ -1121,7 +1096,7 @@ class FieldMask(betterproto.Message):
paths: List[str] = betterproto.string_field(1) paths: List[str] = betterproto.string_field(1)
@dataclass(eq=False, repr=False) @dataclass
class Struct(betterproto.Message): class Struct(betterproto.Message):
""" """
`Struct` represents a structured data value, consisting of fields which map `Struct` represents a structured data value, consisting of fields which map
@@ -1138,7 +1113,7 @@ class Struct(betterproto.Message):
) )
@dataclass(eq=False, repr=False) @dataclass
class Value(betterproto.Message): class Value(betterproto.Message):
""" """
`Value` represents a dynamically typed value which can be either null, a `Value` represents a dynamically typed value which can be either null, a
@@ -1162,7 +1137,7 @@ class Value(betterproto.Message):
list_value: "ListValue" = betterproto.message_field(6, group="kind") list_value: "ListValue" = betterproto.message_field(6, group="kind")
@dataclass(eq=False, repr=False) @dataclass
class ListValue(betterproto.Message): class ListValue(betterproto.Message):
""" """
`ListValue` is a wrapper around a repeated field of values. The JSON `ListValue` is a wrapper around a repeated field of values. The JSON
@@ -1173,7 +1148,7 @@ class ListValue(betterproto.Message):
values: List["Value"] = betterproto.message_field(1) values: List["Value"] = betterproto.message_field(1)
@dataclass(eq=False, repr=False) @dataclass
class Timestamp(betterproto.Message): class Timestamp(betterproto.Message):
""" """
A Timestamp represents a point in time independent of any time zone or A Timestamp represents a point in time independent of any time zone or
@@ -1203,22 +1178,20 @@ class Timestamp(betterproto.Message):
long millis = System.currentTimeMillis(); Timestamp timestamp = long millis = System.currentTimeMillis(); Timestamp timestamp =
Timestamp.newBuilder().setSeconds(millis / 1000) .setNanos((int) Timestamp.newBuilder().setSeconds(millis / 1000) .setNanos((int)
((millis % 1000) * 1000000)).build(); Example 5: Compute Timestamp from ((millis % 1000) * 1000000)).build(); Example 5: Compute Timestamp from
Java `Instant.now()`. Instant now = Instant.now(); Timestamp current time in Python. timestamp = Timestamp()
timestamp = Timestamp.newBuilder().setSeconds(now.getEpochSecond()) timestamp.GetCurrentTime() # JSON Mapping In JSON format, the Timestamp
.setNanos(now.getNano()).build(); Example 6: Compute Timestamp from current type is encoded as a string in the [RFC
time in Python. timestamp = Timestamp() timestamp.GetCurrentTime() 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the format is
# JSON Mapping In JSON format, the Timestamp type is encoded as a string in "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" where {year} is
the [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the always expressed using four digits while {month}, {day}, {hour}, {min}, and
format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" where {sec} are zero-padded to two digits each. The fractional seconds, which can
{year} is always expressed using four digits while {month}, {day}, {hour}, go up to 9 digits (i.e. up to 1 nanosecond resolution), are optional. The
{min}, and {sec} are zero-padded to two digits each. The fractional "Z" suffix indicates the timezone ("UTC"); the timezone is required. A
seconds, which can go up to 9 digits (i.e. up to 1 nanosecond resolution), proto3 JSON serializer should always use UTC (as indicated by "Z") when
are optional. The "Z" suffix indicates the timezone ("UTC"); the timezone printing the Timestamp type and a proto3 JSON parser should be able to
is required. A proto3 JSON serializer should always use UTC (as indicated accept both UTC and other timezones (as indicated by an offset). For
by "Z") when printing the Timestamp type and a proto3 JSON parser should be example, "2017-01-15T01:30:15.01Z" encodes 15.01 seconds past 01:30 UTC on
able to accept both UTC and other timezones (as indicated by an offset). January 15, 2017. In JavaScript, one can convert a Date object to this
For example, "2017-01-15T01:30:15.01Z" encodes 15.01 seconds past 01:30 UTC
on January 15, 2017. In JavaScript, one can convert a Date object to this
format using the standard [toISOString()](https://developer.mozilla.org/en- format using the standard [toISOString()](https://developer.mozilla.org/en-
US/docs/Web/JavaScript/Reference/Global_Objects/Date/toISOString) method. US/docs/Web/JavaScript/Reference/Global_Objects/Date/toISOString) method.
In Python, a standard `datetime.datetime` object can be converted to this In Python, a standard `datetime.datetime` object can be converted to this
@@ -1240,7 +1213,7 @@ class Timestamp(betterproto.Message):
nanos: int = betterproto.int32_field(2) nanos: int = betterproto.int32_field(2)
@dataclass(eq=False, repr=False) @dataclass
class DoubleValue(betterproto.Message): class DoubleValue(betterproto.Message):
""" """
Wrapper message for `double`. The JSON representation for `DoubleValue` is Wrapper message for `double`. The JSON representation for `DoubleValue` is
@@ -1251,7 +1224,7 @@ class DoubleValue(betterproto.Message):
value: float = betterproto.double_field(1) value: float = betterproto.double_field(1)
@dataclass(eq=False, repr=False) @dataclass
class FloatValue(betterproto.Message): class FloatValue(betterproto.Message):
""" """
Wrapper message for `float`. The JSON representation for `FloatValue` is Wrapper message for `float`. The JSON representation for `FloatValue` is
@@ -1262,7 +1235,7 @@ class FloatValue(betterproto.Message):
value: float = betterproto.float_field(1) value: float = betterproto.float_field(1)
@dataclass(eq=False, repr=False) @dataclass
class Int64Value(betterproto.Message): class Int64Value(betterproto.Message):
""" """
Wrapper message for `int64`. The JSON representation for `Int64Value` is Wrapper message for `int64`. The JSON representation for `Int64Value` is
@@ -1273,7 +1246,7 @@ class Int64Value(betterproto.Message):
value: int = betterproto.int64_field(1) value: int = betterproto.int64_field(1)
@dataclass(eq=False, repr=False) @dataclass
class UInt64Value(betterproto.Message): class UInt64Value(betterproto.Message):
""" """
Wrapper message for `uint64`. The JSON representation for `UInt64Value` is Wrapper message for `uint64`. The JSON representation for `UInt64Value` is
@@ -1284,7 +1257,7 @@ class UInt64Value(betterproto.Message):
value: int = betterproto.uint64_field(1) value: int = betterproto.uint64_field(1)
@dataclass(eq=False, repr=False) @dataclass
class Int32Value(betterproto.Message): class Int32Value(betterproto.Message):
""" """
Wrapper message for `int32`. The JSON representation for `Int32Value` is Wrapper message for `int32`. The JSON representation for `Int32Value` is
@@ -1295,7 +1268,7 @@ class Int32Value(betterproto.Message):
value: int = betterproto.int32_field(1) value: int = betterproto.int32_field(1)
@dataclass(eq=False, repr=False) @dataclass
class UInt32Value(betterproto.Message): class UInt32Value(betterproto.Message):
""" """
Wrapper message for `uint32`. The JSON representation for `UInt32Value` is Wrapper message for `uint32`. The JSON representation for `UInt32Value` is
@@ -1306,7 +1279,7 @@ class UInt32Value(betterproto.Message):
value: int = betterproto.uint32_field(1) value: int = betterproto.uint32_field(1)
@dataclass(eq=False, repr=False) @dataclass
class BoolValue(betterproto.Message): class BoolValue(betterproto.Message):
""" """
Wrapper message for `bool`. The JSON representation for `BoolValue` is JSON Wrapper message for `bool`. The JSON representation for `BoolValue` is JSON
@@ -1317,7 +1290,7 @@ class BoolValue(betterproto.Message):
value: bool = betterproto.bool_field(1) value: bool = betterproto.bool_field(1)
@dataclass(eq=False, repr=False) @dataclass
class StringValue(betterproto.Message): class StringValue(betterproto.Message):
""" """
Wrapper message for `string`. The JSON representation for `StringValue` is Wrapper message for `string`. The JSON representation for `StringValue` is
@@ -1328,7 +1301,7 @@ class StringValue(betterproto.Message):
value: str = betterproto.string_field(1) value: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False) @dataclass
class BytesValue(betterproto.Message): class BytesValue(betterproto.Message):
""" """
Wrapper message for `bytes`. The JSON representation for `BytesValue` is Wrapper message for `bytes`. The JSON representation for `BytesValue` is

2
betterproto/plugin.bat Normal file
View File

@@ -0,0 +1,2 @@
@SET plugin_dir=%~dp0
@python %plugin_dir%/plugin.py %*

403
betterproto/plugin.py Executable file
View File

@@ -0,0 +1,403 @@
#!/usr/bin/env python
import itertools
import os.path
import pathlib
import re
import sys
import textwrap
from typing import List, Union
import betterproto
from betterproto.compile.importing import get_type_reference
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
import black
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
)
import google.protobuf.wrappers_pb2 as google_wrappers
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
return "float"
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
return "int"
elif field.type == 8:
return "bool"
elif field.type == 9:
return "str"
elif field.type in [11, 14]:
# Type referencing another defined Message or a named enum
return get_type_reference(package, imports, field.type_name)
elif field.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {field.type}")
def get_py_zero(type_num: int) -> Union[str, float]:
zero: Union[str, float] = 0
if type_num in []:
zero = 0.0
elif type_num == 8:
zero = "False"
elif type_num == 9:
zero = '""'
elif type_num == 11:
zero = "None"
elif type_num == 12:
zero = 'b""'
return zero
def traverse(proto_file):
def _traverse(path, items, prefix=""):
for i, item in enumerate(items):
# Adjust the name since we flatten the heirarchy.
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return ""
def generate_code(request, response):
plugin_options = request.parameter.split(",") if request.parameter else []
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
)
template = env.get_template("template.py.j2")
output_map = {}
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
continue
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
if output_file not in output_map:
output_map[output_file] = {"package": proto_file.package, "files": []}
output_map[output_file]["files"].append(proto_file)
# TODO: Figure out how to handle gRPC request/response messages and add
# processing below for Service.
for filename, options in output_map.items():
package = options["package"]
# print(package, filename, file=sys.stderr)
output = {
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
"services": [],
}
for proto_file in options["files"]:
item: DescriptorProto
for item, path in traverse(proto_file):
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
continue
data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
)
for i, f in enumerate(item.field):
t = py_type(package, output["imports"], f)
zero = get_py_zero(f.type)
repeated = False
packed = False
field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", f.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in item.nested_type:
if (
nested.name.replace("_", "").lower()
== map_entry
):
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
package,
output["imports"],
nested.field[0],
)
v = py_type(
package,
output["imports"],
nested.field[1],
)
t = f"Dict[{k}, {v}]"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
output["typing_imports"].add("Dict")
if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
output["typing_imports"].add("List")
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True
one_of = ""
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
output["typing_imports"].add("Optional")
if "timedelta" in t:
output["datetime_imports"].add("timedelta")
elif "datetime" in t:
output["datetime_imports"].add("datetime")
data["properties"].append(
{
"name": f.name,
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
"one_of": one_of,
}
)
# print(f, file=sys.stderr)
output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)
output["enums"].append(data)
for i, service in enumerate(proto_file.service):
# print(service, file=sys.stderr)
data = {
"name": service.name,
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, i]),
"methods": [],
}
for j, method in enumerate(service.method):
input_message = None
input_type = get_type_reference(
package, output["imports"], method.input_type
).strip('"')
for msg in output["messages"]:
if msg["name"] == input_type:
input_message = msg
for field in msg["properties"]:
if field["zero"] == "None":
output["typing_imports"].add("Optional")
break
data["methods"].append(
{
"name": method.name,
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_type_reference(
package, output["imports"], method.input_type
).strip('"'),
"input_message": input_message,
"output": get_type_reference(
package,
output["imports"],
method.output_type,
unwrap=False,
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
}
)
if method.client_streaming:
output["typing_imports"].add("AsyncIterable")
output["typing_imports"].add("Iterable")
output["typing_imports"].add("Union")
if method.server_streaming:
output["typing_imports"].add("AsyncIterator")
output["services"].append(data)
output["imports"] = sorted(output["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"])
# Fill response
f = response.file.add()
f.name = filename
# Render and then format the output file.
f.content = black.format_str(
template.render(description=output),
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
)
# Make each output directory a package with __init__ file
output_paths = set(pathlib.Path(path) for path in output_map.keys())
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
for filename in sorted(output_paths.union(init_files)):
print(f"Writing {filename}", file=sys.stderr)
def main():
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,135 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto
from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if description.services %}
import grpclib
{% endif %}
{% for i in description.imports %}
{{ i }}
{% endfor %}
{% if description.enums %}{% for enum in description.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
{% endif %}
{% for entry in enum.entries %}
{% if entry.comment %}
{{ entry.comment }}
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in description.messages %}
@dataclass
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
{% endif %}
{% for field in message.properties %}
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %}
{% if not message.properties %}
pass
{% endif %}
{% endfor %}
{% for service in description.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.input_message and method.input_message.properties -%}, *,
{%- for field in method.input_message.properties -%}
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
Optional[{{ field.type }}]
{%- else -%}
{{ field.type }}
{%- endif -%} = {{ field.zero }}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{% if not method.client_streaming %}
request = {{ method.input }}()
{% for field in method.input_message.properties %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }},
):
yield response
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.output }},
):
yield response
{% endif %}{# if client streaming #}
{% else %}{# i.e. not server streaming #}
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.output }}
)
{% endif %}{# client streaming #}
{% endif %}
{% endfor %}
{% endfor %}

View File

@@ -50,7 +50,7 @@ You can add multiple `.proto` files to the test case, as long as one file matche
`test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class* `test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class*
```python ```python
from tests.output_betterproto.bool.bool import Test from betterproto.tests.output_betterproto.bool.bool import Test
def test_value(): def test_value():
message = Test() message = Test()

View File

@@ -2,17 +2,18 @@
import asyncio import asyncio
import os import os
from pathlib import Path from pathlib import Path
import platform
import shutil import shutil
import subprocess
import sys import sys
from typing import Set from typing import Set
from tests.util import ( from betterproto.tests.util import (
get_directories, get_directories,
inputs_path, inputs_path,
output_path_betterproto, output_path_betterproto,
output_path_reference, output_path_reference,
protoc, protoc_plugin,
protoc_reference,
) )
# Force pure-python implementation instead of C++, otherwise imports # Force pure-python implementation instead of C++, otherwise imports
@@ -60,15 +61,13 @@ async def generate(whitelist: Set[str], verbose: bool):
if result != 0: if result != 0:
failed_test_cases.append(test_case_name) failed_test_cases.append(test_case_name)
if len(failed_test_cases) > 0: if failed_test_cases:
sys.stderr.write( sys.stderr.write(
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
) )
for failed_test_case in failed_test_cases: for failed_test_case in failed_test_cases:
sys.stderr.write(f"- {failed_test_case}\n") sys.stderr.write(f"- {failed_test_case}\n")
sys.exit(1)
async def generate_test_case_output( async def generate_test_case_output(
test_case_input_path: Path, test_case_name: str, verbose: bool test_case_input_path: Path, test_case_name: str, verbose: bool
@@ -90,45 +89,25 @@ async def generate_test_case_output(
(ref_out, ref_err, ref_code), (ref_out, ref_err, ref_code),
(plg_out, plg_err, plg_code), (plg_out, plg_err, plg_code),
) = await asyncio.gather( ) = await asyncio.gather(
protoc(test_case_input_path, test_case_output_path_reference, True), protoc_reference(test_case_input_path, test_case_output_path_reference),
protoc(test_case_input_path, test_case_output_path_betterproto, False), protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
)
if ref_code == 0:
print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m")
else:
print(
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
) )
message = f"Generated output for {test_case_name!r}"
if verbose: if verbose:
print(f"\033[31;1;4m{message}\033[0m")
if ref_out: if ref_out:
print("Reference stdout:")
sys.stdout.buffer.write(ref_out) sys.stdout.buffer.write(ref_out)
sys.stdout.buffer.flush()
if ref_err: if ref_err:
print("Reference stderr:")
sys.stderr.buffer.write(ref_err) sys.stderr.buffer.write(ref_err)
sys.stderr.buffer.flush()
if plg_code == 0:
print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m")
else:
print(
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
)
if verbose:
if plg_out: if plg_out:
print("Plugin stdout:")
sys.stdout.buffer.write(plg_out) sys.stdout.buffer.write(plg_out)
sys.stdout.buffer.flush()
if plg_err: if plg_err:
print("Plugin stderr:")
sys.stderr.buffer.write(plg_err) sys.stderr.buffer.write(plg_err)
sys.stdout.buffer.flush()
sys.stderr.buffer.flush() sys.stderr.buffer.flush()
else:
print(message)
return max(ref_code, plg_code) return max(ref_code, plg_code)
@@ -157,10 +136,6 @@ def main():
else: else:
verbose = False verbose = False
whitelist = set(sys.argv[1:]) whitelist = set(sys.argv[1:])
if platform.system() == "Windows":
asyncio.set_event_loop(asyncio.ProactorEventLoop())
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))

View File

@@ -1,15 +1,12 @@
import asyncio import asyncio
import sys from betterproto.tests.output_betterproto.service.service import (
from tests.output_betterproto.service.service import (
DoThingRequest,
DoThingResponse, DoThingResponse,
DoThingRequest,
GetThingRequest, GetThingRequest,
GetThingResponse,
TestStub as ThingServiceClient, TestStub as ThingServiceClient,
) )
import grpclib import grpclib
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor from grpclib.testing import ChannelFor
import pytest import pytest
from betterproto.grpc.util.async_channel import AsyncChannel from betterproto.grpc.util.async_channel import AsyncChannel
@@ -21,92 +18,31 @@ async def _test_client(client, name="clean room", **kwargs):
assert response.names == [name] assert response.names == [name]
def _assert_request_meta_received(deadline, metadata): def _assert_request_meta_recieved(deadline, metadata):
def server_side_test(stream): def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx( assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1 deadline._timestamp, 1
), "The provided deadline should be received serverside" ), "The provided deadline should be recieved serverside"
assert ( assert (
stream.metadata["authorization"] == metadata["authorization"] stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be received serverside" ), "The provided authorization metadata should be recieved serverside"
return server_side_test return server_side_test
@pytest.fixture
def handler_trailer_only_unauthenticated():
async def handler(stream: grpclib.server.Stream):
await stream.recv_message()
await stream.send_initial_metadata()
await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
return handler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_service_call(): async def test_simple_service_call():
async with ChannelFor([ThingService()]) as channel: async with ChannelFor([ThingService()]) as channel:
await _test_client(ThingServiceClient(channel)) await _test_client(ThingServiceClient(channel))
@pytest.mark.asyncio
async def test_trailer_only_error_unary_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_thing",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something")
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
async def test_trailer_only_error_stream_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_many_things",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
spy = mocker.spy(client, "_unary_unary")
await _test_client(client)
comments = spy.call_args_list[-1].args[1].comments
await _test_client(client)
assert spy.call_args_list[-1].args[1].comments is not comments
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_call_with_upfront_request_params(): async def test_service_call_with_upfront_request_params():
# Setting deadline # Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22) deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata) ThingServiceClient(channel, deadline=deadline, metadata=metadata)
@@ -117,7 +53,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(timeout) deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata) ThingServiceClient(channel, timeout=timeout, metadata=metadata)
@@ -134,7 +70,7 @@ async def test_service_call_lower_level_with_overrides():
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"} kwarg_metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
) as channel: ) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary( response = await client._unary_unary(
@@ -156,7 +92,7 @@ async def test_service_call_lower_level_with_overrides():
async with ChannelFor( async with ChannelFor(
[ [
ThingService( ThingService(
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata), test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
) )
] ]
) as channel: ) as channel:
@@ -204,8 +140,8 @@ async def test_async_gen_for_stream_stream_request():
assert response.version == response_index + 1 assert response.version == response_index + 1
response_index += 1 response_index += 1
if more_things: if more_things:
# Send some more requests as we receive responses to be sure coordination of # Send some more requests as we recieve reponses to be sure coordination of
# send/receive events doesn't matter # send/recieve events doesn't matter
await request_chan.send(GetThingRequest(more_things.pop(0))) await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done(): elif not send_initial_requests.done():
# Make sure the sending task it completed # Make sure the sending task it completed
@@ -215,4 +151,4 @@ async def test_async_gen_for_stream_stream_request():
request_chan.close() request_chan.close()
assert response_index == len( assert response_index == len(
expected_things expected_things
), "Didn't receive all expected responses" ), "Didn't recieve all exptected responses"

View File

@@ -27,7 +27,10 @@ class ClientStub:
async def to_list(generator: AsyncIterator): async def to_list(generator: AsyncIterator):
return [value async for value in generator] result = []
async for value in generator:
result.append(value)
return result
@pytest.fixture @pytest.fixture

View File

@@ -1,12 +1,12 @@
from tests.output_betterproto.service.service import ( from betterproto.tests.output_betterproto.service.service import (
DoThingResponse, DoThingResponse,
DoThingRequest, DoThingRequest,
GetThingRequest, GetThingRequest,
GetThingResponse, GetThingResponse,
TestStub as ThingServiceClient,
) )
import grpclib import grpclib
import grpclib.server from typing import Any, Dict
from typing import Dict
class ThingService: class ThingService:

View File

@@ -1,4 +1,4 @@
from tests.output_betterproto.bool import Test from betterproto.tests.output_betterproto.bool import Test
def test_value(): def test_value():

View File

@@ -1,5 +1,5 @@
import tests.output_betterproto.casing as casing import betterproto.tests.output_betterproto.casing as casing
from tests.output_betterproto.casing import Test from betterproto.tests.output_betterproto.casing import Test
def test_message_attributes(): def test_message_attributes():

View File

@@ -1,4 +1,4 @@
from tests.output_betterproto.casing_message_field_uppercase import Test from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
def test_message_casing(): def test_message_casing():

View File

@@ -1,9 +1,13 @@
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
# Remove from list when fixed. # Remove from list when fixed.
xfail = { xfail = {
"import_circular_dependency",
"oneof_enum", # 63
"namespace_keywords", # 70 "namespace_keywords", # 70
"namespace_builtin_types", # 53
"googletypes_struct", # 9 "googletypes_struct", # 9
"googletypes_value", # 9 "googletypes_value", # 9
"enum_skipped_value", # 93
"import_capitalized_package", "import_capitalized_package",
"example", # This is the example in the readme. Not a test. "example", # This is the example in the readme. Not a test.
} }
@@ -12,17 +16,7 @@ services = {
"googletypes_response", "googletypes_response",
"googletypes_response_embedded", "googletypes_response_embedded",
"service", "service",
"service_separate_packages",
"import_service_input_message", "import_service_input_message",
"googletypes_service_returns_empty", "googletypes_service_returns_empty",
"googletypes_service_returns_googletype", "googletypes_service_returns_googletype",
"example_service",
"empty_service",
} }
# Indicate json sample messages to skip when testing that json (de)serialization
# is symmetrical becuase some cases legitimately are not symmetrical.
# Each key references the name of the test scenario and the values in the tuple
# Are the names of the json files.
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
message Test {
enum MyEnum {
ZERO = 0;
ONE = 1;
// TWO = 2;
THREE = 3;
FOUR = 4;
}
MyEnum x = 1;
}

View File

@@ -0,0 +1,18 @@
from betterproto.tests.output_betterproto.enum_skipped_value import (
Test,
TestMyEnum,
)
import pytest
@pytest.mark.xfail(reason="#93")
def test_message_attributes():
assert (
Test(x=TestMyEnum.ONE).to_dict()["x"] == "ONE"
), "MyEnum.ONE is not serialized to 'ONE'"
assert (
Test(x=TestMyEnum.THREE).to_dict()["x"] == "THREE"
), "MyEnum.THREE is not serialized to 'THREE'"
assert (
Test(x=TestMyEnum.FOUR).to_dict()["x"] == "FOUR"
), "MyEnum.FOUR is not serialized to 'FOUR'"

View File

@@ -0,0 +1,3 @@
{
"greeting": "HEY"
}

View File

@@ -0,0 +1,14 @@
syntax = "proto3";
// Enum for the different greeting types
enum Greeting {
HI = 0;
HEY = 1;
// Formal greeting
HELLO = 2;
}
message Test {
// Greeting enum example
Greeting greeting = 1;
}

View File

@@ -0,0 +1,8 @@
syntax = "proto3";
package hello;
// Greeting represents a message you can tell a user.
message Greeting {
string message = 1;
}

View File

@@ -3,8 +3,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf import betterproto.lib.google.protobuf as protobuf
import pytest import pytest
from tests.mocks import MockChannel from betterproto.tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub from betterproto.tests.output_betterproto.googletypes_response import TestStub
test_cases = [ test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5), (TestStub.get_double, protobuf.DoubleValue, 2.5),
@@ -21,7 +21,7 @@ test_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type( async def test_channel_recieves_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
): ):
wrapped_value = wrapper_class() wrapped_value = wrapper_class()

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from tests.mocks import MockChannel from betterproto.tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import ( from betterproto.tests.output_betterproto.googletypes_response_embedded import (
Output, Output,
TestStub, TestStub,
) )

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
import "request_message.proto";
// Tests generated service correctly imports the RequestMessage
service Test {
rpc DoThing (RequestMessage) returns (RequestResponse);
}
message RequestResponse {
int32 value = 1;
}

View File

@@ -0,0 +1,16 @@
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.import_service_input_message import (
RequestResponse,
TestStub,
)
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
@pytest.mark.asyncio
async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing()
assert mock_response == response

Some files were not shown because too many files have changed in this diff Show More