35 Commits

Author SHA1 Message Date
James Hilton-Balfe
6dd7baa26c Release v2.0.0.b4 (#307)
Co-authored-by: Kalan <22137047+kalzoo@users.noreply.github.com>
2022-01-03 18:18:44 +00:00
Kalan
573c7292a6 Add Python 3.10 to GitHub Actions test matrix (#280)
Co-authored-by: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com>
2021-12-29 23:10:34 +00:00
Kalan
d77f44ebb7 Support proto3 field presence (#281)
* Update protobuf pregenerated files

* Update grpcio-tools to latest version

* Implement proto3 field presence

* Fix to_dict with None optional fields.

* Add test with optional enum

* Properly support optional enums

* Add tests for 64-bit ints and floats

* Support field presence for int64 types

* Fix oneof serialization with proto3 field presence (#292)

= Description

The serialization of a oneof message that contains a message with fields
with explicit presence was buggy.

For example:

```
message A {
    oneof kind {
        B b = 1;
        C c = 2;
    }
}

message B {}
message C {
    optional bool z = 1;
}
```

Serializing `A(b=B())` would lead to this payload:

```
0A # tag1, length delimited
00 # length: 0
12 # tag2, length delimited
00 # length: 0
```

Which when deserialized, leads to the message `A(c=C())`.

= Explanation

The issue lies in the post_init method. All fields are introspected, and
if different from PLACEHOLDER, the message is marked as having been
"serialized_on_wire".
Then, when serializing `A(b=B())`, we go through each field of the
oneof:

- field 'b': this is the selected field from the group, so it is
  serialized
- field 'c': marked as 'serialized_on_wire', so it is added as well.

= Fix

The issue is that support for explicit presence changed the default
value from PLACEHOLDER to None. This breaks the post_init method in that
case, which is relatively easy to fix: if a field is optional, and set
to None, this is considered as the default value (which it is).

This fix however has a side-effect: the group_current for this field (the
oneof trick for explicit presence) is no longer set. This changes the
behavior when serializing the message in JSON: as the value is the
default one (None), and the group is not set (which would force the
serialization of the field), so None fields are no longer serialized in
JSON. This break one test, and will be fixed in the next commit.

* fix: do not serialize None fields in JSON format

This is linked to the fix from the previous commit: after it, scalar
None fields were not included in the JSON format, but some were still
included.

This is all cleaned up: None fields are not added in JSON by default,
as they indicate the default value of fields with explicit presence.
However, if `include_default_values is set, they are included.

* Fix: use builtin annotation prefix

* Remove comment

Co-authored-by: roblabla <unfiltered@roblab.la>
Co-authored-by: Vincent Thiberville <vthib@pm.me>
2021-12-29 13:38:32 -08:00
dependabot[bot]
671c0ff4ac Bump urllib3 from 1.26.4 to 1.26.5 (#288)
Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.4 to 1.26.5.
- [Release notes](https://github.com/urllib3/urllib3/releases)
- [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst)
- [Commits](https://github.com/urllib3/urllib3/compare/1.26.4...1.26.5)

---
updated-dependencies:
- dependency-name: urllib3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-12-11 18:31:26 -08:00
dependabot[bot]
9cecc8c3ff Bump babel from 2.9.0 to 2.9.1 (#289)
Bumps [babel](https://github.com/python-babel/babel) from 2.9.0 to 2.9.1.
- [Release notes](https://github.com/python-babel/babel/releases)
- [Changelog](https://github.com/python-babel/babel/blob/master/CHANGES)
- [Commits](https://github.com/python-babel/babel/compare/v2.9.0...v2.9.1)

---
updated-dependencies:
- dependency-name: babel
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2021-12-11 18:30:43 -08:00
Kim Gustyr
bc3cfc5562 Fix default values for enum service args #298 (#299) 2021-12-03 21:26:48 +00:00
guysz
b0a36d12e4 Fix compilation of fields with name identical to their type (#294)
* Revert "Fix compilation of fields named 'bytes' or 'str' (#226)"

This reverts commit deb623ed14.

* Fix compilation of fileds with name identical to their type

* Added test for field-name identical to python type

Co-authored-by: Guy Szweigman <guysz@nvidia.com>
2021-12-01 16:31:02 +00:00
Kalan
a4d2d39546 Fix Python 3.9 Tests (#284)
Co-authored-by: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com>
2021-11-19 21:32:36 +00:00
lazytype
c424b6f8db Include AsyncIterator import for both clients and servers (#264)
Co-authored-by: Robin Lambertz <github@roblab.la>
2021-11-05 14:22:15 +00:00
James Hilton-Balfe
421fdba309 Allow parsing of messages from ByteStrings #266 2021-10-26 00:34:33 +01:00
Robin Lambertz
fb2793e0b6 Allow parsing messages from byteslike
Byteslike objects (like memoryview) do not have a decode function defined.
Instead, a string may be created from them by passing them to the str
constructor along with an encoding.
2021-08-25 12:53:02 +02:00
PIGNOSE
ad8b91766a Add benchmarking cases for nested, repeat and deserialize (#241) 2021-06-21 23:38:22 +02:00
Bekhzod Tillakhanov
a33126544b Fix readme docs 'Async gRPC Support' (#249) 2021-06-21 23:29:59 +02:00
nat
02e41afd09 Release v2.0.0b3 (#182)
Updated change log to include all new features and fixes from master.
2021-04-07 12:57:44 +02:00
nat
7368299a70 Fix serialization of repeated fields with empty messages (#180)
Extend test config and utils to support exclusion of certain json samples from
testing for symetry.
2021-04-06 10:50:45 +10:00
nat
deb623ed14 Fix compilation of fields named 'bytes' or 'str' (#226)
* if you have a field named "bytes" using the bytes type, it doesn't work.
* Enable existing use-case & generalize solution to cover it

Co-authored-by: Spencer <spencer@sf-n.com>
2021-04-06 10:45:57 +10:00
nat
95339bf74d Misc cleanup, see commit body (#227)
- Enable oneof_enum test case that passes now (removed the xfail)
- Switch from toml to tomlkit as a dev dep for better toml support
- upgrade poethepoet to latest stable release
- use full table format for poe tasks to avoid long lines in pyproject.toml
- remove redundant _WrappedMessage class
- fix various Mypy warnings
- reformat some comments for consistent line length
2021-04-06 10:43:09 +10:00
nat
5b639c82b2 Micro-optimization: use tuples instead of lists for conditions (#228)
This should give a small speed boost to some critical code paths.
2021-04-06 10:40:45 +10:00
Matthew Badger
7c5ee47e68 Added support for infinite and nan floats/doubles (#215)
- Added support for the custom double values from
   the protobuf json spec: "Infinity", "-Infinity", and "NaN"
- Added `infinite_floats` test data
- Updated Message.__eq__ to consider nan values
   equal
- Updated `test_message_json` and
   `test_binary_compatibility` to replace NaN float
   values in dictionaries before comparison
   (because two NaN values are not equal)
2021-04-02 15:15:28 +02:00
Nat Noordanus
bb646fe26f Fix template bug resulting in empty __post_init__ methods 2021-04-02 10:13:08 +11:00
Nat Noordanus
fc90653ab1 Sort the list of sources in generated file headers 2021-04-02 10:13:00 +11:00
Nat Noordanus
2a73dbac98 Make plugin use betterproto generated classes internally
This means the betterproto plugin no longer needs to depend durectly on
protobuf.

This requires a small runtime hack to monkey patch some google types to
get around the fact that the compiler uses proto2, but betterproto
expects proto3.

Also:
- regenerate google.protobuf package
- fix a regex bug in the logic for determining whether to use a google
  wrapper type.
- fix a bug causing comments to get mixed up when multiple proto files
  generate code into a single python module
2021-04-02 10:13:00 +11:00
nat
891c9e5d6c Update readme to avoid confusion about unreleased features. (#223) 2021-04-01 20:40:02 +02:00
Nat Noordanus
a890514b5c Update deps & add generate_lib task
- Remove plugin dependency on protobuf since it's no longer required.
- Update poethepoet to for better pyproject toml syntax support
- Add handy generate_lib poe task for maintaining generated libs
2021-04-01 09:49:22 +11:00
Nat Noordanus
fe1e712fdb Make plugin use betterproto generated classes internally
This means the betterproto plugin no longer needs to depend durectly on
protobuf.

This requires a small runtime hack to monkey patch some google types to
get around the fact that the compiler uses proto2, but betterproto
expects proto3.

Also:
- regenerate google.protobuf package
- fix a regex bug in the logic for determining whether to use a google
  wrapper type.
- fix a bug causing comments to get mixed up when multiple proto files
  generate code into a single python module
2021-04-01 09:49:22 +11:00
Vasili Syrakis
7a358a63cf Add __version__ attribute to package 2021-03-31 11:44:32 +11:00
nat
342e6559dc Properly serialize zero-value messages in a oneof group (#176)
Also improve test utils to make it easier to have multiple json examples.

Co-authored-by: Christopher Chambers <chris@peanutcode.com>
2021-03-15 13:52:35 +01:00
Vladimir Solomatin
2f62189346 Fix typing and datetime imports not being present for service method type annotations (#183) 2021-03-12 22:15:15 +01:00
MinJune Kim
8a215367ad Allow empty services (#222)
Fixes issue #220
2021-03-12 21:49:58 +01:00
robinaly
6c1c41e9cc Use dateutil parser (#213)
Switch to using `isoparse` from `dateutil.parser` instead of `datetime.fromisoformat` for more robust parsing of dates in from_dict.
2021-02-24 22:18:05 +01:00
Matthew Badger
9e6881999e Add support for repeated timestamps and durations to to_dict from_dict (#211) 2021-02-16 19:54:50 +01:00
nat
59f5f88c0d Rebuild poetry.lock to fix CI (#202) 2021-01-25 20:28:30 +01:00
Tim Schmidt
8eea5fe256 added documentation for server-facing stubs (#186) 2021-01-24 22:20:32 +01:00
Tim Schmidt
1d54ef8f99 Generate grpclib service stubs (#170) 2020-12-04 22:22:11 +01:00
nat
73cea12e1f Fix incorrect routes in generated client when service is not in a package (#177) 2020-11-28 17:50:25 +01:00
61 changed files with 3475 additions and 1042 deletions

View File

@@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ['3.6.7', '3.7', '3.8', '3.9', '3.10']
exclude:
- os: Windows
python-version: 3.6
@@ -66,4 +66,4 @@ jobs:
- name: Execute test suite
shell: bash
run: poetry run pytest tests/
run: poetry run python -m pytest tests/

View File

@@ -7,6 +7,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [2.0.0b4] - 2022-01-03
- **Breaking**: the minimum Python version has been bumped to `3.6.2`
- Always add `AsyncIterator` to imports if there are services [#264](https://github.com/danielgtaylor/python-betterproto/pull/264)
- Allow parsing of messages from `ByteStrings` [#266](https://github.com/danielgtaylor/python-betterproto/pull/266)
- Add support for proto3 optional [#281](https://github.com/danielgtaylor/python-betterproto/pull/281)
- Fix compilation of fields with names identical to builtin types [#294](https://github.com/danielgtaylor/python-betterproto/pull/294)
- Fix default values for enum service args [#299](https://github.com/danielgtaylor/python-betterproto/pull/299)
## [2.0.0b3] - 2021-04-07
- Generate grpclib service stubs [#170](https://github.com/danielgtaylor/python-betterproto/pull/170)
- Add \_\_version\_\_ attribute to package [#134](https://github.com/danielgtaylor/python-betterproto/pull/134)
- Use betterproto generated messages in the plugin [#161](https://github.com/danielgtaylor/python-betterproto/pull/161)
- Sort the list of sources in generated file headers [#164](https://github.com/danielgtaylor/python-betterproto/pull/164)
- Micro-optimization: use tuples instead of lists for conditions [#228](https://github.com/danielgtaylor/python-betterproto/pull/228)
- Improve datestring parsing [#213](https://github.com/danielgtaylor/python-betterproto/pull/213)
- Fix serialization of repeated fields with empty messages [#180](https://github.com/danielgtaylor/python-betterproto/pull/180)
- Fix compilation of fields named 'bytes' or 'str' [#226](https://github.com/danielgtaylor/python-betterproto/pull/226)
- Fix json serialization of infinite and nan floats/doubles [#215](https://github.com/danielgtaylor/python-betterproto/pull/215)
- Fix template bug resulting in empty \_\_post_init\_\_ methods [#162](https://github.com/danielgtaylor/python-betterproto/pull/162)
- Fix serialization of zero-value messages in a oneof group [#176](https://github.com/danielgtaylor/python-betterproto/pull/176)
- Fix missing typing and datetime imports [#183](https://github.com/danielgtaylor/python-betterproto/pull/183)
- Fix code generation for empty services [#222](https://github.com/danielgtaylor/python-betterproto/pull/222)
- Fix Message.to_dict and from_dict handling of repeated timestamps and durations [#211](https://github.com/danielgtaylor/python-betterproto/pull/211)
- Fix incorrect routes in generated client when service is not in a package [#177](https://github.com/danielgtaylor/python-betterproto/pull/177)
## [2.0.0b2] - 2020-11-24
- Add support for deprecated message and fields [#126](https://github.com/danielgtaylor/python-betterproto/pull/126)
@@ -35,7 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [2.0.0b1] - 2020-07-04
[Upgrade Guide](./docs/upgrading.md)
[Upgrade Guide](./docs/upgrading.md)
> Several bugfixes and improvements required or will require small breaking changes, necessitating a new version.
> `2.0.0` will be released once the interface is stable.

View File

@@ -1,6 +1,7 @@
# Better Protobuf / gRPC Support for Python
![](https://github.com/danielgtaylor/python-betterproto/workflows/CI/badge.svg)
> :octocat: If you're reading this on github, please be aware that it might mention unreleased features! See the latest released README on [pypi](https://pypi.org/project/betterproto/).
This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable, idiomatic Python code. It will not support legacy features or environments (e.g. Protobuf 2). The following are supported:
@@ -159,6 +160,12 @@ service Echo {
}
```
Generate echo proto file:
```
python -m grpc_tools.protoc -I . --python_betterproto_out=. echo.proto
```
A client can be implemented as follows:
```python
import asyncio
@@ -192,6 +199,37 @@ EchoStreamResponse(value='hello')
EchoStreamResponse(value='hello')
```
This project also produces server-facing stubs that can be used to implement a Python
gRPC server.
To use them, simply subclass the base class in the generated files and override the
service methods:
```python
import asyncio
from echo import EchoBase, EchoResponse, EchoStreamResponse
from grpclib.server import Server
from typing import AsyncIterator
class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
return EchoResponse([value for _ in range(extra_times)])
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times):
yield EchoStreamResponse(value)
async def main():
server = Server([EchoService()])
await server.start("127.0.0.1", 50051)
await server.wait_closed()
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
```
### JSON
Both serializing and parsing are supported to/from JSON and Python dictionaries using the following methods:
@@ -413,9 +451,9 @@ Assuming your `google.protobuf` source files (included with all releases of `pro
```sh
protoc \
--plugin=protoc-gen-custom=betterproto/plugin.py \
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py \
--custom_opt=INCLUDE_GOOGLE \
--custom_out=betterproto/lib \
--custom_out=src/betterproto/lib \
-I /usr/local/include/ \
/usr/local/include/google/protobuf/*.proto
```

View File

@@ -1,6 +1,8 @@
import betterproto
from dataclasses import dataclass
from typing import List
@dataclass
class TestMessage(betterproto.Message):
@@ -9,6 +11,29 @@ class TestMessage(betterproto.Message):
baz: float = betterproto.float_field(2)
@dataclass
class TestNestedChildMessage(betterproto.Message):
str_key: str = betterproto.string_field(0)
bytes_key: bytes = betterproto.bytes_field(1)
bool_key: bool = betterproto.bool_field(2)
float_key: float = betterproto.float_field(3)
int_key: int = betterproto.uint64_field(4)
@dataclass
class TestNestedMessage(betterproto.Message):
foo: TestNestedChildMessage = betterproto.message_field(0)
bar: TestNestedChildMessage = betterproto.message_field(1)
baz: TestNestedChildMessage = betterproto.message_field(2)
@dataclass
class TestRepeatedMessage(betterproto.Message):
foo_repeat: List[str] = betterproto.string_field(0)
bar_repeat: List[int] = betterproto.int64_field(1)
baz_repeat: List[bool] = betterproto.bool_field(2)
class BenchMessage:
"""Test creation and usage a proto message."""
@@ -16,6 +41,30 @@ class BenchMessage:
self.cls = TestMessage
self.instance = TestMessage()
self.instance_filled = TestMessage(0, "test", 0.0)
self.instance_filled_bytes = bytes(self.instance_filled)
self.instance_filled_nested = TestNestedMessage(
TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, -302),
TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300),
)
self.instance_filled_nested_bytes = bytes(self.instance_filled_nested)
self.instance_filled_repeated = TestRepeatedMessage(
[
"test1",
"test2",
"test3",
"test4",
"test5",
"test6",
"test7",
"test8",
"test9",
"test10",
],
[2, -100, 0, 500000, 600, -425678, 1000000000, -300, 1, -694214214466],
[True, False, False, False, True, True, False, True, False, False],
)
self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated)
def time_overhead(self):
"""Overhead in class definition."""
@@ -50,6 +99,26 @@ class BenchMessage:
"""Time serializing a message to wire."""
bytes(self.instance_filled)
def time_deserialize(self):
"""Time deserialize a message."""
TestMessage().parse(self.instance_filled_bytes)
def time_serialize_nested(self):
"""Time serializing a nested message to wire."""
bytes(self.instance_filled_nested)
def time_deserialize_nested(self):
"""Time deserialize a nested message."""
TestNestedMessage().parse(self.instance_filled_nested_bytes)
def time_serialize_repeated(self):
"""Time serializing a repeated message to wire."""
bytes(self.instance_filled_repeated)
def time_deserialize_repeated(self):
"""Time deserialize a repeated message."""
TestRepeatedMessage().parse(self.instance_filled_repeated_bytes)
class MemSuite:
def setup(self):

View File

@@ -12,7 +12,7 @@ Features:
- Generated messages are both binary & JSON serializable
- Messages use relevant python types, e.g. ``Enum``, ``datetime`` and ``timedelta``
objects
- ``async``/``await`` support for gRPC Clients
- ``async``/``await`` support for gRPC Clients and Servers
- Generates modern, readable, idiomatic python code
Contents:

View File

@@ -100,7 +100,7 @@ Async gRPC Support
++++++++++++++++++
The generated code includes `grpclib <https://grpclib.readthedocs.io/en/latest>`_ based
stub (client) classes for rpc services declared in the input proto files.
stub (client and server) classes for rpc services declared in the input proto files.
It is enabled by default.
@@ -160,6 +160,36 @@ The generated client can be used like so:
EchoStreamResponse(value='hello')
The server-facing stubs can be used to implement a Python
gRPC server.
To use them, simply subclass the base class in the generated files and override the
service methods:
.. code-block:: python
from echo import EchoBase
from grpclib.server import Server
from typing import AsyncIterator
class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
return value
async def echo_stream(
self, value: str, extra_times: int
) -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times):
yield value
async def start_server():
HOST = "127.0.0.1"
PORT = 1337
server = Server([EchoService()])
await server.start(HOST, PORT)
await server.serve_forever()
JSON
++++
Message objects include :meth:`betterproto.Message.to_json` and

1307
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "betterproto"
version = "2.0.0b2"
version = "2.0.0b4"
description = "A better Protobuf / gRPC generator & library"
authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"]
readme = "README.md"
@@ -12,50 +12,94 @@ packages = [
]
[tool.poetry.dependencies]
python = "^3.6"
backports-datetime-fromisoformat = { version = "^1.0.0", python = "<3.7" }
python = ">=3.6.2,<4.0"
black = { version = ">=19.3b0", optional = true }
dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpclib = "^0.4.1"
jinja2 = { version = "^2.11.2", optional = true }
protobuf = { version = "^3.12.2", optional = true }
python-dateutil = "^2.8"
[tool.poetry.dev-dependencies]
black = "^20.8b1"
asv = "^0.4.2"
black = "^21.11b0"
bpython = "^0.19"
grpcio-tools = "^1.30.0"
grpcio-tools = "^1.40.0"
jinja2 = "^2.11.2"
mypy = "^0.770"
poethepoet = "^0.5.0"
mypy = "^0.930"
poethepoet = ">=0.9.0"
protobuf = "^3.12.2"
pytest = "^5.4.2"
pytest = "^6.2.5"
pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
tox = "^3.15.1"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
asv = "^0.4.2"
tomlkit = "^0.7.0"
tox = "^3.15.1"
[tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main"
[tool.poetry.extras]
compiler = ["black", "jinja2", "protobuf"]
compiler = ["black", "jinja2"]
[tool.poe.tasks]
# Dev workflow tasks
generate = { script = "tests.generate:main", help = "Generate test cases (do this once before running test)" }
test = { cmd = "pytest --cov src", help = "Run tests" }
types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" }
format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" }
clean = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"}
bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"}
[tool.poe.tasks.generate]
script = "tests.generate:main"
help = "Generate test cases (do this once before running test)"
[tool.poe.tasks.test]
cmd = "pytest"
help = "Run tests"
[tool.poe.tasks.types]
cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy"
[tool.poe.tasks.format]
cmd = "black . --exclude tests/output_"
help = "Apply black formatting to source code"
[tool.poe.tasks.docs]
cmd = "sphinx-build docs docs/build"
help = "Build the sphinx docs"
[tool.poe.tasks.bench]
shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD"
help = "Benchmark current commit vs. master branch"
[tool.poe.tasks.clean]
cmd = """
rm -rf .asv .coverage .mypy_cache .pytest_cache
dist betterproto.egg-info **/__pycache__
testsoutput_*
"""
help = "Clean out generated files from the workspace"
[tool.poe.tasks.generate_lib]
cmd = """
protoc
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py
--custom_opt=INCLUDE_GOOGLE
--custom_out=src/betterproto/lib
-I /usr/local/include/
/usr/local/include/google/protobuf/**/*.proto
"""
help = "Regenerate the types in betterproto.lib.google"
# CI tasks
full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" }
check-style = { cmd = "black . --check --diff --exclude tests/output_", help = "Check if code style is correct"}
[tool.poe.tasks.full-test]
shell = "poe generate && tox"
help = "Run tests with multiple pythons"
[tool.poe.tasks.check-style]
cmd = "black . --check --diff --exclude tests/output_"
help = "Check if code style is correct"
[tool.black]
target-version = ['py36']

View File

@@ -2,17 +2,20 @@ import dataclasses
import enum
import inspect
import json
import math
import struct
import sys
import typing
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
from dateutil.parser import isoparse
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
@@ -23,15 +26,10 @@ from typing import (
)
from ._types import T
from ._version import __version__
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
if sys.version_info[:2] < (3, 7):
# Apply backport of datetime.fromisoformat from 3.7
from backports.datetime_fromisoformat import MonkeyPatch
MonkeyPatch.patch_fromisoformat()
# Proto 3 data types
TYPE_ENUM = "enum"
@@ -117,6 +115,12 @@ def datetime_default_gen() -> datetime:
DATETIME_ZERO = datetime_default_gen()
# Special protobuf json doubles
INFINITY = "Infinity"
NEG_INFINITY = "-Infinity"
NAN = "NaN"
class Casing(enum.Enum):
"""Casing constants for serialization."""
@@ -141,6 +145,8 @@ class FieldMetadata:
group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
# Is the field optional
optional: Optional[bool] = False
@staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata":
@@ -155,12 +161,15 @@ def dataclass_field(
map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None,
wraps: Optional[str] = None,
optional: bool = False,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=PLACEHOLDER,
default=None if optional else PLACEHOLDER,
metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
"betterproto": FieldMetadata(
number, proto_type, map_types, group, wraps, optional
)
},
)
@@ -170,74 +179,107 @@ def dataclass_field(
# out at runtime. The generated dataclass variables are still typed correctly.
def enum_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_ENUM, group=group)
def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
return dataclass_field(number, TYPE_ENUM, group=group, optional=optional)
def bool_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BOOL, group=group)
def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
return dataclass_field(number, TYPE_BOOL, group=group, optional=optional)
def int32_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_INT32, group=group)
def int32_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_INT32, group=group, optional=optional)
def int64_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_INT64, group=group)
def int64_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_INT64, group=group, optional=optional)
def uint32_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_UINT32, group=group)
def uint32_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_UINT32, group=group, optional=optional)
def uint64_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_UINT64, group=group)
def uint64_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_UINT64, group=group, optional=optional)
def sint32_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_SINT32, group=group)
def sint32_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_SINT32, group=group, optional=optional)
def sint64_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_SINT64, group=group)
def sint64_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_SINT64, group=group, optional=optional)
def float_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_FLOAT, group=group)
def float_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional)
def double_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_DOUBLE, group=group)
def double_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional)
def fixed32_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_FIXED32, group=group)
def fixed32_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional)
def fixed64_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_FIXED64, group=group)
def fixed64_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional)
def sfixed32_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_SFIXED32, group=group)
def sfixed32_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional)
def sfixed64_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_SFIXED64, group=group)
def sfixed64_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional)
def string_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_STRING, group=group)
def string_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_STRING, group=group, optional=optional)
def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group)
def bytes_field(
number: int, group: Optional[str] = None, optional: bool = False
) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group, optional=optional)
def message_field(
number: int, group: Optional[str] = None, wraps: Optional[str] = None
number: int,
group: Optional[str] = None,
wraps: Optional[str] = None,
optional: bool = False,
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
return dataclass_field(
number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional
)
def map_field(
@@ -269,7 +311,7 @@ class Enum(enum.IntEnum):
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
@@ -304,16 +346,16 @@ def encode_varint(value: int) -> bytes:
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
if proto_type in (
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
]:
):
return encode_varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
elif proto_type in (TYPE_SINT32, TYPE_SINT64):
# Handle zig-zag encoding.
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
elif proto_type in FIXED_TYPES:
@@ -373,6 +415,51 @@ def _serialize_single(
return bytes(output)
def _parse_float(value: Any) -> float:
"""Parse the given value to a float
Parameters
----------
value : Any
Value to parse
Returns
-------
float
Parsed value
"""
if value == INFINITY:
return float("inf")
if value == NEG_INFINITY:
return -float("inf")
if value == NAN:
return float("nan")
return float(value)
def _dump_float(value: float) -> Union[float, str]:
"""Dump the given float to JSON
Parameters
----------
value : float
Value to dump
Returns
-------
Union[float, str]
Dumped valid, either a float or the strings
"Infinity" or "-Infinity"
"""
if value == float("inf"):
return INFINITY
if value == -float("inf"):
return NEG_INFINITY
if value == float("nan"):
return NAN
return value
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
Decode a single varint value from a byte buffer. Returns the value and the
@@ -474,13 +561,13 @@ class ProtoClassMetadata:
@staticmethod
def _get_default_gen(
cls: Type["Message"], fields: List[dataclasses.Field]
cls: Type["Message"], fields: Iterable[dataclasses.Field]
) -> Dict[str, Callable[[], Any]]:
return {field.name: cls._get_field_default_gen(field) for field in fields}
@staticmethod
def _get_cls_by_field(
cls: Type["Message"], fields: List[dataclasses.Field]
cls: Type["Message"], fields: Iterable[dataclasses.Field]
) -> Dict[str, Type]:
field_cls = {}
@@ -537,7 +624,8 @@ class Message(ABC):
if meta.group:
group_current.setdefault(meta.group)
if self.__raw_get(field_name) != PLACEHOLDER:
value = self.__raw_get(field_name)
if value != PLACEHOLDER and not (meta.optional and value is None):
# Found a non-sentinel value
all_sentinel = False
@@ -568,7 +656,18 @@ class Message(ABC):
other_val = other._get_field_default(field_name)
if self_val != other_val:
return False
# We consider two nan values to be the same for the
# purposes of comparing messages (otherwise a message
# is not equal to itself)
if (
isinstance(self_val, float)
and isinstance(other_val, float)
and math.isnan(self_val)
and math.isnan(other_val)
):
continue
else:
return False
return True
@@ -628,7 +727,7 @@ class Message(ABC):
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta
self.__class__._betterproto_meta = meta # type: ignore
return meta
def __bytes__(self) -> bytes:
@@ -641,12 +740,16 @@ class Message(ABC):
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
# wrapper types and proto3 field presence/optional fields.
continue
# Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
#
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = (
meta.group and self._group_current[meta.group] == field_name
)
@@ -679,9 +782,18 @@ class Message(ABC):
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
output += (
_serialize_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or b"\n\x00"
)
elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types
@@ -704,7 +816,7 @@ class Message(ABC):
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
serialize_empty=serialize_empty or bool(selected_in_group),
wraps=meta.wraps or "",
)
@@ -734,7 +846,7 @@ class Message(ABC):
@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = sys.modules[cls.__module__]
return get_type_hints(cls, vars(module))
return get_type_hints(cls, module.__dict__, {})
@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@@ -760,8 +872,9 @@ class Message(ABC):
# This is some kind of list (repeated) field.
return list
elif t.__origin__ is Union and t.__args__[1] is type(None):
# This is an optional (wrapped) field. For setting the default we
# really don't care what kind of field it is.
# This is an optional field (either wrapped, or using proto3
# field presence). For setting the default we really don't care
# what kind of field it is.
return type(None)
else:
return t
@@ -781,23 +894,23 @@ class Message(ABC):
) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
if meta.proto_type in (TYPE_INT32, TYPE_INT64):
bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64):
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
value = str(value, "utf-8")
elif meta.proto_type == TYPE_MESSAGE:
cls = self._betterproto.cls_by_field[field_name]
@@ -856,10 +969,10 @@ class Message(ABC):
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]:
if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32):
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]:
elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64):
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
@@ -961,9 +1074,20 @@ class Message(ABC):
output[cased_name] = value
elif field_is_repeated:
# Convert each item.
value = [i.to_dict(casing, include_default_values) for i in value]
cls = self._betterproto.cls_by_field[field_name]
if cls == datetime:
value = [_Timestamp.timestamp_to_json(i) for i in value]
elif cls == timedelta:
value = [_Duration.delta_to_json(i) for i in value]
else:
value = [
i.to_dict(casing, include_default_values) for i in value
]
if value or include_default_values:
output[cased_name] = value
elif value is None:
if include_default_values:
output[cased_name] = value
elif (
value._serialized_on_wire
or include_default_values
@@ -989,6 +1113,9 @@ class Message(ABC):
if meta.proto_type in INT_64_TYPES:
if field_is_repeated:
output[cased_name] = [str(n) for n in value]
elif value is None:
if include_default_values:
output[cased_name] = value
else:
output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES:
@@ -996,11 +1123,13 @@ class Message(ABC):
output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
elif value is None and include_default_values:
output[cased_name] = value
else:
output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
if field_is_repeated:
enum_class: Type[Enum] = field_types[field_name].__args__[0]
enum_class = field_types[field_name].__args__[0]
if isinstance(value, typing.Iterable) and not isinstance(
value, str
):
@@ -1008,9 +1137,20 @@ class Message(ABC):
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
else:
enum_class: Type[Enum] = field_types[field_name] # noqa
elif value is None:
if include_default_values:
output[cased_name] = value
elif meta.optional:
enum_class = field_types[field_name].__args__[0]
output[cased_name] = enum_class(value).name
else:
enum_class = field_types[field_name] # noqa
output[cased_name] = enum_class(value).name
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if field_is_repeated:
output[cased_name] = [_dump_float(n) for n in value]
else:
output[cased_name] = _dump_float(value)
else:
output[cased_name] = value
return output
@@ -1042,16 +1182,26 @@ class Message(ABC):
v = getattr(self, field_name)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name]
for item in value[key]:
v.append(cls().from_dict(item))
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
v = isoparse(value[key])
setattr(self, field_name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
cls = self._betterproto.cls_by_field[field_name]
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
@@ -1079,6 +1229,11 @@ class Message(ABC):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
@@ -1161,6 +1316,7 @@ from .lib.google.protobuf import ( # noqa
BytesValue,
DoubleValue,
Duration,
EnumValue,
FloatValue,
Int32Value,
Int64Value,
@@ -1179,7 +1335,7 @@ class _Duration(Duration):
def delta_to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
while len(parts[1]) not in (3, 6, 9):
parts[1] = f"{parts[1]}0"
return f"{'.'.join(parts)}s"
@@ -1208,33 +1364,19 @@ class _Timestamp(Timestamp):
return f"{result}.{nanos:09d}"
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
return self.value
def from_dict(self: T, value: Any) -> T:
if value is not None:
self.value = value
return self
def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type."""
# TODO: include ListValue and NullValue?
return {
TYPE_BOOL: BoolValue,
TYPE_INT32: Int32Value,
TYPE_UINT32: UInt32Value,
TYPE_INT64: Int64Value,
TYPE_UINT64: UInt64Value,
TYPE_FLOAT: FloatValue,
TYPE_DOUBLE: DoubleValue,
TYPE_STRING: StringValue,
TYPE_BYTES: BytesValue,
TYPE_DOUBLE: DoubleValue,
TYPE_FLOAT: FloatValue,
TYPE_ENUM: EnumValue,
TYPE_INT32: Int32Value,
TYPE_INT64: Int64Value,
TYPE_STRING: StringValue,
TYPE_UINT32: UInt32Value,
TYPE_UINT64: UInt64Value,
}[proto_type]

View File

@@ -0,0 +1,3 @@
from pkg_resources import get_distribution
__version__ = get_distribution("betterproto").version

View File

@@ -100,8 +100,9 @@ def reference_descendent(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package that is a descendent of the current package,
and adds the required import that is aliased to avoid name conflicts.
Returns a reference to a python type in a package that is a descendent of the
current package, and adds the required import that is aliased to avoid name
conflicts.
"""
importing_descendent = py_package[len(current_package) :]
string_from = ".".join(importing_descendent[:-1])
@@ -119,8 +120,9 @@ def reference_ancestor(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package which is an ancestor to the current package,
and adds the required import that is aliased (if possible) to avoid name conflicts.
Returns a reference to a python type in a package which is an ancestor to the
current package, and adds the required import that is aliased (if possible) to avoid
name conflicts.
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
"""
@@ -141,10 +143,10 @@ def reference_cousin(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
and adds the required import that is aliased to avoid name conflicts.
Returns a reference to a python type in a package that is not descendent, ancestor
or sibling, and adds the required import that is aliased to avoid name conflicts.
"""
shared_ancestry = os.path.commonprefix([current_package, py_package])
shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore
distance_up = len(current_package) - len(shared_ancestry)
string_from = f".{'.' * distance_up}" + ".".join(
py_package[len(shared_ancestry) : -1]

View File

@@ -0,0 +1,30 @@
from abc import ABC
from collections.abc import AsyncIterable
from typing import Callable, Any, Dict
import grpclib
import grpclib.server
class ServiceBase(ABC):
"""
Base class for async gRPC servers.
"""
async def _call_rpc_handler_server_stream(
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
) -> None:
response_iter = handler(**request_kwargs)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once
# in that case, we just interpret it as an empty iterator
if isinstance(response_iter, AsyncIterable):
async for response_message in response_iter:
await stream.send_message(response_message)
else:
response_iter.close()

View File

@@ -70,7 +70,7 @@ class AsyncChannel(AsyncIterable[T]):
"""
def __init__(self, *, buffer_limit: int = 0, close: bool = False):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit)
self._closed = False
self._waiting_receivers: int = 0
# Track whether flush has been invoked so it can only happen once

View File

@@ -1,10 +1,12 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/any.proto, google/protobuf/source_context.proto, google/protobuf/type.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/wrappers.proto
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
# plugin: python-betterproto
import warnings
from dataclasses import dataclass
from typing import Dict, List
import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
class Syntax(betterproto.Enum):
@@ -107,7 +109,7 @@ class NullValue(betterproto.Enum):
NULL_VALUE = 0
@dataclass
@dataclass(eq=False, repr=False)
class Any(betterproto.Message):
"""
`Any` contains an arbitrary serialized protocol buffer message along with a
@@ -121,24 +123,25 @@ class Any(betterproto.Message):
Example 3: Pack and unpack a message in Python. foo = Foo(...) any
= Any() any.Pack(foo) ... if any.Is(Foo.DESCRIPTOR):
any.Unpack(foo) ... Example 4: Pack and unpack a message in Go
foo := &pb.Foo{...} any, err := ptypes.MarshalAny(foo) ...
foo := &pb.Foo{} if err := ptypes.UnmarshalAny(any, foo); err != nil {
... } The pack methods provided by protobuf library will by default
use 'type.googleapis.com/full.type.name' as the type URL and the unpack
methods only use the fully qualified type name after the last '/' in the
type URL, for example "foo.bar.com/x/y.z" will yield type name "y.z". JSON
==== The JSON representation of an `Any` value uses the regular
representation of the deserialized, embedded message, with an additional
field `@type` which contains the type URL. Example: package
google.profile; message Person { string first_name = 1;
string last_name = 2; } { "@type":
"type.googleapis.com/google.profile.Person", "firstName": <string>,
"lastName": <string> } If the embedded message type is well-known and
has a custom JSON representation, that representation will be embedded
adding a field `value` which holds the custom JSON in addition to the
`@type` field. Example (for message [google.protobuf.Duration][]): {
"@type": "type.googleapis.com/google.protobuf.Duration", "value":
"1.212s" }
foo := &pb.Foo{...} any, err := anypb.New(foo) if err != nil {
... } ... foo := &pb.Foo{} if err :=
any.UnmarshalTo(foo); err != nil { ... } The pack methods
provided by protobuf library will by default use
'type.googleapis.com/full.type.name' as the type URL and the unpack methods
only use the fully qualified type name after the last '/' in the type URL,
for example "foo.bar.com/x/y.z" will yield type name "y.z". JSON ==== The
JSON representation of an `Any` value uses the regular representation of
the deserialized, embedded message, with an additional field `@type` which
contains the type URL. Example: package google.profile; message
Person { string first_name = 1; string last_name = 2; }
{ "@type": "type.googleapis.com/google.profile.Person",
"firstName": <string>, "lastName": <string> } If the embedded
message type is well-known and has a custom JSON representation, that
representation will be embedded adding a field `value` which holds the
custom JSON in addition to the `@type` field. Example (for message
[google.protobuf.Duration][]): { "@type":
"type.googleapis.com/google.protobuf.Duration", "value": "1.212s"
}
"""
# A URL/resource name that uniquely identifies the type of the serialized
@@ -165,7 +168,7 @@ class Any(betterproto.Message):
value: bytes = betterproto.bytes_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class SourceContext(betterproto.Message):
"""
`SourceContext` represents information about the source of a protobuf
@@ -177,7 +180,7 @@ class SourceContext(betterproto.Message):
file_name: str = betterproto.string_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class Type(betterproto.Message):
"""A protocol buffer message type."""
@@ -195,7 +198,7 @@ class Type(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(6)
@dataclass
@dataclass(eq=False, repr=False)
class Field(betterproto.Message):
"""A single field of a message type."""
@@ -223,7 +226,7 @@ class Field(betterproto.Message):
default_value: str = betterproto.string_field(11)
@dataclass
@dataclass(eq=False, repr=False)
class Enum(betterproto.Message):
"""Enum type definition."""
@@ -241,7 +244,7 @@ class Enum(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(5)
@dataclass
@dataclass(eq=False, repr=False)
class EnumValue(betterproto.Message):
"""Enum value definition."""
@@ -253,7 +256,7 @@ class EnumValue(betterproto.Message):
options: List["Option"] = betterproto.message_field(3)
@dataclass
@dataclass(eq=False, repr=False)
class Option(betterproto.Message):
"""
A protocol buffer option, which can be attached to a message, field,
@@ -272,7 +275,7 @@ class Option(betterproto.Message):
value: "Any" = betterproto.message_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class Api(betterproto.Message):
"""
Api is a light-weight descriptor for an API Interface. Interfaces are also
@@ -315,7 +318,7 @@ class Api(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(7)
@dataclass
@dataclass(eq=False, repr=False)
class Method(betterproto.Message):
"""Method represents a method of an API interface."""
@@ -335,7 +338,7 @@ class Method(betterproto.Message):
syntax: "Syntax" = betterproto.enum_field(7)
@dataclass
@dataclass(eq=False, repr=False)
class Mixin(betterproto.Message):
"""
Declares an API Interface to be included in this interface. The including
@@ -360,7 +363,7 @@ class Mixin(betterproto.Message):
implies that all methods in `AccessControl` are also declared with same
name and request/response types in `Storage`. A documentation generator or
annotation processor will see the effective `Storage.GetAcl` method after
inherting documentation and annotations as follows: service Storage {
inheriting documentation and annotations as follows: service Storage {
// Get the underlying ACL object. rpc GetAcl(GetAclRequest) returns
(Acl) { option (google.api.http).get = "/v2/{resource=**}:getAcl";
} ... } Note how the version in the path pattern changed from
@@ -380,7 +383,7 @@ class Mixin(betterproto.Message):
root: str = betterproto.string_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class FileDescriptorSet(betterproto.Message):
"""
The protocol compiler can output a FileDescriptorSet containing the .proto
@@ -390,7 +393,7 @@ class FileDescriptorSet(betterproto.Message):
file: List["FileDescriptorProto"] = betterproto.message_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class FileDescriptorProto(betterproto.Message):
"""Describes a complete .proto file."""
@@ -419,7 +422,7 @@ class FileDescriptorProto(betterproto.Message):
syntax: str = betterproto.string_field(12)
@dataclass
@dataclass(eq=False, repr=False)
class DescriptorProto(betterproto.Message):
"""Describes a message type."""
@@ -439,14 +442,14 @@ class DescriptorProto(betterproto.Message):
reserved_name: List[str] = betterproto.string_field(10)
@dataclass
@dataclass(eq=False, repr=False)
class DescriptorProtoExtensionRange(betterproto.Message):
start: int = betterproto.int32_field(1)
end: int = betterproto.int32_field(2)
options: "ExtensionRangeOptions" = betterproto.message_field(3)
@dataclass
@dataclass(eq=False, repr=False)
class DescriptorProtoReservedRange(betterproto.Message):
"""
Range of reserved tag numbers. Reserved tag numbers may not be used by
@@ -458,13 +461,13 @@ class DescriptorProtoReservedRange(betterproto.Message):
end: int = betterproto.int32_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class ExtensionRangeOptions(betterproto.Message):
# The parser stores options it doesn't recognize here. See above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class FieldDescriptorProto(betterproto.Message):
"""Describes a field within a message."""
@@ -496,9 +499,26 @@ class FieldDescriptorProto(betterproto.Message):
# camelCase.
json_name: str = betterproto.string_field(10)
options: "FieldOptions" = betterproto.message_field(8)
# If true, this is a proto3 "optional". When a proto3 field is optional, it
# tracks presence regardless of field type. When proto3_optional is true,
# this field must be belong to a oneof to signal to old proto3 clients that
# presence is tracked for this field. This oneof is known as a "synthetic"
# oneof, and this field must be its sole member (each proto3 optional field
# gets its own synthetic oneof). Synthetic oneofs exist in the descriptor
# only, and do not generate any API. Synthetic oneofs must be ordered after
# all "real" oneofs. For message fields, proto3_optional doesn't create any
# semantic change, since non-repeated message fields always track presence.
# However it still indicates the semantic detail of whether the user wrote
# "optional" or not. This can be useful for round-tripping the .proto file.
# For consistency we give message fields a synthetic oneof also, even though
# it is not required to track presence. This is especially important because
# the parser can't tell if a field is a message or an enum, so it must always
# create a synthetic oneof. Proto2 optional fields do not set this flag,
# because they already indicate optional with `LABEL_OPTIONAL`.
proto3_optional: bool = betterproto.bool_field(17)
@dataclass
@dataclass(eq=False, repr=False)
class OneofDescriptorProto(betterproto.Message):
"""Describes a oneof."""
@@ -506,14 +526,12 @@ class OneofDescriptorProto(betterproto.Message):
options: "OneofOptions" = betterproto.message_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class EnumDescriptorProto(betterproto.Message):
"""Describes an enum type."""
name: str = betterproto.string_field(1)
value: List["EnumValueDescriptorProto"] = betterproto.message_field(
2, wraps=betterproto.TYPE_ENUM
)
value: List["EnumValueDescriptorProto"] = betterproto.message_field(2)
options: "EnumOptions" = betterproto.message_field(3)
# Range of reserved numeric values. Reserved numeric values may not be used
# by enum values in the same enum declaration. Reserved ranges may not
@@ -526,7 +544,7 @@ class EnumDescriptorProto(betterproto.Message):
reserved_name: List[str] = betterproto.string_field(5)
@dataclass
@dataclass(eq=False, repr=False)
class EnumDescriptorProtoEnumReservedRange(betterproto.Message):
"""
Range of reserved numeric values. Reserved values may not be used by
@@ -539,18 +557,16 @@ class EnumDescriptorProtoEnumReservedRange(betterproto.Message):
end: int = betterproto.int32_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class EnumValueDescriptorProto(betterproto.Message):
"""Describes a value within an enum."""
name: str = betterproto.string_field(1)
number: int = betterproto.int32_field(2)
options: "EnumValueOptions" = betterproto.message_field(
3, wraps=betterproto.TYPE_ENUM
)
options: "EnumValueOptions" = betterproto.message_field(3)
@dataclass
@dataclass(eq=False, repr=False)
class ServiceDescriptorProto(betterproto.Message):
"""Describes a service."""
@@ -559,7 +575,7 @@ class ServiceDescriptorProto(betterproto.Message):
options: "ServiceOptions" = betterproto.message_field(3)
@dataclass
@dataclass(eq=False, repr=False)
class MethodDescriptorProto(betterproto.Message):
"""Describes a method of a service."""
@@ -575,24 +591,25 @@ class MethodDescriptorProto(betterproto.Message):
server_streaming: bool = betterproto.bool_field(6)
@dataclass
@dataclass(eq=False, repr=False)
class FileOptions(betterproto.Message):
# Sets the Java package where classes generated from this .proto will be
# placed. By default, the proto package is used, but this is often
# inappropriate because proto packages do not normally start with backwards
# domain names.
java_package: str = betterproto.string_field(1)
# If set, all the classes from the .proto file are wrapped in a single outer
# class with the given name. This applies to both Proto1 (equivalent to the
# old "--one_java_file" option) and Proto2 (where a .proto always translates
# to a single class, but you may want to explicitly choose the class name).
# Controls the name of the wrapper Java class generated for the .proto file.
# That class will always contain the .proto file's getDescriptor() method as
# well as any top-level extensions defined in the .proto file. If
# java_multiple_files is disabled, then all the other classes from the .proto
# file will be nested inside the single wrapper outer class.
java_outer_classname: str = betterproto.string_field(8)
# If set true, then the Java code generator will generate a separate .java
# If enabled, then the Java code generator will generate a separate .java
# file for each top-level message, enum, and service defined in the .proto
# file. Thus, these types will *not* be nested inside the outer class named
# by java_outer_classname. However, the outer class will still be generated
# to contain the file's getDescriptor() method as well as any top-level
# extensions defined in the file.
# file. Thus, these types will *not* be nested inside the wrapper class
# named by java_outer_classname. However, the wrapper class will still be
# generated to contain the file's getDescriptor() method as well as any top-
# level extensions defined in the file.
java_multiple_files: bool = betterproto.bool_field(10)
# This option does nothing.
java_generate_equals_and_hash: bool = betterproto.bool_field(20)
@@ -657,8 +674,16 @@ class FileOptions(betterproto.Message):
# for the "Options" section above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
def __post_init__(self) -> None:
super().__post_init__()
if self.java_generate_equals_and_hash:
warnings.warn(
"FileOptions.java_generate_equals_and_hash is deprecated",
DeprecationWarning,
)
@dataclass
@dataclass(eq=False, repr=False)
class MessageOptions(betterproto.Message):
# Set true to use the old proto1 MessageSet wire format for extensions. This
# is provided for backwards-compatibility with the MessageSet wire format.
@@ -695,7 +720,7 @@ class MessageOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class FieldOptions(betterproto.Message):
# The ctype option instructs the C++ code generator to use a different
# representation of the field than it normally would. See the specific
@@ -752,13 +777,13 @@ class FieldOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class OneofOptions(betterproto.Message):
# The parser stores options it doesn't recognize here. See above.
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class EnumOptions(betterproto.Message):
# Set this option to true to allow mapping different tag names to the same
# value.
@@ -771,7 +796,7 @@ class EnumOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class EnumValueOptions(betterproto.Message):
# Is this enum value deprecated? Depending on the target platform, this can
# emit Deprecated annotations for the enum value, or it will be completely
@@ -782,7 +807,7 @@ class EnumValueOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class ServiceOptions(betterproto.Message):
# Is this service deprecated? Depending on the target platform, this can emit
# Deprecated annotations for the service, or it will be completely ignored;
@@ -792,7 +817,7 @@ class ServiceOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class MethodOptions(betterproto.Message):
# Is this method deprecated? Depending on the target platform, this can emit
# Deprecated annotations for the method, or it will be completely ignored; in
@@ -803,7 +828,7 @@ class MethodOptions(betterproto.Message):
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
@dataclass
@dataclass(eq=False, repr=False)
class UninterpretedOption(betterproto.Message):
"""
A message representing a option the parser does not recognize. This only
@@ -825,7 +850,7 @@ class UninterpretedOption(betterproto.Message):
aggregate_value: str = betterproto.string_field(8)
@dataclass
@dataclass(eq=False, repr=False)
class UninterpretedOptionNamePart(betterproto.Message):
"""
The name of the uninterpreted option. Each string represents a segment in
@@ -839,7 +864,7 @@ class UninterpretedOptionNamePart(betterproto.Message):
is_extension: bool = betterproto.bool_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class SourceCodeInfo(betterproto.Message):
"""
Encapsulates information about the original source file from which a
@@ -878,7 +903,7 @@ class SourceCodeInfo(betterproto.Message):
location: List["SourceCodeInfoLocation"] = betterproto.message_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class SourceCodeInfoLocation(betterproto.Message):
# Identifies which part of the FileDescriptorProto was defined at this
# location. Each element is a field number or an index. They form a path
@@ -925,7 +950,7 @@ class SourceCodeInfoLocation(betterproto.Message):
leading_detached_comments: List[str] = betterproto.string_field(6)
@dataclass
@dataclass(eq=False, repr=False)
class GeneratedCodeInfo(betterproto.Message):
"""
Describes the relationship between generated code and its original source
@@ -938,7 +963,7 @@ class GeneratedCodeInfo(betterproto.Message):
annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class GeneratedCodeInfoAnnotation(betterproto.Message):
# Identifies the element in the original source .proto file. This field is
# formatted the same as SourceCodeInfo.Location.path.
@@ -954,7 +979,7 @@ class GeneratedCodeInfoAnnotation(betterproto.Message):
end: int = betterproto.int32_field(4)
@dataclass
@dataclass(eq=False, repr=False)
class Duration(betterproto.Message):
"""
A Duration represents a signed, fixed-length span of time represented as a
@@ -999,7 +1024,7 @@ class Duration(betterproto.Message):
nanos: int = betterproto.int32_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class Empty(betterproto.Message):
"""
A generic empty message that you can re-use to avoid defining duplicated
@@ -1012,7 +1037,7 @@ class Empty(betterproto.Message):
pass
@dataclass
@dataclass(eq=False, repr=False)
class FieldMask(betterproto.Message):
"""
`FieldMask` represents a set of symbolic field paths, for example:
@@ -1096,7 +1121,7 @@ class FieldMask(betterproto.Message):
paths: List[str] = betterproto.string_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class Struct(betterproto.Message):
"""
`Struct` represents a structured data value, consisting of fields which map
@@ -1113,7 +1138,7 @@ class Struct(betterproto.Message):
)
@dataclass
@dataclass(eq=False, repr=False)
class Value(betterproto.Message):
"""
`Value` represents a dynamically typed value which can be either null, a
@@ -1137,7 +1162,7 @@ class Value(betterproto.Message):
list_value: "ListValue" = betterproto.message_field(6, group="kind")
@dataclass
@dataclass(eq=False, repr=False)
class ListValue(betterproto.Message):
"""
`ListValue` is a wrapper around a repeated field of values. The JSON
@@ -1148,7 +1173,7 @@ class ListValue(betterproto.Message):
values: List["Value"] = betterproto.message_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class Timestamp(betterproto.Message):
"""
A Timestamp represents a point in time independent of any time zone or
@@ -1178,20 +1203,22 @@ class Timestamp(betterproto.Message):
long millis = System.currentTimeMillis(); Timestamp timestamp =
Timestamp.newBuilder().setSeconds(millis / 1000) .setNanos((int)
((millis % 1000) * 1000000)).build(); Example 5: Compute Timestamp from
current time in Python. timestamp = Timestamp()
timestamp.GetCurrentTime() # JSON Mapping In JSON format, the Timestamp
type is encoded as a string in the [RFC
3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the format is
"{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" where {year} is
always expressed using four digits while {month}, {day}, {hour}, {min}, and
{sec} are zero-padded to two digits each. The fractional seconds, which can
go up to 9 digits (i.e. up to 1 nanosecond resolution), are optional. The
"Z" suffix indicates the timezone ("UTC"); the timezone is required. A
proto3 JSON serializer should always use UTC (as indicated by "Z") when
printing the Timestamp type and a proto3 JSON parser should be able to
accept both UTC and other timezones (as indicated by an offset). For
example, "2017-01-15T01:30:15.01Z" encodes 15.01 seconds past 01:30 UTC on
January 15, 2017. In JavaScript, one can convert a Date object to this
Java `Instant.now()`. Instant now = Instant.now(); Timestamp
timestamp = Timestamp.newBuilder().setSeconds(now.getEpochSecond())
.setNanos(now.getNano()).build(); Example 6: Compute Timestamp from current
time in Python. timestamp = Timestamp() timestamp.GetCurrentTime()
# JSON Mapping In JSON format, the Timestamp type is encoded as a string in
the [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the
format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" where
{year} is always expressed using four digits while {month}, {day}, {hour},
{min}, and {sec} are zero-padded to two digits each. The fractional
seconds, which can go up to 9 digits (i.e. up to 1 nanosecond resolution),
are optional. The "Z" suffix indicates the timezone ("UTC"); the timezone
is required. A proto3 JSON serializer should always use UTC (as indicated
by "Z") when printing the Timestamp type and a proto3 JSON parser should be
able to accept both UTC and other timezones (as indicated by an offset).
For example, "2017-01-15T01:30:15.01Z" encodes 15.01 seconds past 01:30 UTC
on January 15, 2017. In JavaScript, one can convert a Date object to this
format using the standard [toISOString()](https://developer.mozilla.org/en-
US/docs/Web/JavaScript/Reference/Global_Objects/Date/toISOString) method.
In Python, a standard `datetime.datetime` object can be converted to this
@@ -1213,7 +1240,7 @@ class Timestamp(betterproto.Message):
nanos: int = betterproto.int32_field(2)
@dataclass
@dataclass(eq=False, repr=False)
class DoubleValue(betterproto.Message):
"""
Wrapper message for `double`. The JSON representation for `DoubleValue` is
@@ -1224,7 +1251,7 @@ class DoubleValue(betterproto.Message):
value: float = betterproto.double_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class FloatValue(betterproto.Message):
"""
Wrapper message for `float`. The JSON representation for `FloatValue` is
@@ -1235,7 +1262,7 @@ class FloatValue(betterproto.Message):
value: float = betterproto.float_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class Int64Value(betterproto.Message):
"""
Wrapper message for `int64`. The JSON representation for `Int64Value` is
@@ -1246,7 +1273,7 @@ class Int64Value(betterproto.Message):
value: int = betterproto.int64_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class UInt64Value(betterproto.Message):
"""
Wrapper message for `uint64`. The JSON representation for `UInt64Value` is
@@ -1257,7 +1284,7 @@ class UInt64Value(betterproto.Message):
value: int = betterproto.uint64_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class Int32Value(betterproto.Message):
"""
Wrapper message for `int32`. The JSON representation for `Int32Value` is
@@ -1268,7 +1295,7 @@ class Int32Value(betterproto.Message):
value: int = betterproto.int32_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class UInt32Value(betterproto.Message):
"""
Wrapper message for `uint32`. The JSON representation for `UInt32Value` is
@@ -1279,7 +1306,7 @@ class UInt32Value(betterproto.Message):
value: int = betterproto.uint32_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class BoolValue(betterproto.Message):
"""
Wrapper message for `bool`. The JSON representation for `BoolValue` is JSON
@@ -1290,7 +1317,7 @@ class BoolValue(betterproto.Message):
value: bool = betterproto.bool_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class StringValue(betterproto.Message):
"""
Wrapper message for `string`. The JSON representation for `StringValue` is
@@ -1301,7 +1328,7 @@ class StringValue(betterproto.Message):
value: str = betterproto.string_field(1)
@dataclass
@dataclass(eq=False, repr=False)
class BytesValue(betterproto.Message):
"""
Wrapper message for `bytes`. The JSON representation for `BytesValue` is

View File

@@ -0,0 +1,128 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/compiler/plugin.proto
# plugin: python-betterproto
from dataclasses import dataclass
from typing import List
import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
class CodeGeneratorResponseFeature(betterproto.Enum):
FEATURE_NONE = 0
FEATURE_PROTO3_OPTIONAL = 1
@dataclass(eq=False, repr=False)
class Version(betterproto.Message):
"""The version number of protocol compiler."""
major: int = betterproto.int32_field(1)
minor: int = betterproto.int32_field(2)
patch: int = betterproto.int32_field(3)
# A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
# be empty for mainline stable releases.
suffix: str = betterproto.string_field(4)
@dataclass(eq=False, repr=False)
class CodeGeneratorRequest(betterproto.Message):
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
# The .proto files that were explicitly listed on the command-line. The code
# generator should generate code only for these files. Each file's
# descriptor will be included in proto_file, below.
file_to_generate: List[str] = betterproto.string_field(1)
# The generator parameter passed on the command-line.
parameter: str = betterproto.string_field(2)
# FileDescriptorProtos for all files in files_to_generate and everything they
# import. The files will appear in topological order, so each file appears
# before any file that imports it. protoc guarantees that all proto_files
# will be written after the fields above, even though this is not technically
# guaranteed by the protobuf wire format. This theoretically could allow a
# plugin to stream in the FileDescriptorProtos and handle them one by one
# rather than read the entire set into memory at once. However, as of this
# writing, this is not similarly optimized on protoc's end -- it will store
# all fields in memory at once before sending them to the plugin. Type names
# of fields and extensions in the FileDescriptorProto are always fully
# qualified.
proto_file: List[
"betterproto_lib_google_protobuf.FileDescriptorProto"
] = betterproto.message_field(15)
# The version number of protocol compiler.
compiler_version: "Version" = betterproto.message_field(3)
@dataclass(eq=False, repr=False)
class CodeGeneratorResponse(betterproto.Message):
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
# Error message. If non-empty, code generation failed. The plugin process
# should exit with status code zero even if it reports an error in this way.
# This should be used to indicate errors in .proto files which prevent the
# code generator from generating correct code. Errors which indicate a
# problem in protoc itself -- such as the input CodeGeneratorRequest being
# unparseable -- should be reported by writing a message to stderr and
# exiting with a non-zero status code.
error: str = betterproto.string_field(1)
# A bitmask of supported features that the code generator supports. This is a
# bitwise "or" of values from the Feature enum.
supported_features: int = betterproto.uint64_field(2)
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
@dataclass(eq=False, repr=False)
class CodeGeneratorResponseFile(betterproto.Message):
"""Represents a single generated file."""
# The file name, relative to the output directory. The name must not contain
# "." or ".." components and must be relative, not be absolute (so, the file
# cannot lie outside the output directory). "/" must be used as the path
# separator, not "\". If the name is omitted, the content will be appended to
# the previous file. This allows the generator to break large files into
# small chunks, and allows the generated text to be streamed back to protoc
# so that large files need not reside completely in memory at one time. Note
# that as of this writing protoc does not optimize for this -- it will read
# the entire CodeGeneratorResponse before writing files to disk.
name: str = betterproto.string_field(1)
# If non-empty, indicates that the named file should already exist, and the
# content here is to be inserted into that file at a defined insertion point.
# This feature allows a code generator to extend the output produced by
# another code generator. The original generator may provide insertion
# points by placing special annotations in the file that look like:
# @@protoc_insertion_point(NAME) The annotation can have arbitrary text
# before and after it on the line, which allows it to be placed in a comment.
# NAME should be replaced with an identifier naming the point -- this is what
# other generators will use as the insertion_point. Code inserted at this
# point will be placed immediately above the line containing the insertion
# point (thus multiple insertions to the same point will come out in the
# order they were added). The double-@ is intended to make it unlikely that
# the generated code could contain things that look like insertion points by
# accident. For example, the C++ code generator places the following line in
# the .pb.h files that it generates: //
# @@protoc_insertion_point(namespace_scope) This line appears within the
# scope of the file's package namespace, but outside of any particular class.
# Another plugin can then specify the insertion_point "namespace_scope" to
# generate additional classes or other declarations that should be placed in
# this scope. Note that if the line containing the insertion point begins
# with whitespace, the same whitespace will be added to every line of the
# inserted text. This is useful for languages like Python, where indentation
# matters. In these languages, the insertion point comment should be
# indented the same amount as any inserted code will need to be in order to
# work correctly in that context. The code generator that generates the
# initial file and the one which inserts into it must both run as part of a
# single invocation of protoc. Code generators are executed in the order in
# which they appear on the command line. If |insertion_point| is present,
# |name| must also be present.
insertion_point: str = betterproto.string_field(2)
# The file contents.
content: str = betterproto.string_field(15)
# Information describing the file content being inserted. If an insertion
# point is used, this information will be appropriately offset and inserted
# into the code generation metadata for the generated files.
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
betterproto.message_field(16)
)
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf

View File

@@ -33,5 +33,5 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
return black.format_str(
template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
mode=black.Mode(),
)

20
src/betterproto/plugin/main.py Normal file → Executable file
View File

@@ -3,9 +3,13 @@
import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest,
CodeGeneratorResponse,
)
from betterproto.plugin.parser import generate_code
from betterproto.plugin.models import monkey_patch_oneof_index
def main() -> None:
@@ -13,19 +17,19 @@ def main() -> None:
# Read request message from stdin
data = sys.stdin.buffer.read()
# Apply Work around for proto2/3 difference in protoc messages
monkey_patch_oneof_index()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
request = CodeGeneratorRequest()
request.parse(data)
dump_file = os.getenv("BETTERPROTO_DUMP")
if dump_file:
dump_request(dump_file, request)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
response = generate_code(request)
# Serialise response message
output = response.SerializeToString()
@@ -34,7 +38,7 @@ def main() -> None:
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None:
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.

View File

@@ -29,12 +29,37 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import builtins
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field,
FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel,
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
@@ -44,26 +69,6 @@ from ..compile.naming import (
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@@ -71,54 +76,75 @@ PLACEHOLDER = object()
# Organize proto types into categories
PROTO_FLOAT_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
FieldDescriptorProtoType.TYPE_FLOAT, # 2
)
PROTO_INT_TYPES = (
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
FieldDescriptorProtoType.TYPE_INT64, # 3
FieldDescriptorProtoType.TYPE_UINT64, # 4
FieldDescriptorProtoType.TYPE_INT32, # 5
FieldDescriptorProtoType.TYPE_FIXED64, # 6
FieldDescriptorProtoType.TYPE_FIXED32, # 7
FieldDescriptorProtoType.TYPE_UINT32, # 13
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
FieldDescriptorProtoType.TYPE_SINT32, # 17
FieldDescriptorProtoType.TYPE_SINT64, # 18
)
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
PROTO_MESSAGE_TYPES = (
FieldDescriptorProto.TYPE_MESSAGE, # 11
FieldDescriptorProto.TYPE_ENUM, # 14
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
FieldDescriptorProtoType.TYPE_ENUM, # 14
)
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
PROTO_PACKED_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_BOOL, # 8
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
FieldDescriptorProtoType.TYPE_FLOAT, # 2
FieldDescriptorProtoType.TYPE_INT64, # 3
FieldDescriptorProtoType.TYPE_UINT64, # 4
FieldDescriptorProtoType.TYPE_INT32, # 5
FieldDescriptorProtoType.TYPE_FIXED64, # 6
FieldDescriptorProtoType.TYPE_FIXED32, # 7
FieldDescriptorProtoType.TYPE_BOOL, # 8
FieldDescriptorProtoType.TYPE_UINT32, # 13
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
FieldDescriptorProtoType.TYPE_SINT32, # 17
FieldDescriptorProtoType.TYPE_SINT64, # 18
)
def monkey_patch_oneof_index():
"""
The compiler message types are written for proto2, but we read them as proto3.
For this to work in the case of the oneof_index fields, which depend on being able
to tell whether they were set, we have to treat them as oneof fields. This method
monkey patches the generated classes after the fact to force this behaviour.
"""
object.__setattr__(
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
"betterproto"
],
"group",
"oneof_index",
)
object.__setattr__(
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
"group",
"oneof_index",
)
def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
if list(sci.path) == path and sci.leading_comments:
for sci_loc in proto_file.source_code_info.location:
if list(sci_loc.path) == path and sci_loc.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
@@ -139,10 +165,13 @@ def get_comment(
class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
__dataclass_fields__: Dict[str, object]
def __post_init__(self) -> None:
"""Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items():
@@ -156,13 +185,6 @@ class ProtoContentBase:
current = current.parent
return current
@property
def proto_file(self) -> FieldDescriptorProto:
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.package_proto_obj
@property
def request(self) -> "PluginRequestCompiler":
current = self
@@ -176,14 +198,14 @@ class ProtoContentBase:
for this object.
"""
return get_comment(
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
proto_file=self.source_file, path=self.path, indent=self.comment_indent
)
@dataclass
class PluginRequestCompiler:
plugin_request_obj: plugin.CodeGeneratorRequest
plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@property
@@ -215,6 +237,7 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
@@ -231,21 +254,23 @@ class OutputTemplate:
return self.package_proto_obj.package
@property
def input_filenames(self) -> List[str]:
def input_filenames(self) -> Iterable[str]:
"""Names of the input files used to build this output.
Returns
-------
List[str]
Iterable[str]
Names of the input files used to build this output.
"""
return [f.name for f in self.input_files]
return sorted(f.name for f in self.input_files)
@property
def python_module_imports(self) -> Set[str]:
imports = set()
if any(x for x in self.messages if any(x.deprecated_fields)):
imports.add("warnings")
if self.builtins_import:
imports.add("builtins")
return imports
@@ -253,6 +278,7 @@ class OutputTemplate:
class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
source_file: FileDescriptorProto
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@@ -260,6 +286,7 @@ class MessageCompiler(ProtoContentBase):
default_factory=list
)
deprecated: bool = field(default=False, init=False)
builtins_types: Set[str] = field(default_factory=set)
def __post_init__(self) -> None:
# Add message to output file
@@ -291,12 +318,16 @@ class MessageCompiler(ProtoContentBase):
if f.deprecated:
yield f.py_name
@property
def has_deprecated_fields(self) -> bool:
return any(self.deprecated_fields)
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool:
"""True if proto_field_obj is a map, otherwise False."""
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
# This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower()
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
@@ -311,8 +342,20 @@ def is_map(
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False."""
return proto_field_obj.HasField("oneof_index")
"""
True if proto_field_obj is a OneOf, otherwise False.
.. warning::
Becuase the message from protoc is defined in proto2, and betterproto works with
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
distinguishing between default and unset values (which proto3 doesn't support),
we have to hack the generated FieldDescriptorProto class for this to work.
The hack consists of setting group="oneof_index" in the field metadata,
essentially making oneof_index the sole member of a one_of group, which allows
us to tell whether it was set, via the which_one_of interface.
"""
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
@dataclass
@@ -324,17 +367,7 @@ class FieldCompiler(MessageCompiler):
# Add field to message
self.parent.fields.append(self)
# Check for new imports
annotation = self.annotation
if "Optional[" in annotation:
self.output_file.typing_imports.add("Optional")
if "List[" in annotation:
self.output_file.typing_imports.add("List")
if "Dict[" in annotation:
self.output_file.typing_imports.add("Dict")
if "timedelta" in annotation:
self.output_file.datetime_imports.add("timedelta")
if "datetime" in annotation:
self.output_file.datetime_imports.add("datetime")
self.add_imports_to(self.output_file)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
def get_field_string(self, indent: int = 4) -> str:
@@ -347,6 +380,8 @@ class FieldCompiler(MessageCompiler):
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
)
if self.py_name in dir(builtins):
self.parent.builtins_types.add(self.py_name)
return f"{name}{annotations} = {betterproto_field_type}"
@property
@@ -354,13 +389,49 @@ class FieldCompiler(MessageCompiler):
args = []
if self.field_wraps:
args.append(f"wraps={self.field_wraps}")
if self.optional:
args.append(f"optional=True")
return args
@property
def datetime_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
# FIXME: false positives - e.g. `MyDatetimedelta`
if "timedelta" in annotation:
imports.add("timedelta")
if "datetime" in annotation:
imports.add("datetime")
return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
@property
def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or (
self.py_type == self.py_name and self.py_name in dir(builtins)
)
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@property
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
@@ -371,10 +442,14 @@ class FieldCompiler(MessageCompiler):
@property
def repeated(self) -> bool:
return (
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
and not is_map(self.proto_obj, self.parent)
)
@property
def optional(self) -> bool:
return self.proto_obj.proto3_optional
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
@@ -384,14 +459,18 @@ class FieldCompiler(MessageCompiler):
def field_type(self) -> str:
"""String representation of proto field type."""
return (
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
FieldDescriptorProtoType(self.proto_obj.type)
.name.lower()
.replace("type_", "")
)
@property
def default_value_string(self) -> Union[Text, None, float, int]:
def default_value_string(self) -> str:
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.optional:
return "None"
if self.py_type == "int":
return "0"
if self.py_type == "float":
@@ -402,6 +481,14 @@ class FieldCompiler(MessageCompiler):
return '""'
elif self.py_type == "bytes":
return 'b""'
elif self.field_type == "enum":
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
enum = next(
e
for e in self.output_file.enums
if e.proto_obj.name == enum_proto_obj_name
)
return enum.default_value_string
else:
# Message type
return "None"
@@ -446,9 +533,14 @@ class FieldCompiler(MessageCompiler):
@property
def annotation(self) -> str:
py_type = self.py_type
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{self.py_type}]"
return self.py_type
return f"List[{py_type}]"
if self.optional:
return f"Optional[{py_type}]"
return py_type
@dataclass
@@ -478,14 +570,19 @@ class MapEntryCompiler(FieldCompiler):
):
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1] # value
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
@property
@@ -527,7 +624,7 @@ class EnumDefinitionCompiler(MessageCompiler):
name=sanitize_name(entry_proto_value.name),
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.proto_file, path=self.path + [2, entry_number]
proto_file=self.source_file, path=self.path + [2, entry_number]
),
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
@@ -553,6 +650,7 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields
@property
@@ -576,11 +674,10 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service
self.parent.methods.append(self)
# Check for Optional import
# Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
if f.default_value_string == "None":
self.output_file.typing_imports.add("Optional")
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
@@ -590,7 +687,9 @@ class ServiceMethodCompiler(ProtoContentBase):
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
if self.server_streaming:
# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
super().__post_init__() # check for unset fields
@@ -638,7 +737,10 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def route(self) -> str:
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
package_part = (
f"{self.output_file.package}." if self.output_file.package else ""
)
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Optional[MessageCompiler]:

View File

@@ -1,28 +1,20 @@
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest,
CodeGeneratorResponse,
CodeGeneratorResponseFeature,
CodeGeneratorResponseFile,
)
import itertools
import pathlib
import sys
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
ServiceDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
@@ -47,7 +39,7 @@ def traverse(
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy
def _traverse(
path: List[int], items: List["Descriptor"], prefix=""
path: List[int], items: List["EnumDescriptorProto"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
@@ -69,10 +61,11 @@ def traverse(
)
def generate_code(
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
) -> None:
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response = CodeGeneratorResponse()
plugin_options = request.parameter.split(",") if request.parameter else []
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
@@ -100,7 +93,12 @@ def generate_code(
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(item=item, path=path, output_package=output_package)
read_protobuf_type(
source_file=proto_input_file,
item=item,
path=path,
output_package=output_package,
)
# Read Services
for output_package_name, output_package in request_data.output_packages.items():
@@ -116,11 +114,13 @@ def generate_code(
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f: response.File = response.file.add()
f.name = str(output_path)
# Render and then format the output file
f.content = outputfile_compiler(output_file=output_package)
response.file.append(
CodeGeneratorResponseFile(
name=str(output_path),
# Render and then format the output file
content=outputfile_compiler(output_file=output_package),
)
)
# Make each output directory a package with __init__ file
init_files = {
@@ -130,38 +130,55 @@ def generate_code(
} - output_paths
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
return response
def read_protobuf_type(
item: DescriptorProto, path: List[int], output_package: OutputTemplate
item: DescriptorProto,
path: List[int],
source_file: "FileDescriptorProto",
output_package: OutputTemplate,
) -> None:
if isinstance(item, DescriptorProto):
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
# Process Message
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
for index, field in enumerate(item.field):
if is_map(field, item):
MapEntryCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
else:
FieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
def read_protobuf_service(

View File

@@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}
import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
{% if output_file.services %}
import grpclib
{% endif %}
@@ -53,7 +54,7 @@ class {{ message.py_name }}(betterproto.Message):
pass
{% endif %}
{% if message.deprecated or message.deprecated_fields %}
{% if message.deprecated or message.has_deprecated_fields %}
def __post_init__(self) -> None:
{% if message.deprecated %}
warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning)
@@ -72,6 +73,8 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
{% elif not service.methods %}
pass
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
@@ -82,7 +85,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
@@ -154,6 +157,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %}
{% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%},
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
{% endfor %}
{% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %}
request_kwargs = {"request_iterator": stream.__aiter__()}
{% endif %}
{% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs)
await stream.send_message(response)
{% else %}
await self._call_rpc_handler_server_stream(
self.{{ method.py_name }},
stream,
request_kwargs,
)
{% endif %}
{% endfor %}
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
self.__rpc_{{ method.py_name }},
{% if not method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.UNARY_UNARY,
{% elif not method.client_streaming and method.server_streaming %}
grpclib.const.Cardinality.UNARY_STREAM,
{% elif method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.STREAM_UNARY,
{% else %}
grpclib.const.Cardinality.STREAM_STREAM,
{% endif %}
{{ method.py_input_message_type }},
{{ method.py_output_message_type }},
),
{% endfor %}
}
{% endfor %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}

View File

@@ -60,13 +60,15 @@ async def generate(whitelist: Set[str], verbose: bool):
if result != 0:
failed_test_cases.append(test_case_name)
if failed_test_cases:
if len(failed_test_cases) > 0:
sys.stderr.write(
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
)
for failed_test_case in failed_test_cases:
sys.stderr.write(f"- {failed_test_case}\n")
sys.exit(1)
async def generate_test_case_output(
test_case_input_path: Path, test_case_name: str, verbose: bool
@@ -92,21 +94,41 @@ async def generate_test_case_output(
protoc(test_case_input_path, test_case_output_path_betterproto, False),
)
message = f"Generated output for {test_case_name!r}"
if verbose:
print(f"\033[31;1;4m{message}\033[0m")
if ref_out:
sys.stdout.buffer.write(ref_out)
if ref_err:
sys.stderr.buffer.write(ref_err)
if plg_out:
sys.stdout.buffer.write(plg_out)
if plg_err:
sys.stderr.buffer.write(plg_err)
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
if ref_code == 0:
print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m")
else:
print(message)
print(
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
)
if verbose:
if ref_out:
print("Reference stdout:")
sys.stdout.buffer.write(ref_out)
sys.stdout.buffer.flush()
if ref_err:
print("Reference stderr:")
sys.stderr.buffer.write(ref_err)
sys.stderr.buffer.flush()
if plg_code == 0:
print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m")
else:
print(
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
)
if verbose:
if plg_out:
print("Plugin stdout:")
sys.stdout.buffer.write(plg_out)
sys.stdout.buffer.flush()
if plg_err:
print("Plugin stderr:")
sys.stderr.buffer.write(plg_err)
sys.stderr.buffer.flush()
return max(ref_code, plg_code)

View File

@@ -1,9 +1,7 @@
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
# Remove from list when fixed.
xfail = {
"oneof_enum", # 63
"namespace_keywords", # 70
"namespace_builtin_types", # 53
"googletypes_struct", # 9
"googletypes_value", # 9
"import_capitalized_package",
@@ -14,7 +12,17 @@ services = {
"googletypes_response",
"googletypes_response_embedded",
"service",
"service_separate_packages",
"import_service_input_message",
"googletypes_service_returns_empty",
"googletypes_service_returns_googletype",
"example_service",
"empty_service",
}
# Indicate json sample messages to skip when testing that json (de)serialization
# is symmetrical becuase some cases legitimately are not symmetrical.
# Each key references the name of the test scenario and the values in the tuple
# Are the names of the json files.
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}

View File

@@ -0,0 +1,3 @@
{
"msg": [{"values":[]}]
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
message MessageA {
repeated float values = 1;
}
message Test {
repeated MessageA msg = 1;
}

View File

@@ -0,0 +1,7 @@
/* Empty service without comments */
syntax = "proto3";
package empty_service;
service Test {
}

View File

@@ -1,8 +1,909 @@
syntax = "proto3";
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package hello;
// Author: kenton@google.com (Kenton Varda)
// Based on original Protocol Buffers design by
// Sanjay Ghemawat, Jeff Dean, and others.
//
// The messages in this file describe the definitions found in .proto files.
// A valid .proto file can be translated directly to a FileDescriptorProto
// without any other information (e.g. without reading its imports).
// Greeting represents a message you can tell a user.
message Greeting {
string message = 1;
syntax = "proto2";
// package google.protobuf;
option go_package = "google.golang.org/protobuf/types/descriptorpb";
option java_package = "com.google.protobuf";
option java_outer_classname = "DescriptorProtos";
option csharp_namespace = "Google.Protobuf.Reflection";
option objc_class_prefix = "GPB";
option cc_enable_arenas = true;
// descriptor.proto must be optimized for speed because reflection-based
// algorithms don't work during bootstrapping.
option optimize_for = SPEED;
// The protocol compiler can output a FileDescriptorSet containing the .proto
// files it parses.
message FileDescriptorSet {
repeated FileDescriptorProto file = 1;
}
// Describes a complete .proto file.
message FileDescriptorProto {
optional string name = 1; // file name, relative to root of source tree
optional string package = 2; // e.g. "foo", "foo.bar", etc.
// Names of files imported by this file.
repeated string dependency = 3;
// Indexes of the public imported files in the dependency list above.
repeated int32 public_dependency = 10;
// Indexes of the weak imported files in the dependency list.
// For Google-internal migration only. Do not use.
repeated int32 weak_dependency = 11;
// All top-level definitions in this file.
repeated DescriptorProto message_type = 4;
repeated EnumDescriptorProto enum_type = 5;
repeated ServiceDescriptorProto service = 6;
repeated FieldDescriptorProto extension = 7;
optional FileOptions options = 8;
// This field contains optional information about the original source code.
// You may safely remove this entire field without harming runtime
// functionality of the descriptors -- the information is needed only by
// development tools.
optional SourceCodeInfo source_code_info = 9;
// The syntax of the proto file.
// The supported values are "proto2" and "proto3".
optional string syntax = 12;
}
// Describes a message type.
message DescriptorProto {
optional string name = 1;
repeated FieldDescriptorProto field = 2;
repeated FieldDescriptorProto extension = 6;
repeated DescriptorProto nested_type = 3;
repeated EnumDescriptorProto enum_type = 4;
message ExtensionRange {
optional int32 start = 1; // Inclusive.
optional int32 end = 2; // Exclusive.
optional ExtensionRangeOptions options = 3;
}
repeated ExtensionRange extension_range = 5;
repeated OneofDescriptorProto oneof_decl = 8;
optional MessageOptions options = 7;
// Range of reserved tag numbers. Reserved tag numbers may not be used by
// fields or extension ranges in the same message. Reserved ranges may
// not overlap.
message ReservedRange {
optional int32 start = 1; // Inclusive.
optional int32 end = 2; // Exclusive.
}
repeated ReservedRange reserved_range = 9;
// Reserved field names, which may not be used by fields in the same message.
// A given name may only be reserved once.
repeated string reserved_name = 10;
}
message ExtensionRangeOptions {
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
// Describes a field within a message.
message FieldDescriptorProto {
enum Type {
// 0 is reserved for errors.
// Order is weird for historical reasons.
TYPE_DOUBLE = 1;
TYPE_FLOAT = 2;
// Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if
// negative values are likely.
TYPE_INT64 = 3;
TYPE_UINT64 = 4;
// Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if
// negative values are likely.
TYPE_INT32 = 5;
TYPE_FIXED64 = 6;
TYPE_FIXED32 = 7;
TYPE_BOOL = 8;
TYPE_STRING = 9;
// Tag-delimited aggregate.
// Group type is deprecated and not supported in proto3. However, Proto3
// implementations should still be able to parse the group wire format and
// treat group fields as unknown fields.
TYPE_GROUP = 10;
TYPE_MESSAGE = 11; // Length-delimited aggregate.
// New in version 2.
TYPE_BYTES = 12;
TYPE_UINT32 = 13;
TYPE_ENUM = 14;
TYPE_SFIXED32 = 15;
TYPE_SFIXED64 = 16;
TYPE_SINT32 = 17; // Uses ZigZag encoding.
TYPE_SINT64 = 18; // Uses ZigZag encoding.
}
enum Label {
// 0 is reserved for errors
LABEL_OPTIONAL = 1;
LABEL_REQUIRED = 2;
LABEL_REPEATED = 3;
}
optional string name = 1;
optional int32 number = 3;
optional Label label = 4;
// If type_name is set, this need not be set. If both this and type_name
// are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP.
optional Type type = 5;
// For message and enum types, this is the name of the type. If the name
// starts with a '.', it is fully-qualified. Otherwise, C++-like scoping
// rules are used to find the type (i.e. first the nested types within this
// message are searched, then within the parent, on up to the root
// namespace).
optional string type_name = 6;
// For extensions, this is the name of the type being extended. It is
// resolved in the same manner as type_name.
optional string extendee = 2;
// For numeric types, contains the original text representation of the value.
// For booleans, "true" or "false".
// For strings, contains the default text contents (not escaped in any way).
// For bytes, contains the C escaped value. All bytes >= 128 are escaped.
// TODO(kenton): Base-64 encode?
optional string default_value = 7;
// If set, gives the index of a oneof in the containing type's oneof_decl
// list. This field is a member of that oneof.
optional int32 oneof_index = 9;
// JSON name of this field. The value is set by protocol compiler. If the
// user has set a "json_name" option on this field, that option's value
// will be used. Otherwise, it's deduced from the field's name by converting
// it to camelCase.
optional string json_name = 10;
optional FieldOptions options = 8;
// If true, this is a proto3 "optional". When a proto3 field is optional, it
// tracks presence regardless of field type.
//
// When proto3_optional is true, this field must be belong to a oneof to
// signal to old proto3 clients that presence is tracked for this field. This
// oneof is known as a "synthetic" oneof, and this field must be its sole
// member (each proto3 optional field gets its own synthetic oneof). Synthetic
// oneofs exist in the descriptor only, and do not generate any API. Synthetic
// oneofs must be ordered after all "real" oneofs.
//
// For message fields, proto3_optional doesn't create any semantic change,
// since non-repeated message fields always track presence. However it still
// indicates the semantic detail of whether the user wrote "optional" or not.
// This can be useful for round-tripping the .proto file. For consistency we
// give message fields a synthetic oneof also, even though it is not required
// to track presence. This is especially important because the parser can't
// tell if a field is a message or an enum, so it must always create a
// synthetic oneof.
//
// Proto2 optional fields do not set this flag, because they already indicate
// optional with `LABEL_OPTIONAL`.
optional bool proto3_optional = 17;
}
// Describes a oneof.
message OneofDescriptorProto {
optional string name = 1;
optional OneofOptions options = 2;
}
// Describes an enum type.
message EnumDescriptorProto {
optional string name = 1;
repeated EnumValueDescriptorProto value = 2;
optional EnumOptions options = 3;
// Range of reserved numeric values. Reserved values may not be used by
// entries in the same enum. Reserved ranges may not overlap.
//
// Note that this is distinct from DescriptorProto.ReservedRange in that it
// is inclusive such that it can appropriately represent the entire int32
// domain.
message EnumReservedRange {
optional int32 start = 1; // Inclusive.
optional int32 end = 2; // Inclusive.
}
// Range of reserved numeric values. Reserved numeric values may not be used
// by enum values in the same enum declaration. Reserved ranges may not
// overlap.
repeated EnumReservedRange reserved_range = 4;
// Reserved enum value names, which may not be reused. A given name may only
// be reserved once.
repeated string reserved_name = 5;
}
// Describes a value within an enum.
message EnumValueDescriptorProto {
optional string name = 1;
optional int32 number = 2;
optional EnumValueOptions options = 3;
}
// Describes a service.
message ServiceDescriptorProto {
optional string name = 1;
repeated MethodDescriptorProto method = 2;
optional ServiceOptions options = 3;
}
// Describes a method of a service.
message MethodDescriptorProto {
optional string name = 1;
// Input and output type names. These are resolved in the same way as
// FieldDescriptorProto.type_name, but must refer to a message type.
optional string input_type = 2;
optional string output_type = 3;
optional MethodOptions options = 4;
// Identifies if client streams multiple client messages
optional bool client_streaming = 5 [default = false];
// Identifies if server streams multiple server messages
optional bool server_streaming = 6 [default = false];
}
// ===================================================================
// Options
// Each of the definitions above may have "options" attached. These are
// just annotations which may cause code to be generated slightly differently
// or may contain hints for code that manipulates protocol messages.
//
// Clients may define custom options as extensions of the *Options messages.
// These extensions may not yet be known at parsing time, so the parser cannot
// store the values in them. Instead it stores them in a field in the *Options
// message called uninterpreted_option. This field must have the same name
// across all *Options messages. We then use this field to populate the
// extensions when we build a descriptor, at which point all protos have been
// parsed and so all extensions are known.
//
// Extension numbers for custom options may be chosen as follows:
// * For options which will only be used within a single application or
// organization, or for experimental options, use field numbers 50000
// through 99999. It is up to you to ensure that you do not use the
// same number for multiple options.
// * For options which will be published and used publicly by multiple
// independent entities, e-mail protobuf-global-extension-registry@google.com
// to reserve extension numbers. Simply provide your project name (e.g.
// Objective-C plugin) and your project website (if available) -- there's no
// need to explain how you intend to use them. Usually you only need one
// extension number. You can declare multiple options with only one extension
// number by putting them in a sub-message. See the Custom Options section of
// the docs for examples:
// https://developers.google.com/protocol-buffers/docs/proto#options
// If this turns out to be popular, a web service will be set up
// to automatically assign option numbers.
message FileOptions {
// Sets the Java package where classes generated from this .proto will be
// placed. By default, the proto package is used, but this is often
// inappropriate because proto packages do not normally start with backwards
// domain names.
optional string java_package = 1;
// If set, all the classes from the .proto file are wrapped in a single
// outer class with the given name. This applies to both Proto1
// (equivalent to the old "--one_java_file" option) and Proto2 (where
// a .proto always translates to a single class, but you may want to
// explicitly choose the class name).
optional string java_outer_classname = 8;
// If set true, then the Java code generator will generate a separate .java
// file for each top-level message, enum, and service defined in the .proto
// file. Thus, these types will *not* be nested inside the outer class
// named by java_outer_classname. However, the outer class will still be
// generated to contain the file's getDescriptor() method as well as any
// top-level extensions defined in the file.
optional bool java_multiple_files = 10 [default = false];
// This option does nothing.
optional bool java_generate_equals_and_hash = 20 [deprecated=true];
// If set true, then the Java2 code generator will generate code that
// throws an exception whenever an attempt is made to assign a non-UTF-8
// byte sequence to a string field.
// Message reflection will do the same.
// However, an extension field still accepts non-UTF-8 byte sequences.
// This option has no effect on when used with the lite runtime.
optional bool java_string_check_utf8 = 27 [default = false];
// Generated classes can be optimized for speed or code size.
enum OptimizeMode {
SPEED = 1; // Generate complete code for parsing, serialization,
// etc.
CODE_SIZE = 2; // Use ReflectionOps to implement these methods.
LITE_RUNTIME = 3; // Generate code using MessageLite and the lite runtime.
}
optional OptimizeMode optimize_for = 9 [default = SPEED];
// Sets the Go package where structs generated from this .proto will be
// placed. If omitted, the Go package will be derived from the following:
// - The basename of the package import path, if provided.
// - Otherwise, the package statement in the .proto file, if present.
// - Otherwise, the basename of the .proto file, without extension.
optional string go_package = 11;
// Should generic services be generated in each language? "Generic" services
// are not specific to any particular RPC system. They are generated by the
// main code generators in each language (without additional plugins).
// Generic services were the only kind of service generation supported by
// early versions of google.protobuf.
//
// Generic services are now considered deprecated in favor of using plugins
// that generate code specific to your particular RPC system. Therefore,
// these default to false. Old code which depends on generic services should
// explicitly set them to true.
optional bool cc_generic_services = 16 [default = false];
optional bool java_generic_services = 17 [default = false];
optional bool py_generic_services = 18 [default = false];
optional bool php_generic_services = 42 [default = false];
// Is this file deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for everything in the file, or it will be completely ignored; in the very
// least, this is a formalization for deprecating files.
optional bool deprecated = 23 [default = false];
// Enables the use of arenas for the proto messages in this file. This applies
// only to generated classes for C++.
optional bool cc_enable_arenas = 31 [default = true];
// Sets the objective c class prefix which is prepended to all objective c
// generated classes from this .proto. There is no default.
optional string objc_class_prefix = 36;
// Namespace for generated classes; defaults to the package.
optional string csharp_namespace = 37;
// By default Swift generators will take the proto package and CamelCase it
// replacing '.' with underscore and use that to prefix the types/symbols
// defined. When this options is provided, they will use this value instead
// to prefix the types/symbols defined.
optional string swift_prefix = 39;
// Sets the php class prefix which is prepended to all php generated classes
// from this .proto. Default is empty.
optional string php_class_prefix = 40;
// Use this option to change the namespace of php generated classes. Default
// is empty. When this option is empty, the package name will be used for
// determining the namespace.
optional string php_namespace = 41;
// Use this option to change the namespace of php generated metadata classes.
// Default is empty. When this option is empty, the proto file name will be
// used for determining the namespace.
optional string php_metadata_namespace = 44;
// Use this option to change the package of ruby generated classes. Default
// is empty. When this option is not set, the package name will be used for
// determining the ruby package.
optional string ruby_package = 45;
// The parser stores options it doesn't recognize here.
// See the documentation for the "Options" section above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message.
// See the documentation for the "Options" section above.
extensions 1000 to max;
reserved 38;
}
message MessageOptions {
// Set true to use the old proto1 MessageSet wire format for extensions.
// This is provided for backwards-compatibility with the MessageSet wire
// format. You should not use this for any other reason: It's less
// efficient, has fewer features, and is more complicated.
//
// The message must be defined exactly as follows:
// message Foo {
// option message_set_wire_format = true;
// extensions 4 to max;
// }
// Note that the message cannot have any defined fields; MessageSets only
// have extensions.
//
// All extensions of your type must be singular messages; e.g. they cannot
// be int32s, enums, or repeated messages.
//
// Because this is an option, the above two restrictions are not enforced by
// the protocol compiler.
optional bool message_set_wire_format = 1 [default = false];
// Disables the generation of the standard "descriptor()" accessor, which can
// conflict with a field of the same name. This is meant to make migration
// from proto1 easier; new code should avoid fields named "descriptor".
optional bool no_standard_descriptor_accessor = 2 [default = false];
// Is this message deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for the message, or it will be completely ignored; in the very least,
// this is a formalization for deprecating messages.
optional bool deprecated = 3 [default = false];
// Whether the message is an automatically generated map entry type for the
// maps field.
//
// For maps fields:
// map<KeyType, ValueType> map_field = 1;
// The parsed descriptor looks like:
// message MapFieldEntry {
// option map_entry = true;
// optional KeyType key = 1;
// optional ValueType value = 2;
// }
// repeated MapFieldEntry map_field = 1;
//
// Implementations may choose not to generate the map_entry=true message, but
// use a native map in the target language to hold the keys and values.
// The reflection APIs in such implementations still need to work as
// if the field is a repeated message field.
//
// NOTE: Do not set the option in .proto files. Always use the maps syntax
// instead. The option should only be implicitly set by the proto compiler
// parser.
optional bool map_entry = 7;
reserved 8; // javalite_serializable
reserved 9; // javanano_as_lite
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
message FieldOptions {
// The ctype option instructs the C++ code generator to use a different
// representation of the field than it normally would. See the specific
// options below. This option is not yet implemented in the open source
// release -- sorry, we'll try to include it in a future version!
optional CType ctype = 1 [default = STRING];
enum CType {
// Default mode.
STRING = 0;
CORD = 1;
STRING_PIECE = 2;
}
// The packed option can be enabled for repeated primitive fields to enable
// a more efficient representation on the wire. Rather than repeatedly
// writing the tag and type for each element, the entire array is encoded as
// a single length-delimited blob. In proto3, only explicit setting it to
// false will avoid using packed encoding.
optional bool packed = 2;
// The jstype option determines the JavaScript type used for values of the
// field. The option is permitted only for 64 bit integral and fixed types
// (int64, uint64, sint64, fixed64, sfixed64). A field with jstype JS_STRING
// is represented as JavaScript string, which avoids loss of precision that
// can happen when a large value is converted to a floating point JavaScript.
// Specifying JS_NUMBER for the jstype causes the generated JavaScript code to
// use the JavaScript "number" type. The behavior of the default option
// JS_NORMAL is implementation dependent.
//
// This option is an enum to permit additional types to be added, e.g.
// goog.math.Integer.
optional JSType jstype = 6 [default = JS_NORMAL];
enum JSType {
// Use the default type.
JS_NORMAL = 0;
// Use JavaScript strings.
JS_STRING = 1;
// Use JavaScript numbers.
JS_NUMBER = 2;
}
// Should this field be parsed lazily? Lazy applies only to message-type
// fields. It means that when the outer message is initially parsed, the
// inner message's contents will not be parsed but instead stored in encoded
// form. The inner message will actually be parsed when it is first accessed.
//
// This is only a hint. Implementations are free to choose whether to use
// eager or lazy parsing regardless of the value of this option. However,
// setting this option true suggests that the protocol author believes that
// using lazy parsing on this field is worth the additional bookkeeping
// overhead typically needed to implement it.
//
// This option does not affect the public interface of any generated code;
// all method signatures remain the same. Furthermore, thread-safety of the
// interface is not affected by this option; const methods remain safe to
// call from multiple threads concurrently, while non-const methods continue
// to require exclusive access.
//
//
// Note that implementations may choose not to check required fields within
// a lazy sub-message. That is, calling IsInitialized() on the outer message
// may return true even if the inner message has missing required fields.
// This is necessary because otherwise the inner message would have to be
// parsed in order to perform the check, defeating the purpose of lazy
// parsing. An implementation which chooses not to check required fields
// must be consistent about it. That is, for any particular sub-message, the
// implementation must either *always* check its required fields, or *never*
// check its required fields, regardless of whether or not the message has
// been parsed.
optional bool lazy = 5 [default = false];
// Is this field deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for accessors, or it will be completely ignored; in the very least, this
// is a formalization for deprecating fields.
optional bool deprecated = 3 [default = false];
// For Google-internal migration only. Do not use.
optional bool weak = 10 [default = false];
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
reserved 4; // removed jtype
}
message OneofOptions {
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
message EnumOptions {
// Set this option to true to allow mapping different tag names to the same
// value.
optional bool allow_alias = 2;
// Is this enum deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for the enum, or it will be completely ignored; in the very least, this
// is a formalization for deprecating enums.
optional bool deprecated = 3 [default = false];
reserved 5; // javanano_as_lite
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
message EnumValueOptions {
// Is this enum value deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for the enum value, or it will be completely ignored; in the very least,
// this is a formalization for deprecating enum values.
optional bool deprecated = 1 [default = false];
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
message ServiceOptions {
// Note: Field numbers 1 through 32 are reserved for Google's internal RPC
// framework. We apologize for hoarding these numbers to ourselves, but
// we were already using them long before we decided to release Protocol
// Buffers.
// Is this service deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for the service, or it will be completely ignored; in the very least,
// this is a formalization for deprecating services.
optional bool deprecated = 33 [default = false];
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
message MethodOptions {
// Note: Field numbers 1 through 32 are reserved for Google's internal RPC
// framework. We apologize for hoarding these numbers to ourselves, but
// we were already using them long before we decided to release Protocol
// Buffers.
// Is this method deprecated?
// Depending on the target platform, this can emit Deprecated annotations
// for the method, or it will be completely ignored; in the very least,
// this is a formalization for deprecating methods.
optional bool deprecated = 33 [default = false];
// Is this method side-effect-free (or safe in HTTP parlance), or idempotent,
// or neither? HTTP based RPC implementation may choose GET verb for safe
// methods, and PUT verb for idempotent methods instead of the default POST.
enum IdempotencyLevel {
IDEMPOTENCY_UNKNOWN = 0;
NO_SIDE_EFFECTS = 1; // implies idempotent
IDEMPOTENT = 2; // idempotent, but may have side effects
}
optional IdempotencyLevel idempotency_level = 34
[default = IDEMPOTENCY_UNKNOWN];
// The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999;
// Clients can define custom options in extensions of this message. See above.
extensions 1000 to max;
}
// A message representing a option the parser does not recognize. This only
// appears in options protos created by the compiler::Parser class.
// DescriptorPool resolves these when building Descriptor objects. Therefore,
// options protos in descriptor objects (e.g. returned by Descriptor::options(),
// or produced by Descriptor::CopyTo()) will never have UninterpretedOptions
// in them.
message UninterpretedOption {
// The name of the uninterpreted option. Each string represents a segment in
// a dot-separated name. is_extension is true iff a segment represents an
// extension (denoted with parentheses in options specs in .proto files).
// E.g.,{ ["foo", false], ["bar.baz", true], ["qux", false] } represents
// "foo.(bar.baz).qux".
message NamePart {
required string name_part = 1;
required bool is_extension = 2;
}
repeated NamePart name = 2;
// The value of the uninterpreted option, in whatever type the tokenizer
// identified it as during parsing. Exactly one of these should be set.
optional string identifier_value = 3;
optional uint64 positive_int_value = 4;
optional int64 negative_int_value = 5;
optional double double_value = 6;
optional bytes string_value = 7;
optional string aggregate_value = 8;
}
// ===================================================================
// Optional source code info
// Encapsulates information about the original source file from which a
// FileDescriptorProto was generated.
message SourceCodeInfo {
// A Location identifies a piece of source code in a .proto file which
// corresponds to a particular definition. This information is intended
// to be useful to IDEs, code indexers, documentation generators, and similar
// tools.
//
// For example, say we have a file like:
// message Foo {
// optional string foo = 1;
// }
// Let's look at just the field definition:
// optional string foo = 1;
// ^ ^^ ^^ ^ ^^^
// a bc de f ghi
// We have the following locations:
// span path represents
// [a,i) [ 4, 0, 2, 0 ] The whole field definition.
// [a,b) [ 4, 0, 2, 0, 4 ] The label (optional).
// [c,d) [ 4, 0, 2, 0, 5 ] The type (string).
// [e,f) [ 4, 0, 2, 0, 1 ] The name (foo).
// [g,h) [ 4, 0, 2, 0, 3 ] The number (1).
//
// Notes:
// - A location may refer to a repeated field itself (i.e. not to any
// particular index within it). This is used whenever a set of elements are
// logically enclosed in a single code segment. For example, an entire
// extend block (possibly containing multiple extension definitions) will
// have an outer location whose path refers to the "extensions" repeated
// field without an index.
// - Multiple locations may have the same path. This happens when a single
// logical declaration is spread out across multiple places. The most
// obvious example is the "extend" block again -- there may be multiple
// extend blocks in the same scope, each of which will have the same path.
// - A location's span is not always a subset of its parent's span. For
// example, the "extendee" of an extension declaration appears at the
// beginning of the "extend" block and is shared by all extensions within
// the block.
// - Just because a location's span is a subset of some other location's span
// does not mean that it is a descendant. For example, a "group" defines
// both a type and a field in a single declaration. Thus, the locations
// corresponding to the type and field and their components will overlap.
// - Code which tries to interpret locations should probably be designed to
// ignore those that it doesn't understand, as more types of locations could
// be recorded in the future.
repeated Location location = 1;
message Location {
// Identifies which part of the FileDescriptorProto was defined at this
// location.
//
// Each element is a field number or an index. They form a path from
// the root FileDescriptorProto to the place where the definition. For
// example, this path:
// [ 4, 3, 2, 7, 1 ]
// refers to:
// file.message_type(3) // 4, 3
// .field(7) // 2, 7
// .name() // 1
// This is because FileDescriptorProto.message_type has field number 4:
// repeated DescriptorProto message_type = 4;
// and DescriptorProto.field has field number 2:
// repeated FieldDescriptorProto field = 2;
// and FieldDescriptorProto.name has field number 1:
// optional string name = 1;
//
// Thus, the above path gives the location of a field name. If we removed
// the last element:
// [ 4, 3, 2, 7 ]
// this path refers to the whole field declaration (from the beginning
// of the label to the terminating semicolon).
repeated int32 path = 1 [packed = true];
// Always has exactly three or four elements: start line, start column,
// end line (optional, otherwise assumed same as start line), end column.
// These are packed into a single field for efficiency. Note that line
// and column numbers are zero-based -- typically you will want to add
// 1 to each before displaying to a user.
repeated int32 span = 2 [packed = true];
// If this SourceCodeInfo represents a complete declaration, these are any
// comments appearing before and after the declaration which appear to be
// attached to the declaration.
//
// A series of line comments appearing on consecutive lines, with no other
// tokens appearing on those lines, will be treated as a single comment.
//
// leading_detached_comments will keep paragraphs of comments that appear
// before (but not connected to) the current element. Each paragraph,
// separated by empty lines, will be one comment element in the repeated
// field.
//
// Only the comment content is provided; comment markers (e.g. //) are
// stripped out. For block comments, leading whitespace and an asterisk
// will be stripped from the beginning of each line other than the first.
// Newlines are included in the output.
//
// Examples:
//
// optional int32 foo = 1; // Comment attached to foo.
// // Comment attached to bar.
// optional int32 bar = 2;
//
// optional string baz = 3;
// // Comment attached to baz.
// // Another line attached to baz.
//
// // Comment attached to qux.
// //
// // Another line attached to qux.
// optional double qux = 4;
//
// // Detached comment for corge. This is not leading or trailing comments
// // to qux or corge because there are blank lines separating it from
// // both.
//
// // Detached comment for corge paragraph 2.
//
// optional string corge = 5;
// /* Block comment attached
// * to corge. Leading asterisks
// * will be removed. */
// /* Block comment attached to
// * grault. */
// optional int32 grault = 6;
//
// // ignored detached comments.
optional string leading_comments = 3;
optional string trailing_comments = 4;
repeated string leading_detached_comments = 6;
}
}
// Describes the relationship between generated code and its original source
// file. A GeneratedCodeInfo message is associated with only one generated
// source file, but may contain references to different source .proto files.
message GeneratedCodeInfo {
// An Annotation connects some span of text in generated code to an element
// of its generating .proto file.
repeated Annotation annotation = 1;
message Annotation {
// Identifies the element in the original source .proto file. This field
// is formatted the same as SourceCodeInfo.Location.path.
repeated int32 path = 1 [packed = true];
// Identifies the filesystem path to the original source .proto.
optional string source_file = 2;
// Identifies the starting offset in bytes in the generated code
// that relates to the identified object.
optional int32 begin = 3;
// Identifies the ending offset in bytes in the generated code that
// relates to the identified offset. The end offset should be one past
// the last relevant byte (so the length of the text = end - begin).
optional int32 end = 4;
}
}

View File

@@ -0,0 +1,20 @@
syntax = "proto3";
package example_service;
service Test {
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
}
message ExampleRequest {
string example_string = 1;
int64 example_integer = 2;
}
message ExampleResponse {
string example_string = 1;
int64 example_integer = 2;
}

View File

@@ -0,0 +1,95 @@
from typing import AsyncIterator, AsyncIterable
import pytest
from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest,
ExampleResponse,
)
class ExampleService(TestBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)
async def example_unary_stream(
self, example_string: str, example_integer: int
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)
yield response
yield response
yield response
async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42
async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel)
# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int
# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int
# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator():
yield request
yield request
yield request
response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int
# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int

View File

@@ -0,0 +1,7 @@
{
"int": 26,
"float": 26.0,
"str": "value-for-str",
"bytes": "001a",
"bool": true
}

View File

@@ -0,0 +1,11 @@
syntax = "proto3";
// Tests that messages may contain fields with names that are identical to their python types (PR #294)
message Test {
int32 int = 1;
float float = 2;
string str = 3;
bytes bytes = 4;
bool bool = 5;
}

View File

@@ -0,0 +1,9 @@
{
"positive": "Infinity",
"negative": "-Infinity",
"nan": "NaN",
"three": 3.0,
"threePointOneFour": 3.14,
"negThree": -3.0,
"negThreePointOneFour": -3.14
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
// Some documentation about the Test message.
message Test {
double positive = 1;
double negative = 2;
double nan = 3;
double three = 4;
double three_point_one_four = 5;
double neg_three = 6;
double neg_three_point_one_four = 7;
}

View File

@@ -1,3 +1,3 @@
{
"name": "foobar"
"pitier": "Mr. T"
}

View File

@@ -1,3 +1,3 @@
{
"count": 100
"pitied": 100
}

View File

@@ -2,7 +2,15 @@ syntax = "proto3";
message Test {
oneof foo {
int32 count = 1;
string name = 2;
int32 pitied = 1;
string pitier = 2;
}
int32 just_a_regular_field = 3;
oneof bar {
int32 drinks = 11;
string bar_name = 12;
}
}

View File

@@ -0,0 +1,3 @@
{
"pitier": "Mr. T"
}

View File

@@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
def test_which_count():
message = Test()
message.from_json(get_test_case_json_data("oneof"))
assert betterproto.which_one_of(message, "foo") == ("count", 100)
message.from_json(get_test_case_json_data("oneof")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof-name.json"))
assert betterproto.which_one_of(message, "foo") == ("name", "foobar")
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")

View File

@@ -0,0 +1,3 @@
{
"nothing": {}
}

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
message Nothing {}
message MaybeNothing {
string sometimes = 42;
}
message Test {
oneof empty {
Nothing nothing = 1;
MaybeNothing maybe1 = 2;
MaybeNothing maybe2 = 3;
}
}

View File

@@ -0,0 +1,3 @@
{
"maybe1": {}
}

View File

@@ -0,0 +1,5 @@
{
"maybe2": {
"sometimes": "now"
}
}

View File

@@ -14,7 +14,9 @@ def test_which_one_of_returns_enum_with_default_value():
returns first field when it is enum and set with default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json"))
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)
assert message.move == Move(
x=0, y=0
@@ -28,7 +30,9 @@ def test_which_one_of_returns_enum_with_non_default_value():
returns first field when it is enum and set with non default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
@@ -38,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum"))
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@@ -0,0 +1,12 @@
{
"test1": 128,
"test2": true,
"test3": "A value",
"test4": "aGVsbG8=",
"test5": {
"test": "Hello"
},
"test6": "B",
"test7": "8589934592",
"test8": 2.5
}

View File

@@ -0,0 +1,21 @@
syntax = "proto3";
message InnerTest {
string test = 1;
}
message Test {
optional uint32 test1 = 1;
optional bool test2 = 2;
optional string test3 = 3;
optional bytes test4 = 4;
optional InnerTest test5 = 5;
optional TestEnum test6 = 6;
optional uint64 test7 = 7;
optional float test8 = 8;
}
enum TestEnum {
A = 0;
B = 1;
}

View File

@@ -0,0 +1,9 @@
{
"test1": 0,
"test2": false,
"test3": "",
"test4": "",
"test6": "A",
"test7": "0",
"test8": 0
}

View File

@@ -0,0 +1,38 @@
import json
from tests.output_betterproto.proto3_field_presence import Test, InnerTest, TestEnum
def test_null_fields_json():
"""Ensure that using "null" in JSON is equivalent to not specifying a
field, for fields with explicit presence"""
def test_json(ref_json: str, obj_json: str) -> None:
"""`ref_json` and `obj_json` are JSON strings describing a `Test` object.
Test that deserializing both leads to the same object, and that
`ref_json` is the normalized format."""
ref_obj = Test().from_json(ref_json)
obj = Test().from_json(obj_json)
assert obj == ref_obj
assert json.loads(obj.to_json(0)) == json.loads(ref_json)
test_json("{}", '{ "test1": null, "test2": null, "test3": null }')
test_json("{}", '{ "test4": null, "test5": null, "test6": null }')
test_json("{}", '{ "test7": null, "test8": null }')
test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }')
# Make sure that if include_default_values is set, None values are
# exported.
obj = Test()
assert obj.to_dict() == {}
assert obj.to_dict(include_default_values=True) == {
"test1": None,
"test2": None,
"test3": None,
"test4": None,
"test5": None,
"test6": None,
"test7": None,
"test8": None,
}

View File

@@ -0,0 +1,3 @@
{
"nested": {}
}

View File

@@ -0,0 +1,20 @@
syntax = "proto3";
message Test {
oneof kind {
Nested nested = 1;
WithOptional with_optional = 2;
}
}
message InnerNested {
optional bool a = 1;
}
message Nested {
InnerNested inner = 1;
}
message WithOptional {
optional bool b = 2;
}

View File

@@ -0,0 +1,29 @@
from tests.output_betterproto.proto3_field_presence_oneof import (
Test,
InnerNested,
Nested,
WithOptional,
)
def test_serialization():
"""Ensure that serialization of fields unset but with explicit field
presence do not bloat the serialized payload with length-delimited fields
with length 0"""
def test_empty_nested(message: Test) -> None:
# '0a' => tag 1, length delimited
# '00' => length: 0
assert bytes(message) == bytearray.fromhex("0a 00")
test_empty_nested(Test(nested=Nested()))
test_empty_nested(Test(nested=Nested(inner=None)))
test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None))))
def test_empty_with_optional(message: Test) -> None:
# '12' => tag 2, length delimited
# '00' => length: 0
assert bytes(message) == bytearray.fromhex("12 00")
test_empty_with_optional(Test(with_optional=WithOptional()))
test_empty_with_optional(Test(with_optional=WithOptional(b=None)))

View File

@@ -0,0 +1,4 @@
{
"times": ["1972-01-01T10:00:20.021Z", "1972-01-01T10:00:20.021Z"],
"durations": ["1.200s", "1.200s"]
}

View File

@@ -0,0 +1,10 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
message Test {
repeated google.protobuf.Timestamp times = 1;
repeated google.protobuf.Duration durations = 2;
}

View File

@@ -0,0 +1,9 @@
from datetime import datetime, timedelta
from tests.output_betterproto.repeated_duration_timestamp import Test
def test_roundtrip():
message = Test()
message.times = [datetime.now(), datetime.now()]
message.durations = [timedelta(), timedelta()]

View File

@@ -2,9 +2,16 @@ syntax = "proto3";
package service;
enum ThingType {
UNKNOWN = 0;
LIVING = 1;
DEAD = 2;
}
message DoThingRequest {
string name = 1;
repeated string comments = 2;
ThingType type = 3;
}
message DoThingResponse {

View File

@@ -0,0 +1,31 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
package things.messages;
message DoThingRequest {
string name = 1;
// use `repeated` so we can check if `List` is correctly imported
repeated string comments = 2;
// use google types `timestamp` and `duration` so we can check
// if everything from `datetime` is correctly imported
google.protobuf.Timestamp when = 3;
google.protobuf.Duration duration = 4;
}
message DoThingResponse {
repeated string names = 1;
}
message GetThingRequest {
string name = 1;
}
message GetThingResponse {
string name = 1;
int32 version = 2;
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
import "messages.proto";
package things.service;
service Test {
rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
}

View File

@@ -1,6 +1,8 @@
import betterproto
from dataclasses import dataclass
from typing import Optional, List, Dict
from datetime import datetime
from inspect import signature
def test_has_field():
@@ -285,17 +287,23 @@ def test_to_dict_default_values():
def test_oneof_default_value_set_causes_writes_wire():
@dataclass
class Empty(betterproto.Message):
pass
@dataclass
class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1")
qux: Empty = betterproto.message_field(3, group="group1")
def _round_trip_serialization(foo: Foo) -> Foo:
return Foo().parse(bytes(foo))
foo1 = Foo(bar=0)
foo2 = Foo(baz="")
foo3 = Foo()
foo3 = Foo(qux=Empty())
foo4 = Foo()
assert bytes(foo1) == b"\x08\x00"
assert (
@@ -311,10 +319,17 @@ def test_oneof_default_value_set_causes_writes_wire():
== ("baz", "")
)
assert bytes(foo3) == b""
assert bytes(foo3) == b"\x1a\x00"
assert (
betterproto.which_one_of(foo3, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
== ("qux", Empty())
)
assert bytes(foo4) == b""
assert (
betterproto.which_one_of(foo4, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo4), "group1")
== ("", None)
)
@@ -395,3 +410,77 @@ def test_bool():
assert t
t.bar = 0
assert not t
# valid ISO datetimes according to https://www.myintervals.com/blog/2009/05/20/iso-8601-date-validation-that-doesnt-suck/
iso_candidates = """2009-12-12T12:34
2009
2009-05-19
2009-05-19
20090519
2009123
2009-05
2009-123
2009-222
2009-001
2009-W01-1
2009-W51-1
2009-W33
2009W511
2009-05-19
2009-05-19 00:00
2009-05-19 14
2009-05-19 14:31
2009-05-19 14:39:22
2009-05-19T14:39Z
2009-W21-2
2009-W21-2T01:22
2009-139
2009-05-19 14:39:22-06:00
2009-05-19 14:39:22+0600
2009-05-19 14:39:22-01
20090621T0545Z
2007-04-06T00:00
2007-04-05T24:00
2010-02-18T16:23:48.5
2010-02-18T16:23:48,444
2010-02-18T16:23:48,3-06:00
2010-02-18T16:23:00.4
2010-02-18T16:23:00,25
2010-02-18T16:23:00.33+0600
2010-02-18T16:00:00.23334444
2010-02-18T16:00:00,2283
2009-05-19 143922
2009-05-19 1439""".split(
"\n"
)
def test_iso_datetime():
@dataclass
class Envelope(betterproto.Message):
ts: datetime = betterproto.message_field(1)
msg = Envelope()
for _, candidate in enumerate(iso_candidates):
msg.from_dict({"ts": candidate})
assert isinstance(msg.ts, datetime)
def test_iso_datetime_list():
@dataclass
class Envelope(betterproto.Message):
timestamps: List[datetime] = betterproto.message_field(1)
msg = Envelope()
msg.from_dict({"timestamps": iso_candidates})
assert all([isinstance(item, datetime) for item in msg.timestamps])
def test_enum_service_argument__expected_default_value():
from tests.output_betterproto.service.service import ThingType, TestStub
sig = signature(TestStub.do_thing)
assert sig.parameters["type"].default == ThingType.UNKNOWN

View File

@@ -1,10 +1,11 @@
import importlib
import json
import math
import os
import sys
from collections import namedtuple
from types import ModuleType
from typing import Set
from typing import Any, Dict, List, Set, Tuple
import pytest
@@ -28,7 +29,12 @@ from google.protobuf.json_format import Parse
class TestCases:
def __init__(self, path, services: Set[str], xfail: Set[str]):
def __init__(
self,
path,
services: Set[str],
xfail: Set[str],
):
_all = set(get_directories(path)) - {"__pycache__"}
_services = services
_messages = (_all - services) - {"__pycache__"}
@@ -69,6 +75,55 @@ def module_has_entry_point(module: ModuleType):
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
def list_replace_nans(items: List) -> List[Any]:
"""Replace float("nan") in a list with the string "NaN"
Parameters
----------
items : List
List to update
Returns
-------
List[Any]
Updated list
"""
result = []
for item in items:
if isinstance(item, list):
result.append(list_replace_nans(item))
elif isinstance(item, dict):
result.append(dict_replace_nans(item))
elif isinstance(item, float) and math.isnan(item):
result.append(betterproto.NAN)
return result
def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
"""Replace float("nan") in a dictionary with the string "NaN"
Parameters
----------
input_dict : Dict[Any, Any]
Dictionary to update
Returns
-------
Dict[Any, Any]
Updated dictionary
"""
result = {}
for key, value in input_dict.items():
if isinstance(value, dict):
value = dict_replace_nans(value)
elif isinstance(value, list):
value = list_replace_nans(value)
elif isinstance(value, float) and math.isnan(value):
value = betterproto.NAN
result[key] = value
return result
@pytest.fixture
def test_data(request):
test_case_name = request.param
@@ -81,7 +136,6 @@ def test_data(request):
reference_module_root = os.path.join(
*reference_output_package.split("."), test_case_name
)
sys.path.append(reference_module_root)
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
@@ -126,42 +180,48 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data
for _ in range(repeat):
message: betterproto.Message = plugin_module.Test()
for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue
message.from_json(json_data)
message_json = message.to_json(0)
message: betterproto.Message = plugin_module.Test()
assert json.loads(message_json) == json.loads(json_data)
message.from_json(sample.json)
message_json = message.to_json(0)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(sample.json)
)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
def test_service_can_be_instantiated(test_data: TestData) -> None:
plugin_module, _, json_data = test_data
plugin_module.TestStub(MockChannel())
test_data.plugin_module.TestStub(MockChannel())
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data
reference_instance = Parse(json_data, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for sample in json_data:
reference_instance = Parse(sample.json, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
json_data
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(sample.json)
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
# # Generally this can't be relied on, but here we are aiming to match the
# # existing Python implementation and aren't doing anything tricky.
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
# Generally this can't be relied on, but here we are aiming to match the
# existing Python implementation and aren't doing anything tricky.
# https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
assert plugin_instance_from_json == plugin_instance_from_binary
assert (
plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
)
assert plugin_instance_from_json == plugin_instance_from_binary
assert dict_replace_nans(
plugin_instance_from_json.to_dict()
) == dict_replace_nans(plugin_instance_from_binary.to_dict())

13
tests/test_version.py Normal file
View File

@@ -0,0 +1,13 @@
from betterproto import __version__
from pathlib import Path
import tomlkit
PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve()
def test_version():
with PROJECT_TOML.open() as toml_file:
project_config = tomlkit.loads(toml_file.read())
assert (
__version__ == project_config["tool"]["poetry"]["version"]
), "Project version should match in package and package config"

View File

@@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass
import importlib
import os
import pathlib
import sys
from pathlib import Path
import sys
from types import ModuleType
from typing import Callable, Generator, Optional, Union
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@@ -47,15 +47,44 @@ async def protoc(
return stdout, stderr, proc.returncode
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
test_data_file_name = json_file_name or f"{test_case_name}.json"
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
@dataclass
class TestCaseJsonFile:
json: str
test_name: str
file_name: str
if not test_data_file_path.exists():
return None
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
with test_data_file_path.open("r") as fh:
return fh.read()
def get_test_case_json_data(
test_case_name: str, *json_file_names: str
) -> List[TestCaseJsonFile]:
"""
:return:
A list of all files found in "{inputs_path}/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
*(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
test_case_dir.joinpath(f"{test_case_name}.json"),
*test_case_dir.glob(f"{test_case_name}_*.json"),
]
result = []
for test_data_file_path in possible_file_paths:
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(
TestCaseJsonFile(
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
)
)
return result
def find_module(
@@ -74,7 +103,7 @@ def find_module(
if predicate(module):
return module
module_path = pathlib.Path(*module.__path__)
module_path = Path(*module.__path__)
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path: