16 Commits

Author SHA1 Message Date
Daniel G. Taylor
6fd9612ee1 Doc updates, version bump for release 2019-10-27 15:43:52 -07:00
Daniel G. Taylor
ba520f88a4 Install Protobuf include files on CI host 2019-10-27 15:40:33 -07:00
Daniel G. Taylor
b0b64fcbaf Fix tests attempt 3 2019-10-27 15:29:04 -07:00
Daniel G. Taylor
7900c7c9db Fix tests 2019-10-27 15:21:20 -07:00
Daniel G. Taylor
fcc273e294 Fix tests 2019-10-27 15:18:10 -07:00
Daniel G. Taylor
f820397751 Add missing optional types test 2019-10-27 15:14:06 -07:00
Daniel G. Taylor
16687211a2 Typing fixes 2019-10-27 15:13:51 -07:00
Daniel G. Taylor
eb5020db2a Fix bool parsing bug 2019-10-27 14:59:38 -07:00
Daniel G. Taylor
035793aec3 Support wrapper types 2019-10-27 14:55:25 -07:00
Daniel G. Taylor
c79535b614 Support Duration/Timestamp Google well-known types 2019-10-26 23:07:30 -07:00
Daniel G. Taylor
5daf61f64c Refactor default value code 2019-10-25 21:16:32 -07:00
Daniel G. Taylor
4679c571c3 Fix comment newlines 2019-10-25 12:28:26 -07:00
Daniel G. Taylor
ff8463cf12 Handle fields that clash with Python reserved keywords 2019-10-23 21:28:31 -07:00
Daniel G. Taylor
eff9021529 Some informational output from the plugin, do not overwrite __init__.py 2019-10-23 15:07:05 -07:00
Daniel G. Taylor
d43d5af5ce Better JSON casing support, renaming messages/fields 2019-10-23 15:06:34 -07:00
Daniel G. Taylor
ef0a1bf50c Use specific version of pypi publish image 2019-10-23 15:03:13 -07:00
20 changed files with 670 additions and 155 deletions

View File

@@ -14,10 +14,12 @@ jobs:
- uses: dschep/install-pipenv-action@v1 - uses: dschep/install-pipenv-action@v1
- name: Install dependencies - name: Install dependencies
run: | run: |
sudo apt install protobuf-compiler sudo apt install protobuf-compiler libprotobuf-dev
pipenv install --dev pipenv install --dev
- name: Run tests - name: Run tests
run: | run: |
cp .env.default .env
pipenv run pip install -e .
pipenv run generate pipenv run generate
pipenv run test pipenv run test
- name: Build package - name: Build package
@@ -25,7 +27,7 @@ jobs:
run: pipenv run python setup.py sdist run: pipenv run python setup.py sdist
- name: Publish package - name: Publish package
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@master uses: pypa/gh-action-pypi-publish@v1.0.0a0
with: with:
user: __token__ user: __token__
password: ${{ secrets.pypi }} password: ${{ secrets.pypi }}

View File

@@ -14,6 +14,7 @@ rope = "*"
protobuf = "*" protobuf = "*"
jinja2 = "*" jinja2 = "*"
grpclib = "*" grpclib = "*"
stringcase = "*"
[requires] [requires]
python_version = "3.7" python_version = "3.7"

52
Pipfile.lock generated
View File

@@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40" "sha256": "28c38cd6c4eafb0b9ac9a64cf623145868fdee163111d3b941b34d23011db6ca"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@@ -147,6 +147,13 @@
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
], ],
"version": "==1.12.0" "version": "==1.12.0"
},
"stringcase": {
"hashes": [
"sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"
],
"index": "pypi",
"version": "==1.2.0"
} }
}, },
"develop": { "develop": {
@@ -159,10 +166,10 @@
}, },
"attrs": { "attrs": {
"hashes": [ "hashes": [
"sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
"sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"
], ],
"version": "==19.2.0" "version": "==19.3.0"
}, },
"entrypoints": { "entrypoints": {
"hashes": [ "hashes": [
@@ -211,26 +218,30 @@
}, },
"mypy": { "mypy": {
"hashes": [ "hashes": [
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b", "sha256:1521c186a3d200c399bd5573c828ea2db1362af7209b2adb1bb8532cea2fb36f",
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac", "sha256:31a046ab040a84a0fc38bc93694876398e62bc9f35eca8ccbf6418b7297f4c00",
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c", "sha256:3b1a411909c84b2ae9b8283b58b48541654b918e8513c20a400bb946aa9111ae",
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879", "sha256:48c8bc99380575deb39f5d3400ebb6a8a1cb5cc669bbba4d3bb30f904e0a0e7d",
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795", "sha256:540c9caa57a22d0d5d3c69047cc9dd0094d49782603eb03069821b41f9e970e9",
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021", "sha256:672e418425d957e276c291930a3921b4a6413204f53fe7c37cad7bc57b9a3391",
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3", "sha256:6ed3b9b3fdc7193ea7aca6f3c20549b377a56f28769783a8f27191903a54170f",
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23", "sha256:9371290aa2cad5ad133e4cdc43892778efd13293406f7340b9ffe99d5ec7c1d9",
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd", "sha256:ace6ac1d0f87d4072f05b5468a084a45b4eda970e4d26704f201e06d47ab2990",
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b", "sha256:b428f883d2b3fe1d052c630642cc6afddd07d5cd7873da948644508be3b9d4a7",
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046" "sha256:d5bf0e6ec8ba346a2cf35cb55bf4adfddbc6b6576fcc9e10863daa523e418dbb",
"sha256:d7574e283f83c08501607586b3167728c58e8442947e027d2d4c7dcd6d82f453",
"sha256:dc889c84241a857c263a2b1cd1121507db7d5b5f5e87e77147097230f374d10b",
"sha256:f4748697b349f373002656bf32fede706a0e713d67bfdcf04edf39b1f61d46eb"
], ],
"index": "pypi", "index": "pypi",
"version": "==0.730" "version": "==0.740"
}, },
"mypy-extensions": { "mypy-extensions": {
"hashes": [ "hashes": [
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458" "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d",
"sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"
], ],
"version": "==0.4.2" "version": "==0.4.3"
}, },
"packaging": { "packaging": {
"hashes": [ "hashes": [
@@ -300,20 +311,25 @@
}, },
"typed-ast": { "typed-ast": {
"hashes": [ "hashes": [
"sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161",
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
"sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47",
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
"sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2",
"sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e",
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
"sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66",
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
], ],
"version": "==1.4.0" "version": "==1.4.0"

View File

@@ -10,6 +10,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
- Enums - Enums
- Dataclasses - Dataclasses
- `async`/`await` - `async`/`await`
- Timezone-aware `datetime` and `timedelta` objects
- Relative imports - Relative imports
- Mypy type checking - Mypy type checking
@@ -34,6 +35,8 @@ This project exists because I am unhappy with the state of the official Google p
- Much code looks like C++ or Java ported 1:1 to Python - Much code looks like C++ or Java ported 1:1 to Python
- Capitalized function names like `HasField()` and `SerializeToString()` - Capitalized function names like `HasField()` and `SerializeToString()`
- Uses `SerializeToString()` rather than the built-in `__bytes__()` - Uses `SerializeToString()` rather than the built-in `__bytes__()`
- Special wrapped types don't use Python's `None`
- Timestamp/duration types don't use Python's built-in `datetime` module
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical. This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
@@ -155,7 +158,7 @@ You can use it like so (enable async in the interactive shell first):
EchoResponse(values=["hello", "hello"]) EchoResponse(values=["hello", "hello"])
>>> async for response in service.echo_stream(value="hello", extra_times=1) >>> async for response in service.echo_stream(value="hello", extra_times=1)
print(response) print(response)
EchoStreamResponse(value="hello") EchoStreamResponse(value="hello")
EchoStreamResponse(value="hello") EchoStreamResponse(value="hello")
@@ -168,6 +171,12 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
- Dicts: `Message().to_dict()`, `Message().from_dict(...)` - Dicts: `Message().to_dict()`, `Message().from_dict(...)`
- JSON: `Message().to_json()`, `Message().from_json(...)` - JSON: `Message().to_json()`, `Message().from_json(...)`
For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g:
```py
>>> MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
```
### Determining if a message was sent ### Determining if a message was sent
Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example. Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
@@ -238,6 +247,53 @@ Again this is a little different than the official Google code generator:
["foo", "foo's value"] ["foo", "foo's value"]
``` ```
### Well-Known Google Types
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
| Google Message | Python Type | Default |
| --------------------------- | ---------------------------------------- | ---------------------- |
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
| `google.protobuf.*Value` | `Optional[...]` | `None` |
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
```protobuf
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
}
```
You can do stuff like:
```py
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
>>> t
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
>>> t.ts - t.duration
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
>>> t.ts.isoformat()
'2019-01-01T12:00:00+00:00'
>>> t.maybe = None
>>> t.to_dict()
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```
## Development ## Development
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then: First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
@@ -295,14 +351,14 @@ $ pipenv run tests
- [x] Bytes as base64 - [x] Bytes as base64
- [ ] Any support - [ ] Any support
- [x] Enum strings - [x] Enum strings
- [ ] Well known types support (timestamp, duration, wrappers) - [x] Well known types support (timestamp, duration, wrappers)
- [ ] Support different casing (orig vs. camel vs. others?) - [x] Support different casing (orig vs. camel vs. others?)
- [ ] Async service stubs - [ ] Async service stubs
- [x] Unary-unary - [x] Unary-unary
- [x] Server streaming response - [x] Server streaming response
- [ ] Client streaming request - [ ] Client streaming request
- [ ] Renaming messages and fields to conform to Python name standards - [x] Renaming messages and fields to conform to Python name standards
- [ ] Renaming clashes with language keywords and standard library top-level packages - [x] Renaming clashes with language keywords
- [x] Python package - [x] Python package
- [x] Automate running tests - [x] Automate running tests
- [ ] Cleanup! - [ ] Cleanup!

View File

@@ -5,6 +5,7 @@ import json
import struct import struct
from abc import ABC from abc import ABC
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from datetime import datetime, timedelta, timezone
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
@@ -24,6 +25,9 @@ from typing import (
import grpclib.client import grpclib.client
import grpclib.const import grpclib.const
import stringcase
from .casing import safe_snake_case
# Proto 3 data types # Proto 3 data types
TYPE_ENUM = "enum" TYPE_ENUM = "enum"
@@ -101,6 +105,17 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
class Casing(enum.Enum):
"""Casing constants for serialization."""
CAMEL = stringcase.camelcase
SNAKE = stringcase.snakecase
class _PLACEHOLDER: class _PLACEHOLDER:
pass pass
@@ -108,18 +123,6 @@ class _PLACEHOLDER:
PLACEHOLDER: Any = _PLACEHOLDER() PLACEHOLDER: Any = _PLACEHOLDER()
def get_default(proto_type: str) -> Any:
"""Get the default (zero value) for a given type."""
return {
TYPE_BOOL: False,
TYPE_FLOAT: 0.0,
TYPE_DOUBLE: 0.0,
TYPE_STRING: "",
TYPE_BYTES: b"",
TYPE_MAP: {},
}.get(proto_type, 0)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FieldMetadata: class FieldMetadata:
"""Stores internal metadata used for parsing & serialization.""" """Stores internal metadata used for parsing & serialization."""
@@ -129,9 +132,11 @@ class FieldMetadata:
# Protobuf type name # Protobuf type name
proto_type: str proto_type: str
# Map information if the proto_type is a map # Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]] map_types: Optional[Tuple[str, str]] = None
# Groups several "one-of" fields together # Groups several "one-of" fields together
group: Optional[str] group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
@staticmethod @staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata": def get(field: dataclasses.Field) -> "FieldMetadata":
@@ -145,11 +150,14 @@ def dataclass_field(
*, *,
map_types: Optional[Tuple[str, str]] = None, map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None, group: Optional[str] = None,
wraps: Optional[str] = None,
) -> dataclasses.Field: ) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata.""" """Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field( return dataclasses.field(
default=PLACEHOLDER, default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)}, metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
},
) )
@@ -222,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group) return dataclass_field(number, TYPE_BYTES, group=group)
def message_field(number: int, group: Optional[str] = None) -> Any: def message_field(
return dataclass_field(number, TYPE_MESSAGE, group=group) number: int, group: Optional[str] = None, wraps: Optional[str] = None
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
def map_field( def map_field(
@@ -274,7 +284,7 @@ def encode_varint(value: int) -> bytes:
return bytes(b + [bits]) return bytes(b + [bits])
def _preprocess_single(proto_type: str, value: Any) -> bytes: def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization.""" """Adjusts values before serialization."""
if proto_type in [ if proto_type in [
TYPE_ENUM, TYPE_ENUM,
@@ -297,16 +307,37 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
elif proto_type == TYPE_STRING: elif proto_type == TYPE_STRING:
return value.encode("utf-8") return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE: elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)
return bytes(value) return bytes(value)
return value return value
def _serialize_single( def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes: ) -> bytes:
"""Serializes a single field and value.""" """Serializes a single field and value."""
value = _preprocess_single(proto_type, value) value = _preprocess_single(proto_type, wraps, value)
output = b"" output = b""
if proto_type in WIRE_VARINT_TYPES: if proto_type in WIRE_VARINT_TYPES:
@@ -319,7 +350,7 @@ def _serialize_single(
key = encode_varint((field_number << 3) | 1) key = encode_varint((field_number << 3) | 1)
output += key + value output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES: elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty: if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2) key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value output += key + encode_varint(len(value)) + value
else: else:
@@ -359,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
while i < len(value): while i < len(value):
start = i start = i
num_wire, i = decode_varint(value, i) num_wire, i = decode_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3 number = num_wire >> 3
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
@@ -375,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
elif wire_type == 5: elif wire_type == 5:
decoded, i = value[i : i + 4], i + 4 decoded, i = value[i : i + 4], i + 4
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield ParsedField( yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i] number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
) )
@@ -392,15 +420,19 @@ class Message(ABC):
register the message fields which get used by the serializers and parsers register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations. to go between Python, binary and JSON protobuf message representations.
""" """
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has # Set a default value for each field in the class after `__init__` has
# already been run. # already been run.
group_map = {"fields": {}, "groups": {}} group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
if meta.group: if meta.group:
# This is part of a one-of group.
group_map["fields"][field.name] = meta.group group_map["fields"][field.name] = meta.group
if meta.group not in group_map["groups"]: if meta.group not in group_map["groups"]:
@@ -450,6 +482,11 @@ class Message(ABC):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
value = getattr(self, field.name) value = getattr(self, field.name)
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
continue
# Being selected in a a group means this field is the one that is # Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even # currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value. # if the value is the default zero value.
@@ -457,42 +494,48 @@ class Message(ABC):
if meta.group and self._group_map["groups"][meta.group]["current"] == field: if meta.group and self._group_map["groups"][meta.group]["current"] == field:
selected_in_group = True selected_in_group = True
if isinstance(value, list): serialize_empty = False
if not len(value) and not selected_in_group: if isinstance(value, Message) and value._serialized_on_wire:
# Empty values are not serialized # Empty messages can still be sent on the wire if they were
continue # set (or received empty).
serialize_empty = True
if value == self._get_field_default(field, meta) and not (
selected_in_group or serialize_empty
):
# Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to
# serialize an empty message (i.e. zero value was explicitly
# set by the user).
continue
if isinstance(value, list):
if meta.proto_type in PACKED_TYPES: if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First, # Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then # preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes. # treat it like a field of raw bytes.
buf = b"" buf = b""
for item in value: for item in value:
buf += _preprocess_single(meta.proto_type, item) buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf) output += _serialize_single(meta.number, TYPE_BYTES, buf)
else: else:
for item in value: for item in value:
output += _serialize_single(meta.number, meta.proto_type, item) output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
)
elif isinstance(value, dict): elif isinstance(value, dict):
if not len(value) and not selected_in_group:
# Empty values are not serialized
continue
for k, v in value.items(): for k, v in value.items():
assert meta.map_types assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k) sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v) sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv) output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else: else:
if value == get_default(meta.proto_type) and not selected_in_group:
# Default (zero) values are not serialized
continue
serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
serialize_empty = True
output += _serialize_single( output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps or "",
) )
return output + self._unknown_fields return output + self._unknown_fields
@@ -500,30 +543,45 @@ class Message(ABC):
# For compatibility with other libraries # For compatibility with other libraries
SerializeToString = __bytes__ SerializeToString = __bytes__
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: def _type_hint(self, field_name: str) -> Type:
"""Get the message class for a field from the type hints."""
module = inspect.getmodule(self.__class__) module = inspect.getmodule(self.__class__)
type_hints = get_type_hints(self.__class__, vars(module)) type_hints = get_type_hints(self.__class__, vars(module))
cls = type_hints[field.name] return type_hints[field_name]
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
cls = self._type_hint(field.name)
if hasattr(cls, "__args__") and index >= 0: if hasattr(cls, "__args__") and index >= 0:
cls = type_hints[field.name].__args__[index] cls = cls.__args__[index]
return cls return cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
t = self._cls_for(field, index=-1) t = self._type_hint(field.name)
value: Any = 0 value: Any = 0
if meta.proto_type == TYPE_MAP: if hasattr(t, "__origin__"):
# Maps cannot be repeated, so we check these first. if t.__origin__ == dict:
value = {} # This is some kind of map (dict in Python).
elif hasattr(t, "__args__") and len(t.__args__) == 1: value = {}
# Anything else with type args is a list. elif t.__origin__ == list:
value = [] # This is some kind of list (repeated) field.
elif meta.proto_type == TYPE_MESSAGE: value = []
# Message means creating an instance of the right type. elif t.__origin__ == Union and t.__args__[1] == type(None):
value = t() # This is an optional (wrapped) field. For setting the default we
# really don't care what kind of field it is.
value = None
else:
value = t()
elif issubclass(t, Enum):
# Enums always default to zero.
value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = DATETIME_ZERO
else: else:
value = get_default(meta.proto_type) # This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
value = t()
return value return value
@@ -540,6 +598,9 @@ class Message(ABC):
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]: elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Undo zig-zag encoding # Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1)) value = (value >> 1) ^ (-(value & 1))
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type) fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0] value = struct.unpack(fmt, value)[0]
@@ -548,8 +609,18 @@ class Message(ABC):
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE: elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field) cls = self._cls_for(field)
value = cls().parse(value)
value._serialized_on_wire = True if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
elif meta.wraps:
# This is a Google wrapper value message around a single
# scalar type.
value = _get_wrapper(meta.wraps)().parse(value).value
else:
value = cls().parse(value)
value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP: elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each # TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class. # key/value pair will recreate the class.
@@ -624,48 +695,59 @@ class Message(ABC):
def FromString(cls: Type[T], data: bytes) -> T: def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data) return cls().parse(data)
def to_dict(self) -> dict: def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
""" """
Returns a dict representation of this message instance which can be Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON. used to serialize to e.g. JSON. Defaults to camel casing for
compatibility but can be set to other modes.
""" """
output: Dict[str, Any] = {} output: Dict[str, Any] = {}
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
v = getattr(self, field.name) v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message": if meta.proto_type == "message":
if isinstance(v, list): if isinstance(v, datetime):
if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0):
output[cased_name] = _Duration.delta_to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
elif isinstance(v, list):
# Convert each item. # Convert each item.
v = [i.to_dict() for i in v] v = [i.to_dict() for i in v]
output[field.name] = v output[cased_name] = v
elif v._serialized_on_wire: elif v._serialized_on_wire:
output[field.name] = v.to_dict() output[cased_name] = v.to_dict()
elif meta.proto_type == "map": elif meta.proto_type == "map":
for k in v: for k in v:
if hasattr(v[k], "to_dict"): if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict() v[k] = v[k].to_dict()
if v: if v:
output[field.name] = v output[cased_name] = v
elif v != get_default(meta.proto_type): elif v != self._get_field_default(field, meta):
if meta.proto_type in INT_64_TYPES: if meta.proto_type in INT_64_TYPES:
if isinstance(v, list): if isinstance(v, list):
output[field.name] = [str(n) for n in v] output[cased_name] = [str(n) for n in v]
else: else:
output[field.name] = str(v) output[cased_name] = str(v)
elif meta.proto_type == TYPE_BYTES: elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list): if isinstance(v, list):
output[field.name] = [b64encode(b).decode("utf8") for b in v] output[cased_name] = [b64encode(b).decode("utf8") for b in v]
else: else:
output[field.name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field)) enum_values = list(self._cls_for(field)) # type: ignore
if isinstance(v, list): if isinstance(v, list):
output[field.name] = [enum_values[e].name for e in v] output[cased_name] = [enum_values[e].name for e in v]
else: else:
output[field.name] = enum_values[v].name output[cased_name] = enum_values[v].name
else: else:
output[field.name] = v output[cased_name] = v
return output return output
def from_dict(self: T, value: dict) -> T: def from_dict(self: T, value: dict) -> T:
@@ -674,44 +756,58 @@ class Message(ABC):
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
""" """
self._serialized_on_wire = True self._serialized_on_wire = True
for field in dataclasses.fields(self): fields_by_name = {f.name: f for f in dataclasses.fields(self)}
meta = FieldMetadata.get(field) for key in value:
if field.name in value and value[field.name] is not None: snake_cased = safe_snake_case(key)
if meta.proto_type == "message": if snake_cased in fields_by_name:
v = getattr(self, field.name) field = fields_by_name[snake_cased]
# print(v, value[field.name]) meta = FieldMetadata.get(field)
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[field.name])):
v.append(cls().from_dict(value[field.name][i]))
else:
v.from_dict(value[field.name])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
else:
v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[field.name], list):
v = [b64decode(n) for n in value[field.name]]
else:
v = b64decode(value[field.name])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
if v is not None: if value[key] is not None:
setattr(self, field.name, v) if meta.proto_type == "message":
v = getattr(self, field.name)
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
if v is not None:
setattr(self, field.name, v)
return self return self
def to_json(self, indent: Union[None, int, str] = None) -> str: def to_json(self, indent: Union[None, int, str] = None) -> str:
@@ -743,6 +839,140 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field.name, getattr(message, field.name)) return (field.name, getattr(message, field.name))
@dataclasses.dataclass
class _Duration(Message):
# Signed seconds of the span of time. Must be from -315,576,000,000 to
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
seconds: int = int64_field(1)
# Signed fractions of a second at nanosecond resolution of the span of time.
# Durations less than one second are represented with a 0 `seconds` field and
# a positive or negative `nanos` field. For durations of one second or more,
# a non-zero value for the `nanos` field must be of the same sign as the
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
nanos: int = int32_field(2)
def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod
def delta_to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
@dataclasses.dataclass
class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
seconds: int = int64_field(1)
# Non-negative fractions of a second at nanosecond resolution. Negative
# second values with fractions must still have non-negative nanos values that
# count forward in time. Must be from 0 to 999,999,999 inclusive.
nanos: int = int32_field(2)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def timestamp_to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + ".%09dZ" % nanos
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
return self.value
def from_dict(self: T, value: Any) -> T:
if value is not None:
self.value = value
return self
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,
TYPE_INT32: _Int32Value,
TYPE_UINT32: _UInt32Value,
TYPE_INT64: _Int64Value,
TYPE_UINT64: _UInt64Value,
TYPE_FLOAT: _FloatValue,
TYPE_DOUBLE: _DoubleValue,
TYPE_STRING: _StringValue,
TYPE_BYTES: _BytesValue,
}[proto_type]
class ServiceStub(ABC): class ServiceStub(ABC):
""" """
Base class for async gRPC service stubs. Base class for async gRPC service stubs.

41
betterproto/casing.py Normal file
View File

@@ -0,0 +1,41 @@
import stringcase
def safe_snake_case(value: str) -> str:
"""Snake case a value taking into account Python keywords."""
value = stringcase.snakecase(value)
if value in [
"and",
"as",
"assert",
"break",
"class",
"continue",
"def",
"del",
"elif",
"else",
"except",
"finally",
"for",
"from",
"global",
"if",
"import",
"in",
"is",
"lambda",
"nonlocal",
"not",
"or",
"pass",
"raise",
"return",
"try",
"while",
"with",
"yield",
]:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
value += "_"
return value

View File

@@ -16,6 +16,8 @@ except ImportError:
) )
raise SystemExit(1) raise SystemExit(1)
import stringcase
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import ( from google.protobuf.descriptor_pb2 import (
DescriptorProto, DescriptorProto,
@@ -25,11 +27,20 @@ from google.protobuf.descriptor_pb2 import (
ServiceDescriptorProto, ServiceDescriptorProto,
) )
from betterproto.casing import safe_snake_case
def snake_case(value: str) -> str:
return ( WRAPPER_TYPES = {
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_") "google.protobuf.DoubleValue": "float",
) "google.protobuf.FloatValue": "float",
"google.protobuf.Int64Value": "int",
"google.protobuf.UInt64Value": "int",
"google.protobuf.Int32Value": "int",
"google.protobuf.UInt32Value": "int",
"google.protobuf.BoolValue": "bool",
"google.protobuf.StringValue": "str",
"google.protobuf.BytesValue": "bytes",
}
def get_ref_type(package: str, imports: set, type_name: str) -> str: def get_ref_type(package: str, imports: set, type_name: str) -> str:
@@ -37,15 +48,33 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
necessary. necessary.
""" """
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".") type_name = type_name.lstrip(".")
if type_name in WRAPPER_TYPES:
return f"Optional[{WRAPPER_TYPES[type_name]}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
if type_name.startswith(package): if type_name.startswith(package):
# This is the current package, which has nested types flattened. parts = type_name.lstrip(package).lstrip(".").split(".")
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"' if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
if "." in type_name: if "." in type_name:
# This is imported from another package. No need # This is imported from another package. No need
# to use a forward ref and we need to add the import. # to use a forward ref and we need to add the import.
parts = type_name.split(".") parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}" type_name = f"{parts[-2]}.{parts[-1]}"
@@ -122,7 +151,7 @@ def get_comment(proto_file, path: List[int]) -> str:
if path[-2] == 2 and path[-4] != 6: if path[-2] == 2 and path[-4] != 6:
# This is a field # This is a field
return " # " + " # ".join(lines) return " # " + "\n # ".join(lines)
else: else:
# This is a message, enum, service, or method # This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 70: if len(lines) == 1 and len(lines[0]) < 70:
@@ -146,6 +175,9 @@ def generate_code(request, response):
output_map = {} output_map = {}
for proto_file in request.proto_file: for proto_file in request.proto_file:
out = proto_file.package out = proto_file.package
if out == "google.protobuf":
continue
if not out: if not out:
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
@@ -163,6 +195,7 @@ def generate_code(request, response):
"package": package, "package": package,
"files": [f.name for f in options["files"]], "files": [f.name for f in options["files"]],
"imports": set(), "imports": set(),
"datetime_imports": set(),
"typing_imports": set(), "typing_imports": set(),
"messages": [], "messages": [],
"enums": [], "enums": [],
@@ -179,7 +212,7 @@ def generate_code(request, response):
for item, path in traverse(proto_file): for item, path in traverse(proto_file):
# print(item, file=sys.stderr) # print(item, file=sys.stderr)
# print(path, file=sys.stderr) # print(path, file=sys.stderr)
data = {"name": item.name} data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
if isinstance(item, DescriptorProto): if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr) # print(item, file=sys.stderr)
@@ -203,6 +236,14 @@ def generate_code(request, response):
packed = False packed = False
field_type = f.Type.Name(f.type).lower()[5:] field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
if f.type_name.startswith(
".google.protobuf"
) and f.type_name.endswith("Value"):
w = f.type_name.split(".").pop()[:-5].upper()
field_wraps = f"betterproto.TYPE_{w}"
map_types = None map_types = None
if f.type == 11: if f.type == 11:
# This might be a map... # This might be a map...
@@ -252,13 +293,23 @@ def generate_code(request, response):
if f.HasField("oneof_index"): if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
output["typing_imports"].add("Optional")
if "timedelta" in t:
output["datetime_imports"].add("timedelta")
elif "datetime" in t:
output["datetime_imports"].add("datetime")
data["properties"].append( data["properties"].append(
{ {
"name": f.name, "name": f.name,
"py_name": safe_snake_case(f.name),
"number": f.number, "number": f.number,
"comment": get_comment(proto_file, path + [2, i]), "comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type), "proto_type": int(f.type),
"field_type": field_type, "field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types, "map_types": map_types,
"type": t, "type": t,
"zero": zero, "zero": zero,
@@ -294,6 +345,7 @@ def generate_code(request, response):
data = { data = {
"name": service.name, "name": service.name,
"py_name": stringcase.pascalcase(service.name),
"comment": get_comment(proto_file, [6, i]), "comment": get_comment(proto_file, [6, i]),
"methods": [], "methods": [],
} }
@@ -317,7 +369,7 @@ def generate_code(request, response):
data["methods"].append( data["methods"].append(
{ {
"name": method.name, "name": method.name,
"py_name": snake_case(method.name), "py_name": stringcase.snakecase(method.name),
"comment": get_comment(proto_file, [6, i, 2, j]), "comment": get_comment(proto_file, [6, i, 2, j]),
"route": f"/{package}.{service.name}/{method.name}", "route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type( "input": get_ref_type(
@@ -338,6 +390,7 @@ def generate_code(request, response):
output["services"].append(data) output["services"].append(data)
output["imports"] = sorted(output["imports"]) output["imports"] = sorted(output["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"]) output["typing_imports"] = sorted(output["typing_imports"])
# Fill response # Fill response
@@ -361,10 +414,20 @@ def generate_code(request, response):
inits.add(base) inits.add(base)
for base in inits: for base in inits:
name = os.path.join(base, "__init__.py")
if os.path.exists(name):
# Never overwrite inits as they may have custom stuff in them.
continue
init = response.file.add() init = response.file.add()
init.name = os.path.join(base, "__init__.py") init.name = name
init.content = b"" init.content = b""
filenames = sorted([f.name for f in response.file])
for fname in filenames:
print(f"Writing {fname}", file=sys.stderr)
def main(): def main():
"""The plugin's main entry point.""" """The plugin's main entry point."""

View File

@@ -2,6 +2,10 @@
# sources: {{ ', '.join(description.files) }} # sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto # plugin: python-betterproto
from dataclasses import dataclass from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %} {% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@@ -18,7 +22,7 @@ import grpclib
{% if description.enums %}{% for enum in description.enums %} {% if description.enums %}{% for enum in description.enums %}
class {{ enum.name }}(betterproto.Enum): class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %} {% if enum.comment %}
{{ enum.comment }} {{ enum.comment }}
@@ -35,7 +39,7 @@ class {{ enum.name }}(betterproto.Enum):
{% endif %} {% endif %}
{% for message in description.messages %} {% for message in description.messages %}
@dataclass @dataclass
class {{ message.name }}(betterproto.Message): class {{ message.py_name }}(betterproto.Message):
{% if message.comment %} {% if message.comment %}
{{ message.comment }} {{ message.comment }}
@@ -44,7 +48,7 @@ class {{ message.name }}(betterproto.Message):
{% if field.comment %} {% if field.comment %}
{{ field.comment }} {{ field.comment }}
{% endif %} {% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %} {% endfor %}
{% if not message.properties %} {% if not message.properties %}
pass pass
@@ -53,13 +57,13 @@ class {{ message.name }}(betterproto.Message):
{% endfor %} {% endfor %}
{% for service in description.services %} {% for service in description.services %}
class {{ service.name }}Stub(betterproto.ServiceStub): class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %} {% if service.comment %}
{{ service.comment }} {{ service.comment }}
{% endif %} {% endif %}
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}

View File

@@ -0,0 +1,3 @@
{
"value": true
}

View File

@@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
bool value = 1;
}

View File

@@ -0,0 +1,4 @@
{
"camelCase": 1,
"snakeCase": "ONE"
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
enum my_enum {
ZERO = 0;
ONE = 1;
TWO = 2;
}
message Test {
int32 camelCase = 1;
my_enum snake_case = 2;
}

View File

@@ -72,7 +72,8 @@ if __name__ == "__main__":
input_json = open(filename).read() input_json = open(filename).read()
parsed = Parse(input_json, imported.Test()) parsed = Parse(input_json, imported.Test())
serialized = parsed.SerializeToString() serialized = parsed.SerializeToString()
serialized_json = MessageToJson(parsed, preserving_proto_field_name=True) preserve = "casing" not in filename
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
s_loaded = json.loads(serialized_json) s_loaded = json.loads(serialized_json)
in_loaded = json.loads(input_json) in_loaded = json.loads(input_json)

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1,5 @@
{
"maybe": false,
"ts": "1972-01-01T10:00:20.021Z",
"duration": "1.200s"
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
google.protobuf.Int32Value important = 4;
}

View File

@@ -0,0 +1,5 @@
{
"for": 1,
"with": 2,
"as": 3
}

View File

@@ -0,0 +1,7 @@
syntax = "proto3";
message Test {
int32 for = 1;
int32 with = 2;
int32 as = 3;
}

View File

@@ -1,5 +1,6 @@
import betterproto import betterproto
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
def test_has_field(): def test_has_field():
@@ -115,3 +116,49 @@ def test_oneof_support():
assert betterproto.which_one_of(foo2, "group1")[0] == "bar" assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
assert foo.bar == 0 assert foo.bar == 0
assert betterproto.which_one_of(foo2, "group2")[0] == "" assert betterproto.which_one_of(foo2, "group2")[0] == ""
def test_json_casing():
@dataclass
class CasingTest(betterproto.Message):
pascal_case: int = betterproto.int32_field(1)
camel_case: int = betterproto.int32_field(2)
snake_case: int = betterproto.int32_field(3)
kabob_case: int = betterproto.int32_field(4)
# Parsing should accept almost any input
test = CasingTest().from_dict(
{"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
)
assert test == CasingTest(1, 2, 3, 4)
# Serializing should be strict.
assert test.to_dict() == {
"pascalCase": 1,
"camelCase": 2,
"snakeCase": 3,
"kabobCase": 4,
}
assert test.to_dict(casing=betterproto.Casing.SNAKE) == {
"pascal_case": 1,
"camel_case": 2,
"snake_case": 3,
"kabob_case": 4,
}
def test_optional_flag():
@dataclass
class Request(betterproto.Message):
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
# Serialization of not passed vs. set vs. zero-value.
assert bytes(Request()) == b""
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
assert bytes(Request(flag=False)) == b"\n\x00"
# Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None
assert Request().parse(b"\n\x00").flag == False

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="betterproto", name="betterproto",
version="1.0.1", version="1.1.0",
description="A better Protobuf / gRPC generator & library", description="A better Protobuf / gRPC generator & library",
long_description=open("README.md", "r").read(), long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
@@ -18,7 +18,7 @@ setup(
), ),
package_data={"betterproto": ["py.typed", "templates/template.py"]}, package_data={"betterproto": ["py.typed", "templates/template.py"]},
python_requires=">=3.7", python_requires=">=3.7",
install_requires=["grpclib"], install_requires=["grpclib", "stringcase"],
extras_require={"compiler": ["jinja2", "protobuf"]}, extras_require={"compiler": ["jinja2", "protobuf"]},
zip_safe=False, zip_safe=False,
) )