61 Commits

Author SHA1 Message Date
Arun Babu Neelicattu
a157f05480 Release v2.0.0b2 (#175) 2020-11-24 23:04:33 +01:00
James
69dfe9cafc Implement Message.__bool__ (#142)
* Implement Message.__bool__ with similar semantics to a collection, such that any value being set on the message (i.e. having a non-default value) make the Message value truthy .

Co-authored-by: nat <n@natn.me>
2020-11-24 19:35:09 +01:00
Arun Babu Neelicattu
a8a082e4e7 Update dependencies and add ci checks for python 3.9 (#173)
* Update locked dependencies to fix grpcio compile issue with python 3.9
* ci: add python 3.9
2020-11-24 19:28:28 +01:00
Tim Schmidt
e44de6da06 replace now-disabled set-env command (#172)
thanks @abn
2020-11-21 14:42:50 +01:00
James
a5e0ef910f Fixes for Python 3.9 (#140)
Fix issue in logic for evaluating field types affecting python 3.9
2020-11-01 15:23:02 +01:00
James
8f7af272cc QOL fixes (#141)
- Add missing type annotations
- Various style improvements
- Use constants more consistently
- enforce black on benchmark code
2020-10-17 19:27:11 +02:00
Arun Babu Neelicattu
bf9412e083 Use poetry-core as PEP 517 build backend (#108)
This change replaces the use of poetry as the build backend in favour
of the leaner poetry-core. This speeds up PEP-517 builds for source
installs, tox environment setup etc.
2020-10-01 14:45:45 +02:00
Keerthan Jaic
4630c1cc67 bump grpclib to 0.4.1 (#150) 2020-09-23 21:55:23 +02:00
James
d3e4fbb311 Add Documentation (#125)
Add sphinx docs with readthedocs integration.

Docs can be built locally with `poe docs`.
2020-09-20 22:00:02 +02:00
Jonas Kalderstam
58556e0eb6 Update README with example of calling protoc from python (#149) 2020-09-19 17:03:49 +02:00
Adrian Garcia Badaracco
a3f5f21738 Add benchmarks (#148)
Add asv based benchmarks to guide future optimisation work.
2020-09-19 16:28:16 +02:00
Arun Babu Neelicattu
0028cc384a Relax black version constraints (#146)
This change ensures that the wheel built only requests for the minimum
version of black it requires to function as intended. Without this
change any project that uses betterproto[compiler] would break while
resolving dependencies.
2020-08-31 22:10:57 +02:00
Chris Chambers
034e2e7da0 Add support for recursive messages (#130)
Changes message initialization (`__post_init__`) so that default values
are no longer eagerly created to prevent infinite recursion when
initializing recursive messages.

As a result, `PLACEHOLDER` will be present in the message for any
uninitialized fields.  So, an implementation of `__get_attribute__` is
added that checks for `PLACEHOLDER` and lazily creates and stores
default field values.

And, because `PLACEHOLDER` values don't compare equal with zero values,
a custom implementation of `__eq__` is provided, and the code generation
template is updated so that messages generate with `@dataclass(eq=False)`.

Also add new Message __repr__ implementation that skips PLACEHOLDER 
values and orders keys by number from the proto.

Co-authored-by: Christopher Chambers <chris@peanutcode.com>
Co-authored-by: nat <n@natn.me>
Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com>
2020-08-30 21:04:36 +02:00
James
ca16b6ed34 Various micro-optimizations (#139) 2020-08-30 17:23:57 +02:00
James
16d554db75 Update black 2020-08-29 17:15:59 +02:00
Adrian Garcia Badaracco
9ef5503728 Small improvements to models.py 2020-08-23 14:26:15 +02:00
Adrian Garcia Badaracco
c93351ef21 Factor code template compilation out into a separate module 2020-08-09 20:06:39 +02:00
James
80bef7c94f Improve logic to avoid keyword collisions in generated code
Use the standard library keyword module instead of a hard coded list and applying it to enum keys as well.
2020-08-09 12:41:41 +02:00
nat
804805f0f5 Update poe (#132)
- This update improves support for windows & removes the direct dependency on poetry
2020-08-06 22:16:25 +02:00
Arun Babu Neelicattu
43c134d27c ci: refactor jobs and improve platform coverage (#128) 2020-07-30 14:47:38 +02:00
Arun Babu Neelicattu
0cd9510b54 Support deprecated message and fields (#126) 2020-07-30 14:47:01 +02:00
Arun Babu Neelicattu
beafc812ff Fix static type checking for grpclib client (#124)
* Fix static type checking in grpclib client
* Fix python3.6 compatibility issue with dataclasses
2020-07-30 11:30:58 +02:00
Arun Babu Neelicattu
3d8c0cb713 grpclib_client: handle trailer-only responses (#127)
Resolves: #123
2020-07-25 19:57:46 +02:00
nat
c513853301 Replace Makefile with poe tasks in pyproject.yaml (#118)
https://github.com/nat-n/poethepoet
2020-07-25 19:54:40 +02:00
Brady Kieffer
c1a76a5f5e Serialize default values in oneofs when calling to_dict() or to_json() (#110)
* Serialize default values in oneofs when calling to_dict() or to_json()

This change is consistent with the official protobuf implementation. If
a default value is set when using a oneof, and then a message is
translated from message -> JSON -> message, the default value is kept in
tact. Also, if no default value is set, they remain null.

* Some cleanup + testing for nested messages with oneofs

* Cleanup oneof_enum test cases, they should be fixed

This _should_ address:
https://github.com/danielgtaylor/python-betterproto/issues/63

* Include default value oneof fields when serializing to bytes

This will cause oneof fields with default values to explicitly be sent
to clients. Note that does not mean that all fields are serialized and
sent to clients, just those that _could_ be null and are not.

* Remove assignment when populating a sub-message within a proto

Also, move setattr out one indentation level

* Properly transform proto with empty string in oneof to bytes

Also, updated tests to ensure that which_one_of picks up the set field

* Formatting betterproto/__init__.py

* Adding test cases demonstrating equivalent behaviour with google impl

* Removing a temporary file I made locally

* Adding some clarifying comments

* Fixing tests for python38
2020-07-25 19:51:40 +02:00
Joshua Salzedo
2745953a8e Fix the readme gRPC usage example (#122)
* re-implement README gRPC client example to be a self-contained script
 - fix a syntax error
 - fix a usage error

* asyncio.run() was added in 3.7
 - this lib targets >= 3.6

* Apply suggestions from code review

Optimized imports, store RPC call result before printing

Co-authored-by: Arun Babu Neelicattu <arun.neelicattu@gmail.com>

* add entry-point check to example

Co-authored-by: Arun Babu Neelicattu <arun.neelicattu@gmail.com>
2020-07-25 19:45:26 +02:00
Adrian Garcia Badaracco
b5dcac1250 REF: Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data (#121)
Refactor plugin to parse input into data-class based hierarchical structure
2020-07-25 19:44:02 +02:00
James
cbd3437080 Some minor consistency changes
- replace some usages of `==` with `is`
- use available constants instead of magic strings for type names

Co-authored-by: nat <nat.noordanus@gmail.com>
2020-07-12 16:07:27 +02:00
boukeversteegh
2585a07fcf Improve poetry install speed by first upgrading pip 2020-07-12 15:42:31 +02:00
Bouke Versteegh
6c29771f4c Fix: to_dict returns wrong enum fields when numbering is not consecutive (#102)
Fixes #93 to_dict returns wrong enum fields when numbering is not consecutive
2020-07-12 15:06:55 +02:00
Arun Babu Neelicattu
0ba0692dec Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are
initialised in local scope if unspecified or `None`.
2020-07-11 22:33:44 +02:00
Arun Babu Neelicattu
42e197f985 Ensure we clean up egg-info directories 2020-07-11 19:51:01 +02:00
Arun Babu Neelicattu
459d12b24d Move betterproto → src/betterproto
This change avoids some nasty import issues and also ensures that the
right code is tested and arbitrary code is not included when packaging.
2020-07-11 19:51:01 +02:00
Arun Babu Neelicattu
cebf9176a3 Move betterproto/tests → tests 2020-07-11 19:51:01 +02:00
Bouke Versteegh
8864f4fdbd Merge pull request #103 from boukeversteegh/fix/service-input-message
Fix - No arguments are generated for stub methods when using `import` with proto definition
2020-07-10 22:55:05 +02:00
Arun Babu Neelicattu
03211604bc Replace dependency on protoc with grpcio-tools
This change removes the dependency on platform provided protobuf tools
in favour of `grpcio-tools` dependency. This makes both development and
compiler use independent from platform dependencies.
2020-07-10 13:16:40 +02:00
boukeversteegh
1d7ba850e9 Reorder methods, use BETTERPROTO_DUMP for dump env var, docs. 2020-07-09 23:09:34 +02:00
Bouke Versteegh
b2651335ce Merge pull request #112 from danielgtaylor/pr/readme-contribution
Updated readme with contribution section. More help welcome 😃
2020-07-09 22:53:22 +02:00
nat
5a591ef2a4 Add link to testing README in CONTRIBUTING.md 2020-07-09 20:41:13 +02:00
boukeversteegh
8d7d0efb9b Move contributing guide to CONTRIBUTING.md 2020-07-09 09:31:04 +02:00
boukeversteegh
b891d257f6 Updated readme with contribution section. More help welcome 😃 2020-07-09 00:16:36 +02:00
Bouke Versteegh
8bcb67b66f Merge pull request #81 from discord/serialized_on_wire_repeated
Always set serialized_on_wire for all parsed message fields
2020-07-08 23:10:14 +02:00
boukeversteegh
72d72b4603 Merge remote-tracking branch 'daniel/master' into fix/service-input-message
# Conflicts:
#	betterproto/plugin.py
2020-07-08 23:00:32 +02:00
Bouke Versteegh
3273ae4d2c Merge pull request #100 from boukeversteegh/fix/circular-dependencies
Import bug - Circular Dependencies
2020-07-07 21:45:06 +02:00
Bouke Versteegh
6fe666473d Merge pull request #106 from abn/minor-formatting
Minor non-functional improvements
2020-07-07 20:22:44 +02:00
Arun Babu Neelicattu
0338fcba29 Ignore commonly used .venv directory 2020-07-07 19:23:38 +02:00
Arun Babu Neelicattu
0f3ad25770 Minor non-functional changes
- fix few typos
- remove unused imports
- fix minor code-quality issues
- replace `grpclib._protocols` with `grpclib._typing`
- fix boolean and None assertions in test cases
2020-07-07 19:23:38 +02:00
Bouke Versteegh
586e28d2dc Merge pull request #104 from abn/fix-casing
Add missing async/await keywords when casing
2020-07-07 14:32:51 +02:00
Arun Babu Neelicattu
a8d8159d27 Add missing async/await keywords when casing 2020-07-07 13:15:46 +02:00
boukeversteegh
3f519d4fb1 Fixes #23 again, a broken test made it seem the issue was fixed before. 2020-07-05 17:14:53 +02:00
boukeversteegh
dedead048f Read proto objects before services 2020-07-05 13:10:25 +02:00
boukeversteegh
87b3a4b86d Move parsing of protobuf data types and services into separate methods 2020-07-05 12:27:06 +02:00
boukeversteegh
f2e87192b0 Clarify variable names 2020-07-05 12:24:21 +02:00
boukeversteegh
98d00f0d21 Supports running plugin.py standalone by reading from a dump-file, so its possible to debug it. 2020-07-05 12:20:55 +02:00
boukeversteegh
23dcbc2695 Fixes circular import problem when a non-circular dependency triangle is flattened into two python packages 2020-07-04 15:49:55 +02:00
boukeversteegh
0af0cf4bfb Fixes circular import problem when a non-circular dependency triangle is flattened into two python packages 2020-07-04 15:35:42 +02:00
Danny Weinberg
28a288924f Change to have parse *always* set serialized_on_wire 2020-06-04 16:20:32 -07:00
Danny Weinberg
5c700618fd Black again lol 2020-06-04 13:42:43 -07:00
Danny Weinberg
a914306f33 Put test into test_features, simplify to call parse directly 2020-06-04 13:42:07 -07:00
Danny Weinberg
67422db6b9 Fix formatting 2020-06-04 11:34:20 -07:00
Danny Weinberg
061bf86a9c Set serialized_on_wire when message contains only lists
This fixes a bug where serialized_on_wire was not set when a message contained only repeated values (eg in a list or map). The fix here is to just set it to true in the `parse` method as soon as we receive any valid data. This also adds a test to expose the behavior.
2020-06-04 11:04:36 -07:00
171 changed files with 4163 additions and 1545 deletions

23
.github/CONTRIBUTING.md vendored Normal file
View File

@@ -0,0 +1,23 @@
# Contributing
There's lots to do, and we're working hard, so any help is welcome!
- :speech_balloon: Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
What can you do?
- :+1: Vote on [issues](https://github.com/danielgtaylor/python-betterproto/issues).
- :speech_balloon: Give feedback on [Pull Requests](https://github.com/danielgtaylor/python-betterproto/pulls) and [Issues](https://github.com/danielgtaylor/python-betterproto/issues):
- Suggestions
- Express approval
- Raise concerns
- :small_red_triangle: Create an issue:
- File a bug (please check its not a duplicate)
- Propose an enhancement
- :white_check_mark: Create a PR:
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/betterproto/tests/README.md) to make bug-fixing easier
- Fix any of the open issues
- [Good first issues](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
- [Issues with tests](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22has+test%22)
- New bugfix or idea
- If you'd like to discuss your idea first, join us on Slack!

View File

@@ -1,74 +1,69 @@
name: CI name: CI
on: [push, pull_request] on:
push:
branches:
- master
pull_request:
branches:
- '**'
jobs: jobs:
tests:
check-formatting: name: ${{ matrix.os }} / ${{ matrix.python-version }}
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}-latest
name: Consult black on python formatting
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7
- uses: Gr1N/setup-poetry@v2
- uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install dependencies
run: poetry install
- name: Run black
run: make check-style
run-tests:
runs-on: ubuntu-latest
name: Run tests with tox
strategy: strategy:
matrix: matrix:
python-version: [ '3.6', '3.7', '3.8'] os: [Ubuntu, MacOS, Windows]
python-version: [3.6, 3.7, 3.8, 3.9]
exclude:
- os: Windows
python-version: 3.6
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions/setup-python@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- uses: Gr1N/setup-poetry@v2
- uses: actions/cache@v2 - name: Get full Python version
id: full-python-version
shell: bash
run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")
- name: Install poetry
shell: bash
run: |
python -m pip install poetry
echo "$HOME/.poetry/bin" >> $GITHUB_PATH
- name: Configure poetry
shell: bash
run: poetry config virtualenvs.in-project true
- name: Set up cache
uses: actions/cache@v2
id: cache
with: with:
path: ~/.cache/pypoetry/virtualenvs path: .venv
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }} key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry- - name: Ensure cache is healthy
if: steps.cache.outputs.cache-hit == 'true'
shell: bash
run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv
- name: Install dependencies - name: Install dependencies
shell: bash
run: | run: |
sudo apt install protobuf-compiler libprotobuf-dev poetry run python -m pip install pip -U
poetry install poetry install
- name: Run tests
run: |
make generate
make test
build-release: - name: Generate code from proto files
runs-on: ubuntu-latest shell: bash
run: poetry run python -m tests.generate -v
steps: - name: Execute test suite
- uses: actions/checkout@v2 shell: bash
- uses: actions/setup-python@v2 run: poetry run pytest tests/
with:
python-version: 3.7
- uses: Gr1N/setup-poetry@v2
- name: Build package
run: poetry build
- name: Publish package to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
run: poetry publish -n
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}

26
.github/workflows/code-quality.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Code Quality
on:
push:
branches:
- master
pull_request:
branches:
- '**'
jobs:
check-formatting:
name: Check code/doc formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run Black
uses: lgeiger/black-action@master
with:
args: --check src/ tests/ benchmarks/
- name: Install rST dependcies
run: python -m pip install doc8
- name: Lint documentation for errors
run: python -m doc8 docs --max-line-length 88 --ignore-path-errors "docs/migrating.rst;D001"
# it has a table which is longer than 88 characters long

31
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: Release
on:
push:
branches:
- master
tags:
- '**'
pull_request:
branches:
- '**'
jobs:
packaging:
name: Distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install poetry
run: python -m pip install poetry
- name: Build package
run: poetry build
- name: Publish package to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi }}
run: poetry publish -n

5
.gitignore vendored
View File

@@ -6,7 +6,7 @@
.pytest_cache .pytest_cache
.python-version .python-version
build/ build/
betterproto/tests/output_* tests/output_*
**/__pycache__ **/__pycache__
dist dist
**/*.egg-info **/*.egg-info
@@ -14,3 +14,6 @@ output
.idea .idea
.DS_Store .DS_Store
.tox .tox
.venv
.asv
venv

17
.readthedocs.yml Normal file
View File

@@ -0,0 +1,17 @@
version: 2
formats: []
build:
image: latest
sphinx:
configuration: docs/conf.py
fail_on_warning: false
python:
version: 3.7
install:
- method: pip
path: .
extra_requirements:
- dev

View File

@@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`. - Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [2.0.0b2] - 2020-11-24
- Add support for deprecated message and fields [#126](https://github.com/danielgtaylor/python-betterproto/pull/126)
- Add support for recursive messages [#130](https://github.com/danielgtaylor/python-betterproto/pull/130)
- Add support for `bool(Message)` [#142](https://github.com/danielgtaylor/python-betterproto/pull/142)
- Improve support for Python 3.9 [#140](https://github.com/danielgtaylor/python-betterproto/pull/140) [#173](https://github.com/danielgtaylor/python-betterproto/pull/173)
- Improve keyword sanitisation for generated code [#137](https://github.com/danielgtaylor/python-betterproto/pull/137)
- Fix missing serialized_on_wire when message contains only lists [#81](https://github.com/danielgtaylor/python-betterproto/pull/81)
- Fix circular dependencies [#100](https://github.com/danielgtaylor/python-betterproto/pull/100)
- Fix to_dict enum fields when numbering is not consecutive [#102](https://github.com/danielgtaylor/python-betterproto/pull/102)
- Fix argument generation for stub methods when using `import` with proto definition [#103](https://github.com/danielgtaylor/python-betterproto/pull/103)
- Fix missing async/await keywords when casing [#104](https://github.com/danielgtaylor/python-betterproto/pull/104)
- Fix mutable default arguments in generated code [#105](https://github.com/danielgtaylor/python-betterproto/pull/105)
- Fix serialisation of default values in oneofs when calling to_dict() or to_json() [#110](https://github.com/danielgtaylor/python-betterproto/pull/110)
- Fix static type checking for grpclib client [#124](https://github.com/danielgtaylor/python-betterproto/pull/124)
- Fix python3.6 compatibility issue with dataclasses [#124](https://github.com/danielgtaylor/python-betterproto/pull/124)
- Fix handling of trailer-only responses [#127](https://github.com/danielgtaylor/python-betterproto/pull/127)
- Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data [#121](https://github.com/danielgtaylor/python-betterproto/pull/121)
- Refactor template compilation logic [#136](https://github.com/danielgtaylor/python-betterproto/pull/136)
- Replace use of platform provided protoc with development dependency on grpcio-tools [#107](https://github.com/danielgtaylor/python-betterproto/pull/107)
- Switch to using `poe` from `make` to manage project development tasks [#118](https://github.com/danielgtaylor/python-betterproto/pull/118)
- Improve CI platform coverage [#128](https://github.com/danielgtaylor/python-betterproto/pull/128)
## [2.0.0b1] - 2020-07-04 ## [2.0.0b1] - 2020-07-04
[Upgrade Guide](./docs/upgrading.md) [Upgrade Guide](./docs/upgrading.md)
@@ -15,8 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
> `2.0.0` will be released once the interface is stable. > `2.0.0` will be released once the interface is stable.
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83) - Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) - Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
- Fix No arguments are generated for stub methods when using import with proto definition
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25) - Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57) - Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)

View File

@@ -1,42 +0,0 @@
.PHONY: help setup generate test types format clean plugin full-test check-style
help: ## - Show this help.
@fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//'
# Dev workflow tasks
generate: ## - Generate test cases (do this once before running test)
poetry run ./betterproto/tests/generate.py
test: ## - Run tests
poetry run pytest --cov betterproto
types: ## - Check types with mypy
poetry run mypy betterproto --ignore-missing-imports
format: ## - Apply black formatting to source code
poetry run black . --exclude tests/output_
clean: ## - Clean out generated files from the workspace
rm -rf .coverage \
.mypy_cache \
.pytest_cache \
dist \
**/__pycache__ \
betterproto/tests/output_*
# Manual testing
# By default write plugin output to a directory called output
o=output
plugin: ## - Execute the protoc plugin, with output write to `output` or the value passed to `-o`
mkdir -p $(o)
protoc --plugin=protoc-gen-custom=betterproto/plugin.py $(i) --custom_out=$(o)
# CI tasks
full-test: generate ## - Run full testing sequence with multiple pythons
poetry run tox
check-style: ## - Check if code style is correct
poetry run black . --check --diff --exclude tests/output_

View File

@@ -37,7 +37,6 @@ This project exists because I am unhappy with the state of the official Google p
- Uses `SerializeToString()` rather than the built-in `__bytes__()` - Uses `SerializeToString()` rather than the built-in `__bytes__()`
- Special wrapped types don't use Python's `None` - Special wrapped types don't use Python's `None`
- Timestamp/duration types don't use Python's built-in `datetime` module - Timestamp/duration types don't use Python's built-in `datetime` module
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
## Installation ## Installation
@@ -71,13 +70,20 @@ message Greeting {
} }
``` ```
You can run the following: You can run the following to invoke protoc directly:
```sh ```sh
mkdir lib mkdir lib
protoc -I . --python_betterproto_out=lib example.proto protoc -I . --python_betterproto_out=lib example.proto
``` ```
or run the following to invoke protoc via grpcio-tools:
```sh
pip install grpcio-tools
python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto
```
This will generate `lib/hello/__init__.py` which looks like: This will generate `lib/hello/__init__.py` which looks like:
```python ```python
@@ -126,7 +132,7 @@ Greeting(message="Hey!")
The generated Protobuf `Message` classes are compatible with [grpclib](https://github.com/vmagamedov/grpclib) so you are free to use it if you like. That said, this project also includes support for async gRPC stub generation with better static type checking and code completion support. It is enabled by default. The generated Protobuf `Message` classes are compatible with [grpclib](https://github.com/vmagamedov/grpclib) so you are free to use it if you like. That said, this project also includes support for async gRPC stub generation with better static type checking and code completion support. It is enabled by default.
Given an example like: Given an example service definition:
```protobuf ```protobuf
syntax = "proto3"; syntax = "proto3";
@@ -153,22 +159,37 @@ service Echo {
} }
``` ```
You can use it like so (enable async in the interactive shell first): A client can be implemented as follows:
```python ```python
>>> import echo import asyncio
>>> from grpclib.client import Channel import echo
>>> channel = Channel(host="127.0.0.1", port=1234) from grpclib.client import Channel
>>> service = echo.EchoStub(channel)
>>> await service.echo(value="hello", extra_times=1)
EchoResponse(values=["hello", "hello"])
>>> async for response in service.echo_stream(value="hello", extra_times=1)
async def main():
channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1)
print(response)
async for response in service.echo_stream(value="hello", extra_times=1):
print(response) print(response)
EchoStreamResponse(value="hello") # don't forget to close the channel when done!
EchoStreamResponse(value="hello") channel.close()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
```
which would output
```python
EchoResponse(values=['hello', 'hello'])
EchoStreamResponse(value='hello')
EchoStreamResponse(value='hello')
``` ```
### JSON ### JSON
@@ -304,35 +325,31 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
## Development ## Development
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)! - _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
- _See how you can help &rarr; [Contributing](.github/CONTRIBUTING.md)_
### Requirements ### Requirements
- Python (3.6 or higher) - Python (3.6 or higher)
- [protoc](https://grpc.io/docs/protoc-installation/) (3.12 or higher)
*Needed to compile `.proto` files and run the tests*
- [poetry](https://python-poetry.org/docs/#installation) - [poetry](https://python-poetry.org/docs/#installation)
*Needed to install dependencies in a virtual environment* *Needed to install dependencies in a virtual environment*
- make ([ubuntu](https://www.howtoinstall.me/ubuntu/18-04/make/), [windows](https://stackoverflow.com/questions/32127524/how-to-install-and-use-make-in-windows), [mac](https://osxdaily.com/2014/02/12/install-command-line-tools-mac-os-x/)) - [poethepoet](https://github.com/nat-n/poethepoet) for running development tasks as defined in pyproject.toml
- Can be installed to your host environment via `pip install poethepoet` then executed as simple `poe`
*Needed to conveniently run development tasks.* - or run from the poetry venv as `poetry run poe`
*Alternatively, manually run the commands defined in the [Makefile](./Makefile)*
### Setup ### Setup
```sh ```sh
# Get set up with the virtual env & dependencies # Get set up with the virtual env & dependencies
poetry run pip install --upgrade pip
poetry install poetry install
# Activate the poetry environment # Activate the poetry environment
poetry shell poetry shell
``` ```
Run `make help` to see all available development tasks.
### Code style ### Code style
This project enforces [black](https://github.com/psf/black) python code formatting. This project enforces [black](https://github.com/psf/black) python code formatting.
@@ -340,7 +357,7 @@ This project enforces [black](https://github.com/psf/black) python code formatti
Before committing changes run: Before committing changes run:
```sh ```sh
make format poe format
``` ```
To avoid merge conflicts later, non-black formatted python code will fail in CI. To avoid merge conflicts later, non-black formatted python code will fail in CI.
@@ -374,15 +391,15 @@ Here's how to run the tests.
```sh ```sh
# Generate assets from sample .proto files required by the tests # Generate assets from sample .proto files required by the tests
make generate poe generate
# Run the tests # Run the tests
make test poe test
``` ```
To run tests as they are run in CI (with tox) run: To run tests as they are run in CI (with tox) run:
```sh ```sh
make full-test poe full-test
``` ```
### (Re)compiling Google Well-known Types ### (Re)compiling Google Well-known Types
@@ -403,7 +420,6 @@ protoc \
/usr/local/include/google/protobuf/*.proto /usr/local/include/google/protobuf/*.proto
``` ```
### TODO ### TODO
- [x] Fixed length fields - [x] Fixed length fields
@@ -434,10 +450,10 @@ protoc \
- [x] Enum strings - [x] Enum strings
- [x] Well known types support (timestamp, duration, wrappers) - [x] Well known types support (timestamp, duration, wrappers)
- [x] Support different casing (orig vs. camel vs. others?) - [x] Support different casing (orig vs. camel vs. others?)
- [ ] Async service stubs - [x] Async service stubs
- [x] Unary-unary - [x] Unary-unary
- [x] Server streaming response - [x] Server streaming response
- [ ] Client streaming request - [x] Client streaming request
- [x] Renaming messages and fields to conform to Python name standards - [x] Renaming messages and fields to conform to Python name standards
- [x] Renaming clashes with language keywords - [x] Renaming clashes with language keywords
- [x] Python package - [x] Python package

157
asv.conf.json Normal file
View File

@@ -0,0 +1,157 @@
{
// The version of the config file format. Do not change, unless
// you know what you are doing.
"version": 1,
// The name of the project being benchmarked
"project": "python-betterproto",
// The project's homepage
"project_url": "https://github.com/danielgtaylor/python-betterproto",
// The URL or local path of the source code repository for the
// project being benchmarked
"repo": ".",
// The Python project's subdirectory in your repo. If missing or
// the empty string, the project is assumed to be located at the root
// of the repository.
// "repo_subdir": "",
// Customizable commands for building, installing, and
// uninstalling the project. See asv.conf.json documentation.
//
"install_command": ["python -m pip install ."],
"uninstall_command": ["return-code=any python -m pip uninstall -y {project}"],
"build_command": ["python -m pip wheel -w {build_cache_dir} {build_dir}"],
// List of branches to benchmark. If not provided, defaults to "master"
// (for git) or "default" (for mercurial).
// "branches": ["master"], // for git
// "branches": ["default"], // for mercurial
// The DVCS being used. If not set, it will be automatically
// determined from "repo" by looking at the protocol in the URL
// (if remote), or by looking for special directories, such as
// ".git" (if local).
// "dvcs": "git",
// The tool to use to create environments. May be "conda",
// "virtualenv" or other value depending on the plugins in use.
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"environment_type": "virtualenv",
// timeout in seconds for installing any dependencies in environment
// defaults to 10 min
//"install_timeout": 600,
// the base URL to show a commit for the project.
// "show_commit_url": "http://github.com/owner/project/commit/",
// The Pythons you'd like to test against. If not provided, defaults
// to the current version of Python used to run `asv`.
// "pythons": ["2.7", "3.6"],
// The list of conda channel names to be searched for benchmark
// dependency packages in the specified order
// "conda_channels": ["conda-forge", "defaults"],
// The matrix of dependencies to test. Each key is the name of a
// package (in PyPI) and the values are version numbers. An empty
// list or empty string indicates to just test against the default
// (latest) version. null indicates that the package is to not be
// installed. If the package to be tested is only available from
// PyPi, and the 'environment_type' is conda, then you can preface
// the package name by 'pip+', and the package will be installed via
// pip (with all the conda available packages installed first,
// followed by the pip installed packages).
//
// "matrix": {
// "numpy": ["1.6", "1.7"],
// "six": ["", null], // test with and without six installed
// "pip+emcee": [""], // emcee is only available for install with pip.
// },
// Combinations of libraries/python versions can be excluded/included
// from the set to test. Each entry is a dictionary containing additional
// key-value pairs to include/exclude.
//
// An exclude entry excludes entries where all values match. The
// values are regexps that should match the whole string.
//
// An include entry adds an environment. Only the packages listed
// are installed. The 'python' key is required. The exclude rules
// do not apply to includes.
//
// In addition to package names, the following keys are available:
//
// - python
// Python version, as in the *pythons* variable above.
// - environment_type
// Environment type, as above.
// - sys_platform
// Platform, as in sys.platform. Possible values for the common
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
//
// "exclude": [
// {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
// {"environment_type": "conda", "six": null}, // don't run without six on conda
// ],
//
// "include": [
// // additional env for python2.7
// {"python": "2.7", "numpy": "1.8"},
// // additional env if run on windows+conda
// {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""},
// ],
// The directory (relative to the current directory) that benchmarks are
// stored in. If not provided, defaults to "benchmarks"
// "benchmark_dir": "benchmarks",
// The directory (relative to the current directory) to cache the Python
// environments in. If not provided, defaults to "env"
"env_dir": ".asv/env",
// The directory (relative to the current directory) that raw benchmark
// results are stored in. If not provided, defaults to "results".
"results_dir": ".asv/results",
// The directory (relative to the current directory) that the html tree
// should be written to. If not provided, defaults to "html".
"html_dir": ".asv/html",
// The number of characters to retain in the commit hashes.
// "hash_length": 8,
// `asv` will cache results of the recent builds in each
// environment, making them faster to install next time. This is
// the number of builds to keep, per environment.
// "build_cache_size": 2,
// The commits after which the regression search in `asv publish`
// should start looking for regressions. Dictionary whose keys are
// regexps matching to benchmark names, and values corresponding to
// the commit (exclusive) after which to start looking for
// regressions. The default is to start from the first commit
// with results. If the commit is `null`, regression detection is
// skipped for the matching benchmark.
//
// "regressions_first_commits": {
// "some_benchmark": "352cdf", // Consider regressions only after this commit
// "another_benchmark": null, // Skip regression detection altogether
// },
// The thresholds for relative change in results, after which `asv
// publish` starts reporting regressions. Dictionary of the same
// form as in ``regressions_first_commits``, with values
// indicating the thresholds. If multiple entries match, the
// maximum is taken. If no entry matches, the default is 5%.
//
// "regressions_thresholds": {
// "some_benchmark": 0.01, // Threshold of 1%
// "another_benchmark": 0.5, // Threshold of 50%
// },
}

1
benchmarks/__init__.py Normal file
View File

@@ -0,0 +1 @@

59
benchmarks/benchmarks.py Normal file
View File

@@ -0,0 +1,59 @@
import betterproto
from dataclasses import dataclass
@dataclass
class TestMessage(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
class BenchMessage:
"""Test creation and usage a proto message."""
def setup(self):
self.cls = TestMessage
self.instance = TestMessage()
self.instance_filled = TestMessage(0, "test", 0.0)
def time_overhead(self):
"""Overhead in class definition."""
@dataclass
class Message(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
def time_instantiation(self):
"""Time instantiation"""
self.cls()
def time_attribute_access(self):
"""Time to access an attribute"""
self.instance.foo
self.instance.bar
self.instance.baz
def time_init_with_values(self):
"""Time to set an attribute"""
self.cls(0, "test", 0.0)
def time_attribute_setting(self):
"""Time to set attributes"""
self.instance.foo = 0
self.instance.bar = "test"
self.instance.baz = 0.0
def time_serialize(self):
"""Time serializing a message to wire."""
bytes(self.instance_filled)
class MemSuite:
def setup(self):
self.cls = TestMessage
def mem_instance(self):
return self.cls()

View File

@@ -1,2 +0,0 @@
@SET plugin_dir=%~dp0
@python %plugin_dir%/plugin.py %*

View File

@@ -1,403 +0,0 @@
#!/usr/bin/env python
import itertools
import os.path
import pathlib
import re
import sys
import textwrap
from typing import List, Union
import betterproto
from betterproto.compile.importing import get_type_reference
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
import black
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
)
import google.protobuf.wrappers_pb2 as google_wrappers
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
return "float"
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
return "int"
elif field.type == 8:
return "bool"
elif field.type == 9:
return "str"
elif field.type in [11, 14]:
# Type referencing another defined Message or a named enum
return get_type_reference(package, imports, field.type_name)
elif field.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {field.type}")
def get_py_zero(type_num: int) -> Union[str, float]:
zero: Union[str, float] = 0
if type_num in []:
zero = 0.0
elif type_num == 8:
zero = "False"
elif type_num == 9:
zero = '""'
elif type_num == 11:
zero = "None"
elif type_num == 12:
zero = 'b""'
return zero
def traverse(proto_file):
def _traverse(path, items, prefix=""):
for i, item in enumerate(items):
# Adjust the name since we flatten the heirarchy.
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return ""
def generate_code(request, response):
plugin_options = request.parameter.split(",") if request.parameter else []
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
)
template = env.get_template("template.py.j2")
output_map = {}
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
continue
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
if output_file not in output_map:
output_map[output_file] = {"package": proto_file.package, "files": []}
output_map[output_file]["files"].append(proto_file)
# TODO: Figure out how to handle gRPC request/response messages and add
# processing below for Service.
for filename, options in output_map.items():
package = options["package"]
# print(package, filename, file=sys.stderr)
output = {
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
"services": [],
}
for proto_file in options["files"]:
item: DescriptorProto
for item, path in traverse(proto_file):
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
continue
data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
)
for i, f in enumerate(item.field):
t = py_type(package, output["imports"], f)
zero = get_py_zero(f.type)
repeated = False
packed = False
field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", f.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in item.nested_type:
if (
nested.name.replace("_", "").lower()
== map_entry
):
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
package,
output["imports"],
nested.field[0],
)
v = py_type(
package,
output["imports"],
nested.field[1],
)
t = f"Dict[{k}, {v}]"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
output["typing_imports"].add("Dict")
if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
output["typing_imports"].add("List")
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True
one_of = ""
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
output["typing_imports"].add("Optional")
if "timedelta" in t:
output["datetime_imports"].add("timedelta")
elif "datetime" in t:
output["datetime_imports"].add("datetime")
data["properties"].append(
{
"name": f.name,
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
"one_of": one_of,
}
)
# print(f, file=sys.stderr)
output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)
output["enums"].append(data)
for i, service in enumerate(proto_file.service):
# print(service, file=sys.stderr)
data = {
"name": service.name,
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, i]),
"methods": [],
}
for j, method in enumerate(service.method):
input_message = None
input_type = get_type_reference(
package, output["imports"], method.input_type
).strip('"')
for msg in output["messages"]:
if msg["name"] == input_type:
input_message = msg
for field in msg["properties"]:
if field["zero"] == "None":
output["typing_imports"].add("Optional")
break
data["methods"].append(
{
"name": method.name,
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_type_reference(
package, output["imports"], method.input_type
).strip('"'),
"input_message": input_message,
"output": get_type_reference(
package,
output["imports"],
method.output_type,
unwrap=False,
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
}
)
if method.client_streaming:
output["typing_imports"].add("AsyncIterable")
output["typing_imports"].add("Iterable")
output["typing_imports"].add("Union")
if method.server_streaming:
output["typing_imports"].add("AsyncIterator")
output["services"].append(data)
output["imports"] = sorted(output["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"])
# Fill response
f = response.file.add()
f.name = filename
# Render and then format the output file.
f.content = black.format_str(
template.render(description=output),
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
)
# Make each output directory a package with __init__ file
output_paths = set(pathlib.Path(path) for path in output_map.keys())
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
for filename in sorted(output_paths.union(init_files)):
print(f"Writing {filename}", file=sys.stderr)
def main():
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)
if __name__ == "__main__":
main()

View File

@@ -1,135 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto
from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if description.services %}
import grpclib
{% endif %}
{% for i in description.imports %}
{{ i }}
{% endfor %}
{% if description.enums %}{% for enum in description.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
{% endif %}
{% for entry in enum.entries %}
{% if entry.comment %}
{{ entry.comment }}
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in description.messages %}
@dataclass
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
{% endif %}
{% for field in message.properties %}
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %}
{% if not message.properties %}
pass
{% endif %}
{% endfor %}
{% for service in description.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.input_message and method.input_message.properties -%}, *,
{%- for field in method.input_message.properties -%}
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
Optional[{{ field.type }}]
{%- else -%}
{{ field.type }}
{%- endif -%} = {{ field.zero }}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{% if not method.client_streaming %}
request = {{ method.input }}()
{% for field in method.input_message.properties %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }},
):
yield response
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.output }},
):
yield response
{% endif %}{# if client streaming #}
{% else %}{# i.e. not server streaming #}
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.output }}
)
{% endif %}{# client streaming #}
{% endif %}
{% endfor %}
{% endfor %}

View File

@@ -1,3 +0,0 @@
{
"greeting": "HEY"
}

View File

@@ -1,14 +0,0 @@
syntax = "proto3";
// Enum for the different greeting types
enum Greeting {
HI = 0;
HEY = 1;
// Formal greeting
HELLO = 2;
}
message Test {
// Greeting enum example
Greeting greeting = 1;
}

View File

@@ -1,15 +0,0 @@
syntax = "proto3";
import "request_message.proto";
// Tests generated service correctly imports the RequestMessage
service Test {
rpc DoThing (RequestMessage) returns (RequestResponse);
}
message RequestResponse {
int32 value = 1;
}

View File

@@ -1,16 +0,0 @@
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.import_service_input_message import (
RequestResponse,
TestStub,
)
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
@pytest.mark.asyncio
async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing()
assert mock_response == response

31
docs/api.rst Normal file
View File

@@ -0,0 +1,31 @@
.. currentmodule:: betterproto
API reference
=============
The following document outlines betterproto's api. **None** of these classes should be
extended by the user manually.
Message
--------
.. autoclass:: betterproto.Message
:members:
:special-members: __bytes__, __bool__
.. autofunction:: betterproto.serialized_on_wire
.. autofunction:: betterproto.which_one_of
Enumerations
-------------
.. autoclass:: betterproto.Enum()
:members:
.. autoclass:: betterproto.Casing()
:members:

60
docs/conf.py Normal file
View File

@@ -0,0 +1,60 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import pathlib
import toml
# -- Project information -----------------------------------------------------
project = "betterproto"
copyright = "2019 Daniel G. Taylor"
author = "danielgtaylor"
pyproject = toml.load(open(pathlib.Path(__file__).parent.parent / "pyproject.toml"))
# The full version, including alpha/beta/rc tags.
release = pyproject["tool"]["poetry"]["version"]
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
]
autodoc_member_order = "bysource"
autodoc_typehints = "none"
extlinks = {
"issue": ("https://github.com/danielgtaylor/python-betterproto/issues/%s", "GH-"),
}
# Links used for cross-referencing stuff in other documentation
intersphinx_mapping = {
"py": ("https://docs.python.org/3", None),
}
# -- Options for HTML output -------------------------------------------------
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "friendly"
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = "sphinx_rtd_theme"

33
docs/index.rst Normal file
View File

@@ -0,0 +1,33 @@
Welcome to betterproto's documentation!
=======================================
betterproto is a protobuf compiler and interpreter. It improves the experience of using
Protobuf and gRPC in Python, by generating readable, understandable, and idiomatic
Python code, using modern language features.
Features:
~~~~~~~~~
- Generated messages are both binary & JSON serializable
- Messages use relevant python types, e.g. ``Enum``, ``datetime`` and ``timedelta``
objects
- ``async``/``await`` support for gRPC Clients
- Generates modern, readable, idiomatic python code
Contents:
~~~~~~~~~
.. toctree::
:maxdepth: 2
quick-start
api
migrating
If you still can't find what you're looking for, try in one of the following pages:
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

157
docs/migrating.rst Normal file
View File

@@ -0,0 +1,157 @@
Migrating Guide
===============
Google's protocolbuffers
------------------------
betterproto has a mostly 1 to 1 drop in replacement for Google's protocolbuffers (after
regenerating your protobufs of course) although there are some minor differences.
.. note::
betterproto implements the same basic methods including:
- :meth:`betterproto.Message.FromString`
- :meth:`betterproto.Message.SerializeToString`
for compatibility purposes, however it is important to note that these are
effectively aliases for :meth:`betterproto.Message.parse` and
:meth:`betterproto.Message.__bytes__` respectively.
Determining if a message was sent
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes it is useful to be able to determine whether a message has been sent on
the wire. This is how the Google wrapper types work to let you know whether a value is
unset (set as the default/zero value), or set as something else, for example.
Use ``betterproto.serialized_on_wire(message)`` to determine if it was sent. This is
a little bit different from the official Google generated Python code, and it lives
outside the generated ``Message`` class to prevent name clashes. Note that it only
supports Proto 3 and thus can only be used to check if ``Message`` fields are set.
You cannot check if a scalar was sent on the wire.
.. code-block:: python
# Old way (official Google Protobuf package)
>>> mymessage.HasField('myfield')
True
# New way (this project)
>>> betterproto.serialized_on_wire(mymessage.myfield)
True
One-of Support
~~~~~~~~~~~~~~
Protobuf supports grouping fields in a oneof clause. Only one of the fields in the group
may be set at a given time. For example, given the proto:
.. code-block:: proto
syntax = "proto3";
message Test {
oneof foo {
bool on = 1;
int32 count = 2;
string name = 3;
}
}
You can use ``betterproto.which_one_of(message, group_name)`` to determine which of the
fields was set. It returns a tuple of the field name and value, or a blank string and
``None`` if unset. Again this is a little different than the official Google code
generator:
.. code-block:: python
# Old way (official Google protobuf package)
>>> message.WhichOneof("group")
"foo"
# New way (this project)
>>> betterproto.which_one_of(message, "group")
("foo", "foo's value")
Well-Known Google Types
~~~~~~~~~~~~~~~~~~~~~~~
Google provides several well-known message types like a timestamp, duration, and several
wrappers used to provide optional zero value support. Each of these has a special JSON
representation and is handled a little differently from normal messages. The Python
mapping for these is as follows:
+-------------------------------+-----------------------------------------------+--------------------------+
| ``Google Message`` | ``Python Type`` | ``Default`` |
+===============================+===============================================+==========================+
| ``google.protobuf.duration`` | :class:`datetime.timedelta` | ``0`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.timestamp`` | ``Timezone-aware`` :class:`datetime.datetime` | ``1970-01-01T00:00:00Z`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.*Value`` | ``Optional[...]``/``None`` | ``None`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.*`` | ``betterproto.lib.google.protobuf.*`` | ``None`` |
+-------------------------------+-----------------------------------------------+--------------------------+
For the wrapper types, the Python type corresponds to the wrapped type, e.g.
``google.protobuf.BoolValue`` becomes ``Optional[bool]`` while
``google.protobuf.Int32Value`` becomes ``Optional[int]``. All of the optional values
default to None, so don't forget to check for that possible state.
Given:
.. code-block:: proto
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
}
You can use it as such:
.. code-block:: python
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
>>> t
Test(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
>>> t.ts - t.duration
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
>>> t.ts.isoformat()
'2019-01-01T12:00:00+00:00'
>>> t.maybe = None
>>> t.to_dict()
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
[1.2.5] to [2.0.0b1]
--------------------
Updated package structures
~~~~~~~~~~~~~~~~~~~~~~~~~~
Generated code now strictly follows the *package structure* of the ``.proto`` files.
Consequently ``.proto`` files without a package will be combined in a single
``__init__.py`` file. To avoid overwriting existing ``__init__.py`` files, its best
to compile into a dedicated subdirectory.
Upgrading:
- Remove your previously compiled ``.py`` files.
- Create a new *empty* directory, e.g. ``generated`` or ``lib/generated/proto`` etc.
- Regenerate your python files into this directory
- Update import statements, e.g. ``import ExampleMessage from generated``

192
docs/quick-start.rst Normal file
View File

@@ -0,0 +1,192 @@
Getting Started
===============
Installation
++++++++++++
Installation from PyPI is as simple as running:
.. code-block:: sh
python3 -m pip install -U betterproto
If you are using Windows, then the following should be used instead:
.. code-block:: sh
py -3 -m pip install -U betterproto
To include the protoc plugin, install betterproto[compiler] instead of betterproto,
e.g.
.. code-block:: sh
python3 -m pip install -U "betterproto[compiler]"
Compiling proto files
+++++++++++++++++++++
Given you installed the compiler and have a proto file, e.g ``example.proto``:
.. code-block:: proto
syntax = "proto3";
package hello;
// Greeting represents a message you can tell a user.
message Greeting {
string message = 1;
}
To compile the proto you would run the following:
You can run the following to invoke protoc directly:
.. code-block:: sh
mkdir hello
protoc -I . --python_betterproto_out=lib example.proto
or run the following to invoke protoc via grpcio-tools:
.. code-block:: sh
pip install grpcio-tools
python -m grpc_tools.protoc -I . --python_betterproto_out=lib example.proto
This will generate ``lib/__init__.py`` which looks like:
.. code-block:: python
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: example.proto
# plugin: python-betterproto
from dataclasses import dataclass
import betterproto
@dataclass
class Greeting(betterproto.Message):
"""Greeting represents a message you can tell a user."""
message: str = betterproto.string_field(1)
Then to use it:
.. code-block:: python
>>> from lib import Greeting
>>> test = Greeting()
>>> test
Greeting(message='')
>>> test.message = "Hey!"
>>> test
Greeting(message="Hey!")
>>> bytes(test)
b'\n\x04Hey!'
>>> Greeting().parse(serialized)
Greeting(message="Hey!")
Async gRPC Support
++++++++++++++++++
The generated code includes `grpclib <https://grpclib.readthedocs.io/en/latest>`_ based
stub (client) classes for rpc services declared in the input proto files.
It is enabled by default.
Given a service definition similar to the one below:
.. code-block:: proto
syntax = "proto3";
package echo;
message EchoRequest {
string value = 1;
// Number of extra times to echo
uint32 extra_times = 2;
}
message EchoResponse {
repeated string values = 1;
}
message EchoStreamResponse {
string value = 1;
}
service Echo {
rpc Echo(EchoRequest) returns (EchoResponse);
rpc EchoStream(EchoRequest) returns (stream EchoStreamResponse);
}
The generated client can be used like so:
.. code-block:: python
import asyncio
from grpclib.client import Channel
import echo
async def main():
channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1)
print(response)
async for response in service.echo_stream(value="hello", extra_times=1):
print(response)
# don't forget to close the channel when you're done!
channel.close()
asyncio.run(main()) # python 3.7 only
# outputs
EchoResponse(values=['hello', 'hello'])
EchoStreamResponse(value='hello')
EchoStreamResponse(value='hello')
JSON
++++
Message objects include :meth:`betterproto.Message.to_json` and
:meth:`betterproto.Message.from_json` methods for JSON (de)serialisation, and
:meth:`betterproto.Message.to_dict`, :meth:`betterproto.Message.from_dict` for
converting back and forth from JSON serializable dicts.
For compatibility the default is to convert field names to
:attr:`betterproto.Casing.CAMEL`. You can control this behavior by passing a
different casing value, e.g:
.. code-block:: python
@dataclass
class MyMessage(betterproto.Message):
a_long_field_name: str = betterproto.string_field(1)
>>> test = MyMessage(a_long_field_name="Hello World!")
>>> test.to_dict(betterproto.Casing.SNAKE)
{"a_long_field_name": "Hello World!"}
>>> test.to_dict(betterproto.Casing.CAMEL)
{"aLongFieldName": "Hello World!"}
>>> test.to_json(indent=2)
'{\n "aLongFieldName": "Hello World!"\n}'
>>> test.from_dict({"aLongFieldName": "Goodbye World!"})
>>> test.a_long_field_name
"Goodbye World!"

View File

@@ -1,16 +0,0 @@
# Upgrade Guide
## [1.2.5] to [2.0.0b1]
### Updated package structures
Generated code now strictly follows the *package structure* of the `.proto` files.
Consequently `.proto` files without a package will be combined in a single `__init__.py` file.
To avoid overwriting existing `__init__.py` files, its best to compile into a dedicated subdirectory.
Upgrading:
- Remove your previously compiled `.py` files.
- Create a new *empty* directory, e.g. `generated` or `lib/generated/proto` etcetera.
- Regenerate your python files into this directory
- Update import statements, e.g. `import ExampleMessage from generated`

1230
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,34 +1,41 @@
[tool.poetry] [tool.poetry]
name = "betterproto" name = "betterproto"
version = "2.0.0b1" version = "2.0.0b2"
description = "A better Protobuf / gRPC generator & library" description = "A better Protobuf / gRPC generator & library"
authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"] authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"]
readme = "README.md" readme = "README.md"
repository = "https://github.com/danielgtaylor/python-betterproto" repository = "https://github.com/danielgtaylor/python-betterproto"
keywords = ["protobuf", "gRPC"] keywords = ["protobuf", "gRPC"]
license = "MIT" license = "MIT"
packages = [
exclude = ["betterproto/tests"] { include = "betterproto", from = "src" }
]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.6" python = "^3.6"
backports-datetime-fromisoformat = { version = "^1.0.0", python = "<3.7" } backports-datetime-fromisoformat = { version = "^1.0.0", python = "<3.7" }
black = { version = "^19.10b0", optional = true } black = { version = ">=19.3b0", optional = true }
dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpclib = "^0.3.1" grpclib = "^0.4.1"
jinja2 = { version = "^2.11.2", optional = true } jinja2 = { version = "^2.11.2", optional = true }
protobuf = { version = "^3.12.2", optional = true } protobuf = { version = "^3.12.2", optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = "^19.10b0" black = "^20.8b1"
bpython = "^0.19" bpython = "^0.19"
grpcio-tools = "^1.30.0"
jinja2 = "^2.11.2" jinja2 = "^2.11.2"
mypy = "^0.770" mypy = "^0.770"
poethepoet = "^0.5.0"
protobuf = "^3.12.2" protobuf = "^3.12.2"
pytest = "^5.4.2" pytest = "^5.4.2"
pytest-asyncio = "^0.12.0" pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0" pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
tox = "^3.15.1" tox = "^3.15.1"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
asv = "^0.4.2"
[tool.poetry.scripts] [tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main" protoc-gen-python_betterproto = "betterproto.plugin:main"
@@ -36,6 +43,20 @@ protoc-gen-python_betterproto = "betterproto.plugin:main"
[tool.poetry.extras] [tool.poetry.extras]
compiler = ["black", "jinja2", "protobuf"] compiler = ["black", "jinja2", "protobuf"]
[tool.poe.tasks]
# Dev workflow tasks
generate = { script = "tests.generate:main", help = "Generate test cases (do this once before running test)" }
test = { cmd = "pytest --cov src", help = "Run tests" }
types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" }
format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" }
clean = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"}
bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"}
# CI tasks
full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" }
check-style = { cmd = "black . --check --diff --exclude tests/output_", help = "Check if code style is correct"}
[tool.black] [tool.black]
target-version = ['py36'] target-version = ['py36']
@@ -56,5 +77,5 @@ commands =
""" """
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry-core>=1.0.0,<2"]
build-backend = "poetry.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@@ -4,6 +4,7 @@ import inspect
import json import json
import struct import struct
import sys import sys
import typing
from abc import ABC from abc import ABC
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -22,10 +23,10 @@ from typing import (
) )
from ._types import T from ._types import T
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): if sys.version_info[:2] < (3, 7):
# Apply backport of datetime.fromisoformat from 3.7 # Apply backport of datetime.fromisoformat from 3.7
from backports.datetime_fromisoformat import MonkeyPatch from backports.datetime_fromisoformat import MonkeyPatch
@@ -109,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. # Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen(): def datetime_default_gen() -> datetime:
return datetime(1970, 1, 1, tzinfo=timezone.utc) return datetime(1970, 1, 1, tzinfo=timezone.utc)
@@ -119,15 +120,11 @@ DATETIME_ZERO = datetime_default_gen()
class Casing(enum.Enum): class Casing(enum.Enum):
"""Casing constants for serialization.""" """Casing constants for serialization."""
CAMEL = camel_case CAMEL = camel_case #: A camelCase sterilization function.
SNAKE = snake_case SNAKE = snake_case #: A snake_case sterilization function.
class _PLACEHOLDER: PLACEHOLDER: Any = object()
pass
PLACEHOLDER: Any = _PLACEHOLDER()
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@@ -251,14 +248,28 @@ def map_field(
) )
class Enum(int, enum.Enum): class Enum(enum.IntEnum):
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" """
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""
@classmethod @classmethod
def from_string(cls, name: str) -> int: def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.""" """Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try: try:
return cls.__members__[name] return cls._member_map_[name]
except KeyError as e: except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
@@ -304,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
return encode_varint(value) return encode_varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]: elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding. # Handle zig-zag encoding.
if value >= 0: return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
value = value << 1
else:
value = (value << 1) ^ (~0)
return encode_varint(value)
elif proto_type in FIXED_TYPES: elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value) return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING: elif proto_type == TYPE_STRING:
@@ -346,7 +353,7 @@ def _serialize_single(
"""Serializes a single field and value.""" """Serializes a single field and value."""
value = _preprocess_single(proto_type, wraps, value) value = _preprocess_single(proto_type, wraps, value)
output = b"" output = bytearray()
if proto_type in WIRE_VARINT_TYPES: if proto_type in WIRE_VARINT_TYPES:
key = encode_varint(field_number << 3) key = encode_varint(field_number << 3)
output += key + value output += key + value
@@ -363,10 +370,10 @@ def _serialize_single(
else: else:
raise NotImplementedError(proto_type) raise NotImplementedError(proto_type)
return output return bytes(output)
def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]: def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
""" """
Decode a single varint value from a byte buffer. Returns the value and the Decode a single varint value from a byte buffer. Returns the value and the
new position in the buffer. new position in the buffer.
@@ -378,7 +385,7 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
result |= (b & 0x7F) << shift result |= (b & 0x7F) << shift
pos += 1 pos += 1
if not (b & 0x80): if not (b & 0x80):
return (result, pos) return result, pos
shift += 7 shift += 7
if shift >= 64: if shift >= 64:
raise ValueError("Too many bytes when decoding varint.") raise ValueError("Too many bytes when decoding varint.")
@@ -401,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
decoded: Any = None decoded: Any = None
if wire_type == 0: if wire_type == WIRE_VARINT:
decoded, i = decode_varint(value, i) decoded, i = decode_varint(value, i)
elif wire_type == 1: elif wire_type == WIRE_FIXED_64:
decoded, i = value[i : i + 8], i + 8 decoded, i = value[i : i + 8], i + 8
elif wire_type == 2: elif wire_type == WIRE_LEN_DELIM:
length, i = decode_varint(value, i) length, i = decode_varint(value, i)
decoded = value[i : i + length] decoded = value[i : i + length]
i += length i += length
elif wire_type == 5: elif wire_type == WIRE_FIXED_32:
decoded, i = value[i : i + 4], i + 4 decoded, i = value[i : i + 4], i + 4
yield ParsedField( yield ParsedField(
@@ -418,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
class ProtoClassMetadata: class ProtoClassMetadata:
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = ( __slots__ = (
"oneof_group_by_field", "oneof_group_by_field",
"oneof_field_by_group", "oneof_field_by_group",
@@ -431,8 +432,17 @@ class ProtoClassMetadata:
"cls_by_field", "cls_by_field",
"field_name_by_number", "field_name_by_number",
"meta_by_field_name", "meta_by_field_name",
"sorted_field_names",
) )
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
sorted_field_names: Tuple[str, ...]
default_gen: Dict[str, Callable[[], Any]]
cls_by_field: Dict[str, Type]
def __init__(self, cls: Type["Message"]): def __init__(self, cls: Type["Message"]):
by_field = {} by_field = {}
by_group: Dict[str, Set] = {} by_group: Dict[str, Set] = {}
@@ -456,21 +466,22 @@ class ProtoClassMetadata:
self.oneof_field_by_group = by_group self.oneof_field_by_group = by_group
self.field_name_by_number = by_field_number self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name self.meta_by_field_name = by_field_name
self.sorted_field_names = tuple(
by_field_number[number] for number in sorted(by_field_number)
)
self.default_gen = self._get_default_gen(cls, fields) self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields)
@staticmethod @staticmethod
def _get_default_gen(cls, fields): def _get_default_gen(
default_gen = {} cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Callable[[], Any]]:
for field in fields: return {field.name: cls._get_field_default_gen(field) for field in fields}
default_gen[field.name] = cls._get_field_default_gen(field)
return default_gen
@staticmethod @staticmethod
def _get_cls_by_field(cls, fields): def _get_cls_by_field(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Type]:
field_cls = {} field_cls = {}
for field in fields: for field in fields:
@@ -479,7 +490,7 @@ class ProtoClassMetadata:
assert meta.map_types assert meta.map_types
kt = cls._cls_for(field, index=0) kt = cls._cls_for(field, index=0)
vt = cls._cls_for(field, index=1) vt = cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass( field_cls[field.name] = dataclasses.make_dataclass(
"Entry", "Entry",
[ [
("key", kt, dataclass_field(1, meta.map_types[0])), ("key", kt, dataclass_field(1, meta.map_types[0])),
@@ -487,8 +498,7 @@ class ProtoClassMetadata:
], ],
bases=(Message,), bases=(Message,),
) )
field_cls[field.name] = Entry field_cls[f"{field.name}.value"] = vt
field_cls[field.name + ".value"] = vt
else: else:
field_cls[field.name] = cls._cls_for(field) field_cls[field.name] = cls._cls_for(field)
@@ -497,9 +507,19 @@ class ProtoClassMetadata:
class Message(ABC): class Message(ABC):
""" """
A protobuf message base class. Generated code will inherit from this and The base class for protobuf messages, all generated messages will inherit from
register the message fields which get used by the serializers and parsers this. This class registers the message fields which are used by the serializers and
to go between Python, binary and JSON protobuf message representations. parsers to go between the Python, binary and JSON representations of the message.
.. container:: operations
.. describe:: bytes(x)
Calls :meth:`__bytes__`.
.. describe:: bool(x)
Calls :meth:`__bool__`.
""" """
_serialized_on_wire: bool _serialized_on_wire: bool
@@ -511,29 +531,69 @@ class Message(ABC):
all_sentinel = True all_sentinel = True
# Set current field of each group after `__init__` has already been run. # Set current field of each group after `__init__` has already been run.
group_current: Dict[str, str] = {} group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
if meta.group: if meta.group:
group_current.setdefault(meta.group) group_current.setdefault(meta.group)
if getattr(self, field_name) != PLACEHOLDER: if self.__raw_get(field_name) != PLACEHOLDER:
# Skip anything not set to the sentinel value # Found a non-sentinel value
all_sentinel = False all_sentinel = False
if meta.group: if meta.group:
# This was set, so make it the selected value of the one-of. # This was set, so make it the selected value of the one-of.
group_current[meta.group] = field_name group_current[meta.group] = field_name
continue
setattr(self, field_name, self._get_field_default(field_name))
# Now that all the defaults are set, reset it! # Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b"" self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_current"] = group_current self.__dict__["_group_current"] = group_current
def __raw_get(self, name: str) -> Any:
return super().__getattribute__(name)
def __eq__(self, other) -> bool:
if type(self) is not type(other):
return False
for field_name in self._betterproto.meta_by_field_name:
self_val = self.__raw_get(field_name)
other_val = other.__raw_get(field_name)
if self_val is PLACEHOLDER:
if other_val is PLACEHOLDER:
continue
self_val = self._get_field_default(field_name)
elif other_val is PLACEHOLDER:
other_val = other._get_field_default(field_name)
if self_val != other_val:
return False
return True
def __repr__(self) -> str:
parts = [
f"{field_name}={value!r}"
for field_name in self._betterproto.sorted_field_names
for value in (self.__raw_get(field_name),)
if value is not PLACEHOLDER
]
return f"{self.__class__.__name__}({', '.join(parts)})"
def __getattribute__(self, name: str) -> Any:
"""
Lazily initialize default values to avoid infinite recursion for recursive
message types
"""
value = super().__getattribute__(name)
if value is not PLACEHOLDER:
return value
value = self._get_field_default(name)
super().__setattr__(name, value)
return value
def __setattr__(self, attr: str, value: Any) -> None: def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire": if attr != "_serialized_on_wire":
# Track when a field has been set. # Track when a field has been set.
@@ -546,14 +606,20 @@ class Message(ABC):
if field.name == attr: if field.name == attr:
self._group_current[group] = field.name self._group_current[group] = field.name
else: else:
super().__setattr__( super().__setattr__(field.name, PLACEHOLDER)
field.name, self._get_field_default(field.name),
)
super().__setattr__(attr, value) super().__setattr__(attr, value)
def __bool__(self) -> bool:
"""True if the Message has any fields with non-default values."""
return any(
self.__raw_get(field_name)
not in (PLACEHOLDER, self._get_field_default(field_name))
for field_name in self._betterproto.meta_by_field_name
)
@property @property
def _betterproto(self): def _betterproto(self) -> ProtoClassMetadata:
""" """
Lazy initialize metadata for each protobuf class. Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment, It may be initialized multiple times in a multi-threaded environment,
@@ -567,9 +633,9 @@ class Message(ABC):
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
""" """
Get the binary encoded Protobuf representation of this instance. Get the binary encoded Protobuf representation of this message instance.
""" """
output = b"" output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name) value = getattr(self, field_name)
@@ -581,18 +647,20 @@ class Message(ABC):
# Being selected in a a group means this field is the one that is # Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even # currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value. # if the value is the default zero value.
selected_in_group = False selected_in_group = (
if meta.group and self._group_current[meta.group] == field_name: meta.group and self._group_current[meta.group] == field_name
selected_in_group = True )
serialize_empty = False # Empty messages can still be sent on the wire if they were
if isinstance(value, Message) and value._serialized_on_wire: # set (or received empty).
# Empty messages can still be sent on the wire if they were serialize_empty = isinstance(value, Message) and value._serialized_on_wire
# set (or recieved empty).
serialize_empty = True include_default_value_for_oneof = self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
if value == self._get_field_default(field_name) and not ( if value == self._get_field_default(field_name) and not (
selected_in_group or serialize_empty selected_in_group or serialize_empty or include_default_value_for_oneof
): ):
# Default (zero) values are not serialized. Two exceptions are # Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to # if this is the selected oneof item or if we know we have to
@@ -605,7 +673,7 @@ class Message(ABC):
# Packed lists look like a length-delimited field. First, # Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then # preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes. # treat it like a field of raw bytes.
buf = b"" buf = bytearray()
for item in value: for item in value:
buf += _preprocess_single(meta.proto_type, "", item) buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf) output += _serialize_single(meta.number, TYPE_BYTES, buf)
@@ -621,6 +689,17 @@ class Message(ABC):
sv = _serialize_single(2, meta.map_types[1], v) sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv) output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else: else:
# If we have an empty string and we're including the default value for
# a oneof, make sure we serialize it. This ensures that the byte string
# output isn't simply an empty string. This also ensures that round trip
# serialization will keep `which_one_of` calls consistent.
if (
isinstance(value, str)
and value == ""
and include_default_value_for_oneof
):
serialize_empty = True
output += _serialize_single( output += _serialize_single(
meta.number, meta.number,
meta.proto_type, meta.proto_type,
@@ -629,26 +708,44 @@ class Message(ABC):
wraps=meta.wraps or "", wraps=meta.wraps or "",
) )
return output + self._unknown_fields output += self._unknown_fields
return bytes(output)
# For compatibility with other libraries # For compatibility with other libraries
SerializeToString = __bytes__ def SerializeToString(self: T) -> bytes:
"""
Get the binary encoded Protobuf representation of this message instance.
.. note::
This is a method for compatibility with other libraries,
you should really use ``bytes(x)``.
Returns
--------
:class:`bytes`
The binary encoded Protobuf representation of this message instance
"""
return bytes(self)
@classmethod @classmethod
def _type_hint(cls, field_name: str) -> Type: def _type_hint(cls, field_name: str) -> Type:
module = inspect.getmodule(cls) return cls._type_hints()[field_name]
type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name] @classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = sys.modules[cls.__module__]
return get_type_hints(cls, vars(module))
@classmethod @classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints.""" """Get the message class for a field from the type hints."""
field_cls = cls._type_hint(field.name) field_cls = cls._type_hint(field.name)
if hasattr(field_cls, "__args__") and index >= 0: if hasattr(field_cls, "__args__") and index >= 0:
field_cls = field_cls.__args__[index] if field_cls.__args__ is not None:
field_cls = field_cls.__args__[index]
return field_cls return field_cls
def _get_field_default(self, field_name): def _get_field_default(self, field_name: str) -> Any:
return self._betterproto.default_gen[field_name]() return self._betterproto.default_gen[field_name]()
@classmethod @classmethod
@@ -662,7 +759,7 @@ class Message(ABC):
elif t.__origin__ in (list, List): elif t.__origin__ in (list, List):
# This is some kind of list (repeated) field. # This is some kind of list (repeated) field.
return list return list
elif t.__origin__ == Union and t.__args__[1] == type(None): elif t.__origin__ is Union and t.__args__[1] is type(None):
# This is an optional (wrapped) field. For setting the default we # This is an optional (wrapped) field. For setting the default we
# really don't care what kind of field it is. # really don't care what kind of field it is.
return type(None) return type(None)
@@ -671,7 +768,7 @@ class Message(ABC):
elif issubclass(t, Enum): elif issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
return int return int
elif t == datetime: elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen return datetime_default_gen
else: else:
@@ -720,18 +817,38 @@ class Message(ABC):
return value return value
def _include_default_value_for_oneof(
self, field_name: str, meta: FieldMetadata
) -> bool:
return (
meta.group is not None and self._group_current.get(meta.group) == field_name
)
def parse(self: T, data: bytes) -> T: def parse(self: T, data: bytes) -> T:
""" """
Parse the binary encoded Protobuf into this message instance. This Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
Parameters
-----------
data: :class:`bytes`
The data to parse the protobuf from.
Returns
--------
:class:`Message`
The initialized message.
""" """
# Got some data over the wire
self._serialized_on_wire = True
proto_meta = self._betterproto
for parsed in parse_fields(data): for parsed in parse_fields(data):
field_name = self._betterproto.field_name_by_number.get(parsed.number) field_name = proto_meta.field_name_by_number.get(parsed.number)
if not field_name: if not field_name:
self._unknown_fields += parsed.raw self._unknown_fields += parsed.raw
continue continue
meta = self._betterproto.meta_by_field_name[field_name] meta = proto_meta.meta_by_field_name[field_name]
value: Any value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
@@ -739,10 +856,10 @@ class Message(ABC):
pos = 0 pos = 0
value = [] value = []
while pos < len(parsed.value): while pos < len(parsed.value):
if meta.proto_type in ["float", "fixed32", "sfixed32"]: if meta.proto_type in [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4 decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32 wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]: elif meta.proto_type in [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8 decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64 wire_type = WIRE_FIXED_64
else: else:
@@ -771,80 +888,149 @@ class Message(ABC):
# For compatibility with other libraries. # For compatibility with other libraries.
@classmethod @classmethod
def FromString(cls: Type[T], data: bytes) -> T: def FromString(cls: Type[T], data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
.. note::
This is a method for compatibility with other libraries,
you should really use :meth:`parse`.
Parameters
-----------
data: :class:`bytes`
The data to parse the protobuf from.
Returns
--------
:class:`Message`
The initialized message.
"""
return cls().parse(data) return cls().parse(data)
def to_dict( def to_dict(
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Returns a dict representation of this message instance which can be Returns a JSON serializable dict representation of this object.
used to serialize to e.g. JSON. Defaults to camel casing for
compatibility but can be set to other modes.
`include_default_values` can be set to `True` to include default Parameters
values of fields. E.g. an `int32` type field with `0` value will -----------
not be in returned dict if `include_default_values` is set to casing: :class:`Casing`
`False`. The casing to use for key values. Default is :attr:`Casing.CAMEL` for
compatibility purposes.
include_default_values: :class:`bool`
If ``True`` will include the default values of fields. Default is ``False``.
E.g. an ``int32`` field will be included with a value of ``0`` if this is
set to ``True``, otherwise this would be ignored.
Returns
--------
Dict[:class:`str`, Any]
The JSON serializable dict representation of this object.
""" """
output: Dict[str, Any] = {} output: Dict[str, Any] = {}
field_types = self._type_hints()
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
v = getattr(self, field_name) field_is_repeated = defaults[field_name] is list
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == "message": if meta.proto_type == TYPE_MESSAGE:
if isinstance(v, datetime): if isinstance(value, datetime):
if v != DATETIME_ZERO or include_default_values: if (
output[cased_name] = _Timestamp.timestamp_to_json(v) value != DATETIME_ZERO
elif isinstance(v, timedelta): or include_default_values
if v != timedelta(0) or include_default_values: or self._include_default_value_for_oneof(
output[cased_name] = _Duration.delta_to_json(v) field_name=field_name, meta=meta
)
):
output[cased_name] = _Timestamp.timestamp_to_json(value)
elif isinstance(value, timedelta):
if (
value != timedelta(0)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
output[cased_name] = _Duration.delta_to_json(value)
elif meta.wraps: elif meta.wraps:
if v is not None or include_default_values: if value is not None or include_default_values:
output[cased_name] = v output[cased_name] = value
elif isinstance(v, list): elif field_is_repeated:
# Convert each item. # Convert each item.
v = [i.to_dict(casing, include_default_values) for i in v] value = [i.to_dict(casing, include_default_values) for i in value]
if v or include_default_values: if value or include_default_values:
output[cased_name] = v output[cased_name] = value
else: elif (
if v._serialized_on_wire or include_default_values: value._serialized_on_wire
output[cased_name] = v.to_dict(casing, include_default_values) or include_default_values
elif meta.proto_type == "map": or self._include_default_value_for_oneof(
for k in v: field_name=field_name, meta=meta
if hasattr(v[k], "to_dict"): )
v[k] = v[k].to_dict(casing, include_default_values) ):
output[cased_name] = value.to_dict(casing, include_default_values)
elif meta.proto_type == TYPE_MAP:
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)
if v or include_default_values: if value or include_default_values:
output[cased_name] = v output[cased_name] = value
elif v != self._get_field_default(field_name) or include_default_values: elif (
value != self._get_field_default(field_name)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
if meta.proto_type in INT_64_TYPES: if meta.proto_type in INT_64_TYPES:
if isinstance(v, list): if field_is_repeated:
output[cased_name] = [str(n) for n in v] output[cased_name] = [str(n) for n in value]
else: else:
output[cased_name] = str(v) output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES: elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list): if field_is_repeated:
output[cased_name] = [b64encode(b).decode("utf8") for b in v] output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
else: else:
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list( if field_is_repeated:
self._betterproto.cls_by_field[field_name] enum_class: Type[Enum] = field_types[field_name].__args__[0]
) # type: ignore if isinstance(value, typing.Iterable) and not isinstance(
if isinstance(v, list): value, str
output[cased_name] = [enum_values[e].name for e in v] ):
output[cased_name] = [enum_class(el).name for el in value]
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
else: else:
output[cased_name] = enum_values[v].name enum_class: Type[Enum] = field_types[field_name] # noqa
output[cased_name] = enum_class(value).name
else: else:
output[cased_name] = v output[cased_name] = value
return output return output
def from_dict(self: T, value: dict) -> T: def from_dict(self: T, value: Dict[str, Any]) -> T:
""" """
Parse the key/value pairs in `value` into this message instance. This Parse the key/value pairs into the current message instance. This returns the
returns the instance itself and is therefore assignable and chainable. instance itself and is therefore assignable and chainable.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
""" """
self._serialized_on_wire = True self._serialized_on_wire = True
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value: for key in value:
field_name = safe_snake_case(key) field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name) meta = self._betterproto.meta_by_field_name.get(field_name)
@@ -852,12 +1038,12 @@ class Message(ABC):
continue continue
if value[key] is not None: if value[key] is not None:
if meta.proto_type == "message": if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name) v = getattr(self, field_name)
if isinstance(v, list): if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name] cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])): for item in value[key]:
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(item))
elif isinstance(v, datetime): elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field_name, v) setattr(self, field_name, v)
@@ -867,10 +1053,12 @@ class Message(ABC):
elif meta.wraps: elif meta.wraps:
setattr(self, field_name, value[key]) setattr(self, field_name, value[key])
else: else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key]) v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name) v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[field_name + ".value"] cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]: for k in value[key]:
v[k] = cls().from_dict(value[key][k]) v[k] = cls().from_dict(value[key][k])
else: else:
@@ -892,50 +1080,92 @@ class Message(ABC):
elif isinstance(v, str): elif isinstance(v, str):
v = enum_cls.from_string(v) v = enum_cls.from_string(v)
if v is not None: if v is not None:
setattr(self, field_name, v) setattr(self, field_name, v)
return self return self
def to_json(self, indent: Union[None, int, str] = None) -> str: def to_json(self, indent: Union[None, int, str] = None) -> str:
"""Returns the encoded JSON representation of this message instance.""" """A helper function to parse the message instance into its JSON
representation.
This is equivalent to::
json.dumps(message.to_dict(), indent=indent)
Parameters
-----------
indent: Optional[Union[:class:`int`, :class:`str`]]
The indent to pass to :func:`json.dumps`.
Returns
--------
:class:`str`
The JSON representation of the message.
"""
return json.dumps(self.to_dict(), indent=indent) return json.dumps(self.to_dict(), indent=indent)
def from_json(self: T, value: Union[str, bytes]) -> T: def from_json(self: T, value: Union[str, bytes]) -> T:
""" """A helper function to return the message instance from its JSON
Parse the key/value pairs in `value` into this message instance. This representation. This returns the instance itself and is therefore assignable
returns the instance itself and is therefore assignable and chainable. and chainable.
This is equivalent to::
return message.from_dict(json.loads(value))
Parameters
-----------
value: Union[:class:`str`, :class:`bytes`]
The value to pass to :func:`json.loads`.
Returns
--------
:class:`Message`
The initialized message.
""" """
return self.from_dict(json.loads(value)) return self.from_dict(json.loads(value))
def serialized_on_wire(message: Message) -> bool: def serialized_on_wire(message: Message) -> bool:
""" """
True if this message was or should be serialized on the wire. This can If this message was or should be serialized on the wire. This can be used to detect
be used to detect presence (e.g. optional wrapper message) and is used presence (e.g. optional wrapper message) and is used internally during
internally during parsing/serialization. parsing/serialization.
Returns
--------
:class:`bool`
Whether this message was or should be serialized on the wire.
""" """
return message._serialized_on_wire return message._serialized_on_wire
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
"""Return the name and value of a message's one-of field group.""" """
Return the name and value of a message's one-of field group.
Returns
--------
Tuple[:class:`str`, Any]
The field name and the value for that field.
"""
field_name = message._group_current.get(group_name) field_name = message._group_current.get(group_name)
if not field_name: if not field_name:
return ("", None) return "", None
return (field_name, getattr(message, field_name)) return field_name, getattr(message, field_name)
# Circular import workaround: google.protobuf depends on base classes defined above. # Circular import workaround: google.protobuf depends on base classes defined above.
from .lib.google.protobuf import ( from .lib.google.protobuf import ( # noqa
Duration,
Timestamp,
BoolValue, BoolValue,
BytesValue, BytesValue,
DoubleValue, DoubleValue,
Duration,
FloatValue, FloatValue,
Int32Value, Int32Value,
Int64Value, Int64Value,
StringValue, StringValue,
Timestamp,
UInt32Value, UInt32Value,
UInt64Value, UInt64Value,
) )
@@ -950,8 +1180,8 @@ class _Duration(Duration):
parts = str(delta.total_seconds()).split(".") parts = str(delta.total_seconds()).split(".")
if len(parts) > 1: if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]: while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0" parts[1] = f"{parts[1]}0"
return ".".join(parts) + "s" return f"{'.'.join(parts)}s"
class _Timestamp(Timestamp): class _Timestamp(Timestamp):
@@ -967,15 +1197,15 @@ class _Timestamp(Timestamp):
if (nanos % 1e9) == 0: if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional # If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing. # point '.' should be omitted when serializing.
return result + "Z" return f"{result}Z"
if (nanos % 1e6) == 0: if (nanos % 1e6) == 0:
# Serialize 3 fractional digits. # Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6) return f"{result}.{int(nanos // 1e6) :03d}Z"
if (nanos % 1e3) == 0: if (nanos % 1e3) == 0:
# Serialize 6 fractional digits. # Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3) return f"{result}.{int(nanos // 1e3) :06d}Z"
# Serialize 9 fractional digits. # Serialize 9 fractional digits.
return result + ".%09dZ" % nanos return f"{result}.{nanos:09d}"
class _WrappedMessage(Message): class _WrappedMessage(Message):

View File

@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib._typing import IProtoMessage
from . import Message from . import Message
from grpclib._protocols import IProtoMessage
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message") T = TypeVar("T", bound="Message")

View File

@@ -1,3 +1,4 @@
import keyword
import re import re
# Word delimiters and symbols that will not be preserved when re-casing. # Word delimiters and symbols that will not be preserved when re-casing.
@@ -16,51 +17,28 @@ WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
def safe_snake_case(value: str) -> str: def safe_snake_case(value: str) -> str:
"""Snake case a value taking into account Python keywords.""" """Snake case a value taking into account Python keywords."""
value = snake_case(value) value = snake_case(value)
if value in [ value = sanitize_name(value)
"and",
"as",
"assert",
"break",
"class",
"continue",
"def",
"del",
"elif",
"else",
"except",
"finally",
"for",
"from",
"global",
"if",
"import",
"in",
"is",
"lambda",
"nonlocal",
"not",
"or",
"pass",
"raise",
"return",
"try",
"while",
"with",
"yield",
]:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
value += "_"
return value return value
def snake_case(value: str, strict: bool = True): def snake_case(value: str, strict: bool = True) -> str:
""" """
Join words with an underscore into lowercase and remove symbols. Join words with an underscore into lowercase and remove symbols.
@param value: value to convert
@param strict: force single underscores Parameters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to force single underscores.
Returns
--------
:class:`str`
The value in snake_case.
""" """
def substitute_word(symbols, word, is_start): def substitute_word(symbols: str, word: str, is_start: bool) -> str:
if not word: if not word:
return "" return ""
if strict: if strict:
@@ -84,11 +62,21 @@ def snake_case(value: str, strict: bool = True):
return snake return snake
def pascal_case(value: str, strict: bool = True): def pascal_case(value: str, strict: bool = True) -> str:
""" """
Capitalize each word and remove symbols. Capitalize each word and remove symbols.
@param value: value to convert
@param strict: output only alphanumeric characters Parameters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to output only alphanumeric characters.
Returns
--------
:class:`str`
The value in PascalCase.
""" """
def substitute_word(symbols, word): def substitute_word(symbols, word):
@@ -109,12 +97,42 @@ def pascal_case(value: str, strict: bool = True):
) )
def camel_case(value: str, strict: bool = True): def camel_case(value: str, strict: bool = True) -> str:
""" """
Capitalize all words except first and remove symbols. Capitalize all words except first and remove symbols.
Parameters
-----------
value: :class:`str`
The value to convert.
strict: :class:`bool`
Whether or not to output only alphanumeric characters.
Returns
--------
:class:`str`
The value in camelCase.
""" """
return lowercase_first(pascal_case(value, strict=strict)) return lowercase_first(pascal_case(value, strict=strict))
def lowercase_first(value: str): def lowercase_first(value: str) -> str:
"""
Lower cases the first character of the value.
Parameters
----------
value: :class:`str`
The value to lower case.
Returns
-------
:class:`str`
The lower cased string.
"""
return value[0:1].lower() + value[1:] return value[0:1].lower() + value[1:]
def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if keyword.iskeyword(value) else value

View File

@@ -1,10 +1,10 @@
import os import os
import re import re
from typing import Dict, List, Set, Type from typing import Dict, List, Set, Tuple, Type
from betterproto import safe_snake_case from ..casing import safe_snake_case
from betterproto.compile.naming import pythonize_class_name from ..lib.google import protobuf as google_protobuf
from betterproto.lib.google import protobuf as google_protobuf from .naming import pythonize_class_name
WRAPPER_TYPES: Dict[str, Type] = { WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
@@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
} }
def parse_source_type_name(field_type_name): def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
""" """
Split full source type name into package and type name. Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message') E.g. 'root.package.Message' -> ('root.package', 'Message')
@@ -36,7 +36,7 @@ def parse_source_type_name(field_type_name):
def get_type_reference( def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True, package: str, imports: set, source_type: str, unwrap: bool = True
) -> str: ) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
@@ -50,7 +50,7 @@ def get_type_reference(
if source_type == ".google.protobuf.Duration": if source_type == ".google.protobuf.Duration":
return "timedelta" return "timedelta"
if source_type == ".google.protobuf.Timestamp": elif source_type == ".google.protobuf.Timestamp":
return "datetime" return "datetime"
source_package, source_type = parse_source_type_name(source_type) source_package, source_type = parse_source_type_name(source_type)
@@ -79,14 +79,14 @@ def get_type_reference(
return reference_cousin(current_package, imports, py_package, py_type) return reference_cousin(current_package, imports, py_package, py_type)
def reference_absolute(imports, py_package, py_type): def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
""" """
Returns a reference to a python type located in the root, i.e. sys.path. Returns a reference to a python type located in the root, i.e. sys.path.
""" """
string_import = ".".join(py_package) string_import = ".".join(py_package)
string_alias = safe_snake_case(string_import) string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}") imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}" return f'"{string_alias}.{py_type}"'
def reference_sibling(py_type: str) -> str: def reference_sibling(py_type: str) -> str:
@@ -109,10 +109,10 @@ def reference_descendent(
if string_from: if string_from:
string_alias = "_".join(importing_descendent) string_alias = "_".join(importing_descendent)
imports.add(f"from .{string_from} import {string_import} as {string_alias}") imports.add(f"from .{string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}" return f'"{string_alias}.{py_type}"'
else: else:
imports.add(f"from . import {string_import}") imports.add(f"from . import {string_import}")
return f"{string_import}.{py_type}" return f'"{string_import}.{py_type}"'
def reference_ancestor( def reference_ancestor(
@@ -130,11 +130,11 @@ def reference_ancestor(
string_alias = f"_{'_' * distance_up}{string_import}__" string_alias = f"_{'_' * distance_up}{string_import}__"
string_from = f"..{'.' * distance_up}" string_from = f"..{'.' * distance_up}"
imports.add(f"from {string_from} import {string_import} as {string_alias}") imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}" return f'"{string_alias}.{py_type}"'
else: else:
string_alias = f"{'_' * distance_up}{py_type}__" string_alias = f"{'_' * distance_up}{py_type}__"
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
return string_alias return f'"{string_alias}"'
def reference_cousin( def reference_cousin(
@@ -157,4 +157,4 @@ def reference_cousin(
+ "__" + "__"
) )
imports.add(f"from {string_from} import {string_import} as {string_alias}") imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}" return f'"{string_alias}.{py_type}"'

View File

@@ -1,13 +1,13 @@
from betterproto import casing from betterproto import casing
def pythonize_class_name(name): def pythonize_class_name(name: str) -> str:
return casing.pascal_case(name) return casing.pascal_case(name)
def pythonize_field_name(name: str): def pythonize_field_name(name: str) -> str:
return casing.safe_snake_case(name) return casing.safe_snake_case(name)
def pythonize_method_name(name: str): def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name) return casing.safe_snake_case(name)

View File

@@ -1,8 +1,7 @@
from abc import ABC
import asyncio import asyncio
import grpclib.const from abc import ABC
from typing import ( from typing import (
Any, TYPE_CHECKING,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Collection, Collection,
@@ -10,21 +9,23 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
TYPE_CHECKING,
Type, Type,
Union, Union,
) )
import grpclib.const
from .._types import ST, T from .._types import ST, T
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage from grpclib.client import Channel
from grpclib.client import Channel, Stream
from grpclib.metadata import Deadline from grpclib.metadata import Deadline
_Value = Union[str, bytes] _Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] _MessageLike = Union[T, ST]
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
class ServiceStub(ABC): class ServiceStub(ABC):
@@ -60,7 +61,7 @@ class ServiceStub(ABC):
async def _unary_unary( async def _unary_unary(
self, self,
route: str, route: str,
request: "IProtoMessage", request: _MessageLike,
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@@ -77,13 +78,13 @@ class ServiceStub(ABC):
) as stream: ) as stream:
await stream.send_message(request, end=True) await stream.send_message(request, end=True)
response = await stream.recv_message() response = await stream.recv_message()
assert response is not None assert response is not None
return response return response
async def _unary_stream( async def _unary_stream(
self, self,
route: str, route: str,
request: "IProtoMessage", request: _MessageLike,
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@@ -123,8 +124,8 @@ class ServiceStub(ABC):
) as stream: ) as stream:
await self._send_messages(stream, request_iterator) await self._send_messages(stream, request_iterator)
response = await stream.recv_message() response = await stream.recv_message()
assert response is not None assert response is not None
return response return response
async def _stream_stream( async def _stream_stream(
self, self,

View File

@@ -1,12 +1,5 @@
import asyncio import asyncio
from typing import ( from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)
T = TypeVar("T") T = TypeVar("T")
@@ -16,57 +9,53 @@ class ChannelClosed(Exception):
An exception raised on an attempt to send through a closed channel An exception raised on an attempt to send through a closed channel
""" """
pass
class ChannelDone(Exception): class ChannelDone(Exception):
""" """
An exception raised on an attempt to send recieve from a channel that is both closed An exception raised on an attempt to send receive from a channel that is both closed
and empty. and empty.
""" """
pass
class AsyncChannel(AsyncIterable[T]): class AsyncChannel(AsyncIterable[T]):
""" """
A buffered async channel for sending items between coroutines with FIFO ordering. A buffered async channel for sending items between coroutines with FIFO ordering.
This makes decoupled bidirection steaming gRPC requests easy if used like: This makes decoupled bidirectional steaming gRPC requests easy if used like:
.. code-block:: python .. code-block:: python
client = GeneratedStub(grpclib_chan) client = GeneratedStub(grpclib_chan)
request_chan = await AsyncChannel() request_channel = await AsyncChannel()
# We can start be sending all the requests we already have # We can start be sending all the requests we already have
await request_chan.send_from([ReqestObject(...), ReqestObject(...)]) await request_channel.send_from([RequestObject(...), RequestObject(...)])
async for response in client.rpc_call(request_chan): async for response in client.rpc_call(request_channel):
# The response iterator will remain active until the connection is closed # The response iterator will remain active until the connection is closed
... ...
# More items can be sent at any time # More items can be sent at any time
await request_chan.send(ReqestObject(...)) await request_channel.send(RequestObject(...))
... ...
# The channel must be closed to complete the gRPC connection # The channel must be closed to complete the gRPC connection
request_chan.close() request_channel.close()
Items can be sent through the channel by either: Items can be sent through the channel by either:
- providing an iterable to the send_from method - providing an iterable to the send_from method
- passing them to the send method one at a time - passing them to the send method one at a time
Items can be recieved from the channel by either: Items can be received from the channel by either:
- iterating over the channel with a for loop to get all items - iterating over the channel with a for loop to get all items
- calling the recieve method to get one item at a time - calling the receive method to get one item at a time
If the channel is empty then recievers will wait until either an item appears or the If the channel is empty then receivers will wait until either an item appears or the
channel is closed. channel is closed.
Once the channel is closed then subsequent attempt to send through the channel will Once the channel is closed then subsequent attempt to send through the channel will
fail with a ChannelClosed exception. fail with a ChannelClosed exception.
When th channel is closed and empty then it is done, and further attempts to recieve When th channel is closed and empty then it is done, and further attempts to receive
from it will fail with a ChannelDone exception from it will fail with a ChannelDone exception
If multiple coroutines recieve from the channel concurrently, each item sent will be If multiple coroutines receive from the channel concurrently, each item sent will be
recieved by only one of the recievers. received by only one of the receivers.
:param source: :param source:
An optional iterable will items that should be sent through the channel An optional iterable will items that should be sent through the channel
@@ -74,18 +63,16 @@ class AsyncChannel(AsyncIterable[T]):
:param buffer_limit: :param buffer_limit:
Limit the number of items that can be buffered in the channel, A value less than Limit the number of items that can be buffered in the channel, A value less than
1 implies no limit. If the channel is full then attempts to send more items will 1 implies no limit. If the channel is full then attempts to send more items will
result in the sender waiting until an item is recieved from the channel. result in the sender waiting until an item is received from the channel.
:param close: :param close:
If set to True then the channel will automatically close after exhausting source If set to True then the channel will automatically close after exhausting source
or immediately if no source is provided. or immediately if no source is provided.
""" """
def __init__( def __init__(self, *, buffer_limit: int = 0, close: bool = False):
self, *, buffer_limit: int = 0, close: bool = False,
):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False self._closed = False
self._waiting_recievers: int = 0 self._waiting_receivers: int = 0
# Track whether flush has been invoked so it can only happen once # Track whether flush has been invoked so it can only happen once
self._flushed = False self._flushed = False
@@ -95,14 +82,14 @@ class AsyncChannel(AsyncIterable[T]):
async def __anext__(self) -> T: async def __anext__(self) -> T:
if self.done(): if self.done():
raise StopAsyncIteration raise StopAsyncIteration
self._waiting_recievers += 1 self._waiting_receivers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
raise StopAsyncIteration raise StopAsyncIteration
return result return result
finally: finally:
self._waiting_recievers -= 1 self._waiting_receivers -= 1
self._queue.task_done() self._queue.task_done()
def closed(self) -> bool: def closed(self) -> bool:
@@ -116,12 +103,12 @@ class AsyncChannel(AsyncIterable[T]):
Check if this channel is done. Check if this channel is done.
:return: True if this channel is closed and and has been drained of items in :return: True if this channel is closed and and has been drained of items in
which case any further attempts to recieve an item from this channel will raise which case any further attempts to receive an item from this channel will raise
a ChannelDone exception. a ChannelDone exception.
""" """
# After close the channel is not yet done until there is at least one waiting # After close the channel is not yet done until there is at least one waiting
# reciever per enqueued item. # receiver per enqueued item.
return self._closed and self._queue.qsize() <= self._waiting_recievers return self._closed and self._queue.qsize() <= self._waiting_receivers
async def send_from( async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
@@ -158,22 +145,22 @@ class AsyncChannel(AsyncIterable[T]):
await self._queue.put(item) await self._queue.put(item)
return self return self
async def recieve(self) -> Optional[T]: async def receive(self) -> Optional[T]:
""" """
Returns the next item from this channel when it becomes available, Returns the next item from this channel when it becomes available,
or None if the channel is closed before another item is sent. or None if the channel is closed before another item is sent.
:return: An item from the channel :return: An item from the channel
""" """
if self.done(): if self.done():
raise ChannelDone("Cannot recieve from a closed channel") raise ChannelDone("Cannot receive from a closed channel")
self._waiting_recievers += 1 self._waiting_receivers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
return None return None
return result return result
finally: finally:
self._waiting_recievers -= 1 self._waiting_receivers -= 1
self._queue.task_done() self._queue.task_done()
def close(self): def close(self):
@@ -190,8 +177,8 @@ class AsyncChannel(AsyncIterable[T]):
""" """
if not self._flushed: if not self._flushed:
self._flushed = True self._flushed = True
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
for _ in range(deadlocked_recievers): for _ in range(deadlocked_receivers):
await self._queue.put(self.__flush) await self._queue.put(self.__flush)
# A special signal object for flushing the queue when the channel is closed # A special signal object for flushing the queue when the channel is closed

View File

@@ -0,0 +1 @@
from .main import main

View File

@@ -0,0 +1,4 @@
from .main import main
main()

View File

@@ -0,0 +1,37 @@
import os.path
try:
# betterproto[compiler] specific dependencies
import black
import jinja2
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str:
templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates")
)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
return black.format_str(
template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python
import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.plugin.parser import generate_code
def main() -> None:
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
dump_file = os.getenv("BETTERPROTO_DUMP")
if dump_file:
dump_request(dump_file, request)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file.
"""
with open(str(dump_file), "wb") as fh:
sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n")
fh.write(request.SerializeToString())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,706 @@
"""Plugin model dataclasses.
These classes are meant to be an intermediate representation
of protobuf objects. They are used to organize the data collected during parsing.
The general intention is to create a doubly-linked tree-like structure
with the following types of references:
- Downwards references: from message -> fields, from output package -> messages
or from service -> service methods
- Upwards references: from field -> message, message -> package.
- Input/output message references: from a service method to it's corresponding
input/output messages, which may even be in another package.
There are convenience methods to allow climbing up and down this tree, for
example to retrieve the list of all messages that are in the same package as
the current message.
Most of these classes take as inputs:
- proto_obj: A reference to it's corresponding protobuf object as
presented by the protoc plugin.
- parent: a reference to the parent object in the tree.
With this information, the class is able to expose attributes,
such as a pythonized name, that will be calculated from proto_obj.
The instantiation should also attach a reference to the new object
into the corresponding place within it's parent object. For example,
instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
PLACEHOLDER = object()
# Organize proto types into categories
PROTO_FLOAT_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
)
PROTO_INT_TYPES = (
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
)
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
PROTO_MESSAGE_TYPES = (
FieldDescriptorProto.TYPE_MESSAGE, # 11
FieldDescriptorProto.TYPE_ENUM, # 14
)
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
PROTO_PACKED_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_BOOL, # 8
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
)
def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return ""
class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
def __post_init__(self) -> None:
"""Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items():
if field_val is PLACEHOLDER:
raise ValueError(f"`{field_name}` is a required field.")
@property
def output_file(self) -> "OutputTemplate":
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current
@property
def proto_file(self) -> FieldDescriptorProto:
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.package_proto_obj
@property
def request(self) -> "PluginRequestCompiler":
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.parent_request
@property
def comment(self) -> str:
"""Crawl the proto source code and retrieve comments
for this object.
"""
return get_comment(
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
)
@dataclass
class PluginRequestCompiler:
plugin_request_obj: plugin.CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@property
def all_messages(self) -> List["MessageCompiler"]:
"""All of the messages in this request.
Returns
-------
List[MessageCompiler]
List of all of the messages in this request.
"""
return [
msg for output in self.output_packages.values() for msg in output.messages
]
@dataclass
class OutputTemplate:
"""Representation of an output .py file.
Each output file corresponds to a .proto input file,
but may need references to other .proto files to be
built.
"""
parent_request: PluginRequestCompiler
package_proto_obj: FileDescriptorProto
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
@property
def package(self) -> str:
"""Name of input package.
Returns
-------
str
Name of input package.
"""
return self.package_proto_obj.package
@property
def input_filenames(self) -> List[str]:
"""Names of the input files used to build this output.
Returns
-------
List[str]
Names of the input files used to build this output.
"""
return [f.name for f in self.input_files]
@property
def python_module_imports(self) -> Set[str]:
imports = set()
if any(x for x in self.messages if any(x.deprecated_fields)):
imports.add("warnings")
return imports
@dataclass
class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
default_factory=list
)
deprecated: bool = field(default=False, init=False)
def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
self.output_file.enums.append(self)
else:
self.output_file.messages.append(self)
self.deprecated = self.proto_obj.options.deprecated
super().__post_init__()
@property
def proto_name(self) -> str:
return self.proto_obj.name
@property
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.py_name
@property
def deprecated_fields(self) -> Iterator[str]:
for f in self.fields:
if f.deprecated:
yield f.py_name
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool:
"""True if proto_field_obj is a map, otherwise False."""
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
# This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower()
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in parent_message.nested_type: # parent message
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
return True
return False
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False."""
return proto_field_obj.HasField("oneof_index")
@dataclass
class FieldCompiler(MessageCompiler):
parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER
def __post_init__(self) -> None:
# Add field to message
self.parent.fields.append(self)
# Check for new imports
annotation = self.annotation
if "Optional[" in annotation:
self.output_file.typing_imports.add("Optional")
if "List[" in annotation:
self.output_file.typing_imports.add("List")
if "Dict[" in annotation:
self.output_file.typing_imports.add("Dict")
if "timedelta" in annotation:
self.output_file.datetime_imports.add("timedelta")
if "datetime" in annotation:
self.output_file.datetime_imports.add("datetime")
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field as a field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
field_args = ", ".join(
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
)
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
)
return f"{name}{annotations} = {betterproto_field_type}"
@property
def betterproto_field_args(self) -> List[str]:
args = []
if self.field_wraps:
args.append(f"wraps={self.field_wraps}")
return args
@property
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
return f"betterproto.{wrapped_type}"
return None
@property
def repeated(self) -> bool:
return (
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
and not is_map(self.proto_obj, self.parent)
)
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict["))
@property
def field_type(self) -> str:
"""String representation of proto field type."""
return (
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
)
@property
def default_value_string(self) -> Union[Text, None, float, int]:
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.py_type == "int":
return "0"
if self.py_type == "float":
return "0.0"
elif self.py_type == "bool":
return "False"
elif self.py_type == "str":
return '""'
elif self.py_type == "bytes":
return 'b""'
else:
# Message type
return "None"
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
@property
def py_name(self) -> str:
"""Pythonized name."""
return pythonize_field_name(self.proto_name)
@property
def proto_name(self) -> str:
"""Original protobuf name."""
return self.proto_obj.name
@property
def py_type(self) -> str:
"""String representation of Python type."""
if self.proto_obj.type in PROTO_FLOAT_TYPES:
return "float"
elif self.proto_obj.type in PROTO_INT_TYPES:
return "int"
elif self.proto_obj.type in PROTO_BOOL_TYPES:
return "bool"
elif self.proto_obj.type in PROTO_STR_TYPES:
return "str"
elif self.proto_obj.type in PROTO_BYTES_TYPES:
return "bytes"
elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
# Type referencing another defined Message or a named enum
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
)
else:
raise NotImplementedError(f"Unknown type {field.type}")
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_type}]"
return self.py_type
@dataclass
class OneOfFieldCompiler(FieldCompiler):
@property
def betterproto_field_args(self) -> List[str]:
args = super().betterproto_field_args
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
args.append(f'group="{group}"')
return args
@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
py_v_type: Type = PLACEHOLDER
proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER
def __post_init__(self) -> None:
"""Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1] # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
@property
def betterproto_field_args(self) -> List[str]:
return [f"betterproto.{self.proto_k_type}", f"betterproto.{self.proto_v_type}"]
@property
def field_type(self) -> str:
return "map"
@property
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property
def repeated(self) -> bool:
return False # maps cannot be repeated
@dataclass
class EnumDefinitionCompiler(MessageCompiler):
"""Representation of a proto Enum definition."""
proto_obj: EnumDescriptorProto = PLACEHOLDER
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
@dataclass(unsafe_hash=True)
class EnumEntry:
"""Representation of an Enum entry."""
name: str
value: int
comment: str
def __post_init__(self) -> None:
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
name=sanitize_name(entry_proto_value.name),
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.proto_file, path=self.path + [2, entry_number]
),
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
]
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
"""
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
@dataclass
class ServiceCompiler(ProtoContentBase):
parent: OutputTemplate = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
super().__post_init__() # check for unset fields
@property
def proto_name(self) -> str:
return self.proto_obj.name
@property
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@dataclass
class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
comment_indent: int = 8
def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)
# Check for Optional import
if self.py_input_message:
for f in self.py_input_message.fields:
if f.default_value_string == "None":
self.output_file.typing_imports.add("Optional")
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
if self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
super().__post_init__() # check for unset fields
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional")
return mutable_default_args
@property
def py_name(self) -> str:
"""Pythonized method name."""
return pythonize_method_name(self.proto_obj.name)
@property
def proto_name(self) -> str:
"""Original protobuf name."""
return self.proto_obj.name
@property
def route(self) -> str:
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
"""
package, name = parse_source_type_name(self.proto_obj.input_type)
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is
# comparable with method.input_type
for msg in self.request.all_messages:
if (
msg.py_name == name.replace(".", "")
and msg.output_file.package == package
):
return msg
return None
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type corresponding to the
input message.
Returns
-------
str
String representation of the Python type corresponding to the input message.
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
).strip('"')
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the
output message.
Returns
-------
str
String representation of the Python type corresponding to the output message.
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
unwrap=False,
).strip('"')
@property
def client_streaming(self) -> bool:
return self.proto_obj.client_streaming
@property
def server_streaming(self) -> bool:
return self.proto_obj.server_streaming

View File

@@ -0,0 +1,176 @@
import itertools
import pathlib
import sys
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
ServiceDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
is_oneof,
)
if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse(
proto_file: FieldDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy
def _traverse(
path: List[int], items: List["Descriptor"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def generate_code(
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
) -> None:
plugin_options = request.parameter.split(",") if request.parameter else []
request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
request_data.output_packages[output_package_name] = OutputTemplate(
parent_request=request_data, package_proto_obj=proto_file
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(item=item, path=path, output_package=output_package)
# Read Services
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f: response.File = response.file.add()
f.name = str(output_path)
# Render and then format the output file
f.content = outputfile_compiler(output_file=output_package)
# Make each output directory a package with __init__ file
init_files = {
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
} - output_paths
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
def read_protobuf_type(
item: DescriptorProto, path: List[int], output_package: OutputTemplate
) -> None:
if isinstance(item, DescriptorProto):
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
# Process Message
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
for index, field in enumerate(item.field):
if is_map(field, item):
MapEntryCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
elif is_oneof(field):
OneOfFieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
else:
FieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
def read_protobuf_service(
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
) -> None:
service_data = ServiceCompiler(
parent=output_package, proto_obj=service, path=[6, index]
)
for j, method in enumerate(service.method):
ServiceMethodCompiler(
parent=service_data, proto_obj=method, path=[6, index, 2, j]
)

View File

@@ -0,0 +1,2 @@
@SET plugin_dir=%~dp0
@python -m %plugin_dir% %*

View File

@@ -0,0 +1,159 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
from dataclasses import dataclass
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if output_file.typing_imports %}
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
import grpclib
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
{% endif %}
{% for entry in enum.entries %}
{% if entry.comment %}
{{ entry.comment }}
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in output_file.messages %}
@dataclass(eq=False, repr=False)
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
{% endif %}
{% for field in message.fields %}
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.get_field_string() }}
{% endfor %}
{% if not message.fields %}
pass
{% endif %}
{% if message.deprecated or message.deprecated_fields %}
def __post_init__(self) -> None:
{% if message.deprecated %}
warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning)
{% endif %}
super().__post_init__()
{% for field in message.deprecated_fields %}
if self.{{ field }}:
warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning)
{% endfor %}
{% endif %}
{% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }},
):
yield response
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.py_output_message_type.strip('"') }},
):
yield response
{% endif %}{# if client streaming #}
{% else %}{# i.e. not server streaming #}
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.py_output_message_type.strip('"') }}
)
{% endif %}{# client streaming #}
{% endif %}
{% endfor %}
{% endfor %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}

View File

@@ -50,7 +50,7 @@ You can add multiple `.proto` files to the test case, as long as one file matche
`test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class* `test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class*
```python ```python
from betterproto.tests.output_betterproto.bool.bool import Test from tests.output_betterproto.bool.bool import Test
def test_value(): def test_value():
message = Test() message = Test()

View File

@@ -2,18 +2,17 @@
import asyncio import asyncio
import os import os
from pathlib import Path from pathlib import Path
import platform
import shutil import shutil
import subprocess
import sys import sys
from typing import Set from typing import Set
from betterproto.tests.util import ( from tests.util import (
get_directories, get_directories,
inputs_path, inputs_path,
output_path_betterproto, output_path_betterproto,
output_path_reference, output_path_reference,
protoc_plugin, protoc,
protoc_reference,
) )
# Force pure-python implementation instead of C++, otherwise imports # Force pure-python implementation instead of C++, otherwise imports
@@ -89,8 +88,8 @@ async def generate_test_case_output(
(ref_out, ref_err, ref_code), (ref_out, ref_err, ref_code),
(plg_out, plg_err, plg_code), (plg_out, plg_err, plg_code),
) = await asyncio.gather( ) = await asyncio.gather(
protoc_reference(test_case_input_path, test_case_output_path_reference), protoc(test_case_input_path, test_case_output_path_reference, True),
protoc_plugin(test_case_input_path, test_case_output_path_betterproto), protoc(test_case_input_path, test_case_output_path_betterproto, False),
) )
message = f"Generated output for {test_case_name!r}" message = f"Generated output for {test_case_name!r}"
@@ -136,6 +135,10 @@ def main():
else: else:
verbose = False verbose = False
whitelist = set(sys.argv[1:]) whitelist = set(sys.argv[1:])
if platform.system() == "Windows":
asyncio.set_event_loop(asyncio.ProactorEventLoop())
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))

View File

@@ -1,12 +1,15 @@
import asyncio import asyncio
from betterproto.tests.output_betterproto.service.service import ( import sys
DoThingResponse,
from tests.output_betterproto.service.service import (
DoThingRequest, DoThingRequest,
DoThingResponse,
GetThingRequest, GetThingRequest,
GetThingResponse,
TestStub as ThingServiceClient, TestStub as ThingServiceClient,
) )
import grpclib import grpclib
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor from grpclib.testing import ChannelFor
import pytest import pytest
from betterproto.grpc.util.async_channel import AsyncChannel from betterproto.grpc.util.async_channel import AsyncChannel
@@ -18,31 +21,92 @@ async def _test_client(client, name="clean room", **kwargs):
assert response.names == [name] assert response.names == [name]
def _assert_request_meta_recieved(deadline, metadata): def _assert_request_meta_received(deadline, metadata):
def server_side_test(stream): def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx( assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1 deadline._timestamp, 1
), "The provided deadline should be recieved serverside" ), "The provided deadline should be received serverside"
assert ( assert (
stream.metadata["authorization"] == metadata["authorization"] stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be recieved serverside" ), "The provided authorization metadata should be received serverside"
return server_side_test return server_side_test
@pytest.fixture
def handler_trailer_only_unauthenticated():
async def handler(stream: grpclib.server.Stream):
await stream.recv_message()
await stream.send_initial_metadata()
await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
return handler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_service_call(): async def test_simple_service_call():
async with ChannelFor([ThingService()]) as channel: async with ChannelFor([ThingService()]) as channel:
await _test_client(ThingServiceClient(channel)) await _test_client(ThingServiceClient(channel))
@pytest.mark.asyncio
async def test_trailer_only_error_unary_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_thing",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something")
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
async def test_trailer_only_error_stream_unary(
mocker, handler_trailer_only_unauthenticated
):
service = ThingService()
mocker.patch.object(
service,
"do_many_things",
side_effect=handler_trailer_only_unauthenticated,
autospec=True,
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
spy = mocker.spy(client, "_unary_unary")
await _test_client(client)
comments = spy.call_args_list[-1].args[1].comments
await _test_client(client)
assert spy.call_args_list[-1].args[1].comments is not comments
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_call_with_upfront_request_params(): async def test_service_call_with_upfront_request_params():
# Setting deadline # Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22) deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata) ThingServiceClient(channel, deadline=deadline, metadata=metadata)
@@ -53,7 +117,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(timeout) deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata) ThingServiceClient(channel, timeout=timeout, metadata=metadata)
@@ -70,7 +134,7 @@ async def test_service_call_lower_level_with_overrides():
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"} kwarg_metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel: ) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary( response = await client._unary_unary(
@@ -92,7 +156,7 @@ async def test_service_call_lower_level_with_overrides():
async with ChannelFor( async with ChannelFor(
[ [
ThingService( ThingService(
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
) )
] ]
) as channel: ) as channel:
@@ -140,8 +204,8 @@ async def test_async_gen_for_stream_stream_request():
assert response.version == response_index + 1 assert response.version == response_index + 1
response_index += 1 response_index += 1
if more_things: if more_things:
# Send some more requests as we recieve reponses to be sure coordination of # Send some more requests as we receive responses to be sure coordination of
# send/recieve events doesn't matter # send/receive events doesn't matter
await request_chan.send(GetThingRequest(more_things.pop(0))) await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done(): elif not send_initial_requests.done():
# Make sure the sending task it completed # Make sure the sending task it completed
@@ -151,4 +215,4 @@ async def test_async_gen_for_stream_stream_request():
request_chan.close() request_chan.close()
assert response_index == len( assert response_index == len(
expected_things expected_things
), "Didn't recieve all exptected responses" ), "Didn't receive all expected responses"

View File

@@ -27,10 +27,7 @@ class ClientStub:
async def to_list(generator: AsyncIterator): async def to_list(generator: AsyncIterator):
result = [] return [value async for value in generator]
async for value in generator:
result.append(value)
return result
@pytest.fixture @pytest.fixture

View File

@@ -1,12 +1,12 @@
from betterproto.tests.output_betterproto.service.service import ( from tests.output_betterproto.service.service import (
DoThingResponse, DoThingResponse,
DoThingRequest, DoThingRequest,
GetThingRequest, GetThingRequest,
GetThingResponse, GetThingResponse,
TestStub as ThingServiceClient,
) )
import grpclib import grpclib
from typing import Any, Dict import grpclib.server
from typing import Dict
class ThingService: class ThingService:

View File

@@ -1,4 +1,4 @@
from betterproto.tests.output_betterproto.bool import Test from tests.output_betterproto.bool import Test
def test_value(): def test_value():

View File

@@ -1,5 +1,5 @@
import betterproto.tests.output_betterproto.casing as casing import tests.output_betterproto.casing as casing
from betterproto.tests.output_betterproto.casing import Test from tests.output_betterproto.casing import Test
def test_message_attributes(): def test_message_attributes():

View File

@@ -1,4 +1,4 @@
from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test from tests.output_betterproto.casing_message_field_uppercase import Test
def test_message_casing(): def test_message_casing():

View File

@@ -1,12 +1,11 @@
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
# Remove from list when fixed. # Remove from list when fixed.
xfail = { xfail = {
"import_circular_dependency",
"oneof_enum", # 63 "oneof_enum", # 63
"namespace_keywords", # 70 "namespace_keywords", # 70
"namespace_builtin_types", # 53 "namespace_builtin_types", # 53
"googletypes_struct", # 9 "googletypes_struct", # 9
"googletypes_value", # 9, "googletypes_value", # 9
"import_capitalized_package", "import_capitalized_package",
"example", # This is the example in the readme. Not a test. "example", # This is the example in the readme. Not a test.
} }

View File

@@ -0,0 +1,4 @@
{
"v": 10,
"value": 10
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
// Some documentation about the Test message.
message Test {
// Some documentation about the value.
option deprecated = true;
int32 v = 1 [deprecated=true];
int32 value = 2;
}

View File

@@ -0,0 +1,4 @@
{
"v": 10,
"value": 10
}

View File

@@ -0,0 +1,8 @@
syntax = "proto3";
// Some documentation about the Test message.
message Test {
// Some documentation about the value.
int32 v = 1 [deprecated=true];
int32 value = 2;
}

View File

@@ -0,0 +1,9 @@
{
"choice": "FOUR",
"choices": [
"ZERO",
"ONE",
"THREE",
"FOUR"
]
}

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test {
Choice choice = 1;
repeated Choice choices = 2;
}
enum Choice {
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
}

View File

@@ -0,0 +1,84 @@
from tests.output_betterproto.enum import (
Test,
Choice,
)
def test_enum_set_and_get():
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
assert Test(choice=Choice.ONE).choice == Choice.ONE
assert Test(choice=Choice.THREE).choice == Choice.THREE
assert Test(choice=Choice.FOUR).choice == Choice.FOUR
def test_enum_set_with_int():
assert Test(choice=0).choice == Choice.ZERO
assert Test(choice=1).choice == Choice.ONE
assert Test(choice=3).choice == Choice.THREE
assert Test(choice=4).choice == Choice.FOUR
def test_enum_is_comparable_with_int():
assert Test(choice=Choice.ZERO).choice == 0
assert Test(choice=Choice.ONE).choice == 1
assert Test(choice=Choice.THREE).choice == 3
assert Test(choice=Choice.FOUR).choice == 4
def test_enum_to_dict():
assert (
"choice" not in Test(choice=Choice.ZERO).to_dict()
), "Default enum value is not serialized"
assert (
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
== "ZERO"
)
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"
def test_repeated_enum_is_comparable_with_int():
assert Test(choices=[Choice.ZERO]).choices == [0]
assert Test(choices=[Choice.ONE]).choices == [1]
assert Test(choices=[Choice.THREE]).choices == [3]
assert Test(choices=[Choice.FOUR]).choices == [4]
def test_repeated_enum_set_and_get():
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]
def test_repeated_enum_to_dict():
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]
all_enums_dict = Test(
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
).to_dict()
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
def test_repeated_enum_with_single_value_to_dict():
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
assert Test(choices=1).to_dict()["choices"] == ["ONE"]
def test_repeated_enum_with_non_list_iterables_to_dict():
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
"ONE",
"THREE",
]
def enum_generator():
yield Choice.ONE
yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]

View File

@@ -0,0 +1,13 @@
syntax = "proto3";
message Foo{
int64 bar = 1;
}
message Test{
oneof group{
string string = 1;
int64 integer = 2;
Foo foo = 3;
}
}

View File

@@ -0,0 +1,55 @@
import pytest
from google.protobuf import json_format
import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import (
Test,
Foo,
)
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Test as ReferenceTest,
Foo as ReferenceFoo,
)
def test_oneof_serializes_similar_to_google_oneof():
tests = [
(Test(string="abc"), ReferenceTest(string="abc")),
(Test(integer=2), ReferenceTest(integer=2)),
(Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))),
# Default values should also behave the same within oneofs
(Test(string=""), ReferenceTest(string="")),
(Test(integer=0), ReferenceTest(integer=0)),
(Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))),
]
for message, message_reference in tests:
# NOTE: As of July 2020, MessageToJson inserts newlines in the output string so,
# just compare dicts
assert message.to_dict() == json_format.MessageToDict(message_reference)
def test_bytes_are_the_same_for_oneof():
message = Test(string="")
message_reference = ReferenceTest(string="")
message_bytes = bytes(message)
message_reference_bytes = message_reference.SerializeToString()
assert message_bytes == message_reference_bytes
message2 = Test().parse(message_reference_bytes)
message_reference2 = ReferenceTest()
message_reference2.ParseFromString(message_reference_bytes)
assert message == message2
assert message_reference == message_reference2
# None of these fields were explicitly set BUT they should not actually be null
# themselves
assert isinstance(message.foo, Foo)
assert isinstance(message2.foo, Foo)
assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo)

View File

@@ -3,8 +3,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf import betterproto.lib.google.protobuf as protobuf
import pytest import pytest
from betterproto.tests.mocks import MockChannel from tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response import TestStub from tests.output_betterproto.googletypes_response import TestStub
test_cases = [ test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5), (TestStub.get_double, protobuf.DoubleValue, 2.5),
@@ -21,7 +21,7 @@ test_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_recieves_wrapped_type( async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
): ):
wrapped_value = wrapper_class() wrapped_value = wrapper_class()

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from betterproto.tests.mocks import MockChannel from tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response_embedded import ( from tests.output_betterproto.googletypes_response_embedded import (
Output, Output,
TestStub, TestStub,
) )

Some files were not shown because too many files have changed in this diff Show More