Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6fd9612ee1 | ||
|
ba520f88a4 | ||
|
b0b64fcbaf | ||
|
7900c7c9db | ||
|
fcc273e294 | ||
|
f820397751 | ||
|
16687211a2 | ||
|
eb5020db2a | ||
|
035793aec3 | ||
|
c79535b614 | ||
|
5daf61f64c | ||
|
4679c571c3 | ||
|
ff8463cf12 | ||
|
eff9021529 | ||
|
d43d5af5ce | ||
|
ef0a1bf50c | ||
|
0e389abbef |
40
.github/workflows/ci.yml
vendored
40
.github/workflows/ci.yml
vendored
@@ -4,20 +4,30 @@ on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.7
|
||||
- uses: dschep/install-pipenv-action@v1
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt install protobuf-compiler
|
||||
pipenv install --dev
|
||||
- name: Run tests
|
||||
run: |
|
||||
pipenv run generate
|
||||
pipenv run test
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.7
|
||||
- uses: dschep/install-pipenv-action@v1
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt install protobuf-compiler libprotobuf-dev
|
||||
pipenv install --dev
|
||||
- name: Run tests
|
||||
run: |
|
||||
cp .env.default .env
|
||||
pipenv run pip install -e .
|
||||
pipenv run generate
|
||||
pipenv run test
|
||||
- name: Build package
|
||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
||||
run: pipenv run python setup.py sdist
|
||||
- name: Publish package
|
||||
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
|
||||
uses: pypa/gh-action-pypi-publish@v1.0.0a0
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.pypi }}
|
||||
|
1
Pipfile
1
Pipfile
@@ -14,6 +14,7 @@ rope = "*"
|
||||
protobuf = "*"
|
||||
jinja2 = "*"
|
||||
grpclib = "*"
|
||||
stringcase = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.7"
|
||||
|
52
Pipfile.lock
generated
52
Pipfile.lock
generated
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40"
|
||||
"sha256": "28c38cd6c4eafb0b9ac9a64cf623145868fdee163111d3b941b34d23011db6ca"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@@ -147,6 +147,13 @@
|
||||
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
|
||||
],
|
||||
"version": "==1.12.0"
|
||||
},
|
||||
"stringcase": {
|
||||
"hashes": [
|
||||
"sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==1.2.0"
|
||||
}
|
||||
},
|
||||
"develop": {
|
||||
@@ -159,10 +166,10 @@
|
||||
},
|
||||
"attrs": {
|
||||
"hashes": [
|
||||
"sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2",
|
||||
"sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"
|
||||
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
|
||||
"sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"
|
||||
],
|
||||
"version": "==19.2.0"
|
||||
"version": "==19.3.0"
|
||||
},
|
||||
"entrypoints": {
|
||||
"hashes": [
|
||||
@@ -211,26 +218,30 @@
|
||||
},
|
||||
"mypy": {
|
||||
"hashes": [
|
||||
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
|
||||
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
|
||||
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
|
||||
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
|
||||
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
|
||||
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
|
||||
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
|
||||
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
|
||||
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
|
||||
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
|
||||
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
|
||||
"sha256:1521c186a3d200c399bd5573c828ea2db1362af7209b2adb1bb8532cea2fb36f",
|
||||
"sha256:31a046ab040a84a0fc38bc93694876398e62bc9f35eca8ccbf6418b7297f4c00",
|
||||
"sha256:3b1a411909c84b2ae9b8283b58b48541654b918e8513c20a400bb946aa9111ae",
|
||||
"sha256:48c8bc99380575deb39f5d3400ebb6a8a1cb5cc669bbba4d3bb30f904e0a0e7d",
|
||||
"sha256:540c9caa57a22d0d5d3c69047cc9dd0094d49782603eb03069821b41f9e970e9",
|
||||
"sha256:672e418425d957e276c291930a3921b4a6413204f53fe7c37cad7bc57b9a3391",
|
||||
"sha256:6ed3b9b3fdc7193ea7aca6f3c20549b377a56f28769783a8f27191903a54170f",
|
||||
"sha256:9371290aa2cad5ad133e4cdc43892778efd13293406f7340b9ffe99d5ec7c1d9",
|
||||
"sha256:ace6ac1d0f87d4072f05b5468a084a45b4eda970e4d26704f201e06d47ab2990",
|
||||
"sha256:b428f883d2b3fe1d052c630642cc6afddd07d5cd7873da948644508be3b9d4a7",
|
||||
"sha256:d5bf0e6ec8ba346a2cf35cb55bf4adfddbc6b6576fcc9e10863daa523e418dbb",
|
||||
"sha256:d7574e283f83c08501607586b3167728c58e8442947e027d2d4c7dcd6d82f453",
|
||||
"sha256:dc889c84241a857c263a2b1cd1121507db7d5b5f5e87e77147097230f374d10b",
|
||||
"sha256:f4748697b349f373002656bf32fede706a0e713d67bfdcf04edf39b1f61d46eb"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==0.730"
|
||||
"version": "==0.740"
|
||||
},
|
||||
"mypy-extensions": {
|
||||
"hashes": [
|
||||
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
|
||||
"sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d",
|
||||
"sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"
|
||||
],
|
||||
"version": "==0.4.2"
|
||||
"version": "==0.4.3"
|
||||
},
|
||||
"packaging": {
|
||||
"hashes": [
|
||||
@@ -300,20 +311,25 @@
|
||||
},
|
||||
"typed-ast": {
|
||||
"hashes": [
|
||||
"sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161",
|
||||
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
|
||||
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
|
||||
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
|
||||
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
|
||||
"sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47",
|
||||
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
|
||||
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
|
||||
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
|
||||
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
|
||||
"sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2",
|
||||
"sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e",
|
||||
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
|
||||
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
|
||||
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
|
||||
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
|
||||
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
|
||||
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
|
||||
"sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66",
|
||||
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
|
||||
],
|
||||
"version": "==1.4.0"
|
||||
|
66
README.md
66
README.md
@@ -10,6 +10,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
|
||||
- Enums
|
||||
- Dataclasses
|
||||
- `async`/`await`
|
||||
- Timezone-aware `datetime` and `timedelta` objects
|
||||
- Relative imports
|
||||
- Mypy type checking
|
||||
|
||||
@@ -34,6 +35,8 @@ This project exists because I am unhappy with the state of the official Google p
|
||||
- Much code looks like C++ or Java ported 1:1 to Python
|
||||
- Capitalized function names like `HasField()` and `SerializeToString()`
|
||||
- Uses `SerializeToString()` rather than the built-in `__bytes__()`
|
||||
- Special wrapped types don't use Python's `None`
|
||||
- Timestamp/duration types don't use Python's built-in `datetime` module
|
||||
|
||||
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
|
||||
|
||||
@@ -155,7 +158,7 @@ You can use it like so (enable async in the interactive shell first):
|
||||
EchoResponse(values=["hello", "hello"])
|
||||
|
||||
>>> async for response in service.echo_stream(value="hello", extra_times=1)
|
||||
print(response)
|
||||
print(response)
|
||||
|
||||
EchoStreamResponse(value="hello")
|
||||
EchoStreamResponse(value="hello")
|
||||
@@ -168,6 +171,12 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
|
||||
- Dicts: `Message().to_dict()`, `Message().from_dict(...)`
|
||||
- JSON: `Message().to_json()`, `Message().from_json(...)`
|
||||
|
||||
For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g:
|
||||
|
||||
```py
|
||||
>>> MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
|
||||
```
|
||||
|
||||
### Determining if a message was sent
|
||||
|
||||
Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
|
||||
@@ -238,6 +247,53 @@ Again this is a little different than the official Google code generator:
|
||||
["foo", "foo's value"]
|
||||
```
|
||||
|
||||
### Well-Known Google Types
|
||||
|
||||
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
|
||||
|
||||
| Google Message | Python Type | Default |
|
||||
| --------------------------- | ---------------------------------------- | ---------------------- |
|
||||
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
|
||||
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
|
||||
| `google.protobuf.*Value` | `Optional[...]` | `None` |
|
||||
|
||||
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
|
||||
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
|
||||
|
||||
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
|
||||
|
||||
```protobuf
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.BoolValue maybe = 1;
|
||||
google.protobuf.Timestamp ts = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
}
|
||||
```
|
||||
|
||||
You can do stuff like:
|
||||
|
||||
```py
|
||||
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
|
||||
>>> t
|
||||
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
|
||||
|
||||
>>> t.ts - t.duration
|
||||
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
||||
|
||||
>>> t.ts.isoformat()
|
||||
'2019-01-01T12:00:00+00:00'
|
||||
|
||||
>>> t.maybe = None
|
||||
>>> t.to_dict()
|
||||
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
||||
@@ -295,14 +351,14 @@ $ pipenv run tests
|
||||
- [x] Bytes as base64
|
||||
- [ ] Any support
|
||||
- [x] Enum strings
|
||||
- [ ] Well known types support (timestamp, duration, wrappers)
|
||||
- [ ] Support different casing (orig vs. camel vs. others?)
|
||||
- [x] Well known types support (timestamp, duration, wrappers)
|
||||
- [x] Support different casing (orig vs. camel vs. others?)
|
||||
- [ ] Async service stubs
|
||||
- [x] Unary-unary
|
||||
- [x] Server streaming response
|
||||
- [ ] Client streaming request
|
||||
- [ ] Renaming messages and fields to conform to Python name standards
|
||||
- [ ] Renaming clashes with language keywords and standard library top-level packages
|
||||
- [x] Renaming messages and fields to conform to Python name standards
|
||||
- [x] Renaming clashes with language keywords
|
||||
- [x] Python package
|
||||
- [x] Automate running tests
|
||||
- [ ] Cleanup!
|
||||
|
@@ -5,6 +5,7 @@ import json
|
||||
import struct
|
||||
from abc import ABC
|
||||
from base64 import b64encode, b64decode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -24,6 +25,9 @@ from typing import (
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
import stringcase
|
||||
|
||||
from .casing import safe_snake_case
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
@@ -101,6 +105,17 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = stringcase.camelcase
|
||||
SNAKE = stringcase.snakecase
|
||||
|
||||
|
||||
class _PLACEHOLDER:
|
||||
pass
|
||||
|
||||
@@ -108,18 +123,6 @@ class _PLACEHOLDER:
|
||||
PLACEHOLDER: Any = _PLACEHOLDER()
|
||||
|
||||
|
||||
def get_default(proto_type: str) -> Any:
|
||||
"""Get the default (zero value) for a given type."""
|
||||
return {
|
||||
TYPE_BOOL: False,
|
||||
TYPE_FLOAT: 0.0,
|
||||
TYPE_DOUBLE: 0.0,
|
||||
TYPE_STRING: "",
|
||||
TYPE_BYTES: b"",
|
||||
TYPE_MAP: {},
|
||||
}.get(proto_type, 0)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FieldMetadata:
|
||||
"""Stores internal metadata used for parsing & serialization."""
|
||||
@@ -129,9 +132,11 @@ class FieldMetadata:
|
||||
# Protobuf type name
|
||||
proto_type: str
|
||||
# Map information if the proto_type is a map
|
||||
map_types: Optional[Tuple[str, str]]
|
||||
map_types: Optional[Tuple[str, str]] = None
|
||||
# Groups several "one-of" fields together
|
||||
group: Optional[str]
|
||||
group: Optional[str] = None
|
||||
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
|
||||
wraps: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||
@@ -145,11 +150,14 @@ def dataclass_field(
|
||||
*,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
group: Optional[str] = None,
|
||||
wraps: Optional[str] = None,
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
return dataclasses.field(
|
||||
default=PLACEHOLDER,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
|
||||
metadata={
|
||||
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -222,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES, group=group)
|
||||
|
||||
|
||||
def message_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, group=group)
|
||||
def message_field(
|
||||
number: int, group: Optional[str] = None, wraps: Optional[str] = None
|
||||
) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
|
||||
|
||||
|
||||
def map_field(
|
||||
@@ -274,7 +284,7 @@ def encode_varint(value: int) -> bytes:
|
||||
return bytes(b + [bits])
|
||||
|
||||
|
||||
def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||
"""Adjusts values before serialization."""
|
||||
if proto_type in [
|
||||
TYPE_ENUM,
|
||||
@@ -297,16 +307,37 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
elif proto_type == TYPE_STRING:
|
||||
return value.encode("utf-8")
|
||||
elif proto_type == TYPE_MESSAGE:
|
||||
if isinstance(value, datetime):
|
||||
# Convert the `datetime` to a timestamp message.
|
||||
seconds = int(value.timestamp())
|
||||
nanos = int(value.microsecond * 1e3)
|
||||
value = _Timestamp(seconds=seconds, nanos=nanos)
|
||||
elif isinstance(value, timedelta):
|
||||
# Convert the `timedelta` to a duration message.
|
||||
total_ms = value // timedelta(microseconds=1)
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
value = _Duration(seconds=seconds, nanos=nanos)
|
||||
elif wraps:
|
||||
if value is None:
|
||||
return b""
|
||||
value = _get_wrapper(wraps)(value=value)
|
||||
|
||||
return bytes(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_single(
|
||||
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
|
||||
field_number: int,
|
||||
proto_type: str,
|
||||
value: Any,
|
||||
*,
|
||||
serialize_empty: bool = False,
|
||||
wraps: str = "",
|
||||
) -> bytes:
|
||||
"""Serializes a single field and value."""
|
||||
value = _preprocess_single(proto_type, value)
|
||||
value = _preprocess_single(proto_type, wraps, value)
|
||||
|
||||
output = b""
|
||||
if proto_type in WIRE_VARINT_TYPES:
|
||||
@@ -319,7 +350,7 @@ def _serialize_single(
|
||||
key = encode_varint((field_number << 3) | 1)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||
if len(value) or serialize_empty:
|
||||
if len(value) or serialize_empty or wraps:
|
||||
key = encode_varint((field_number << 3) | 2)
|
||||
output += key + encode_varint(len(value)) + value
|
||||
else:
|
||||
@@ -359,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
while i < len(value):
|
||||
start = i
|
||||
num_wire, i = decode_varint(value, i)
|
||||
# print(num_wire, i)
|
||||
number = num_wire >> 3
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
@@ -375,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
elif wire_type == 5:
|
||||
decoded, i = value[i : i + 4], i + 4
|
||||
|
||||
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
||||
|
||||
yield ParsedField(
|
||||
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
|
||||
)
|
||||
@@ -392,15 +420,19 @@ class Message(ABC):
|
||||
register the message fields which get used by the serializers and parsers
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_map: Dict[str, dict]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
group_map = {"fields": {}, "groups": {}}
|
||||
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if meta.group:
|
||||
# This is part of a one-of group.
|
||||
group_map["fields"][field.name] = meta.group
|
||||
|
||||
if meta.group not in group_map["groups"]:
|
||||
@@ -450,6 +482,11 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
value = getattr(self, field.name)
|
||||
|
||||
if value is None:
|
||||
# Optional items should be skipped. This is used for the Google
|
||||
# wrapper types.
|
||||
continue
|
||||
|
||||
# Being selected in a a group means this field is the one that is
|
||||
# currently set in a `oneof` group, so it must be serialized even
|
||||
# if the value is the default zero value.
|
||||
@@ -457,42 +494,48 @@ class Message(ABC):
|
||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||
selected_in_group = True
|
||||
|
||||
if isinstance(value, list):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field, meta) and not (
|
||||
selected_in_group or serialize_empty
|
||||
):
|
||||
# Default (zero) values are not serialized. Two exceptions are
|
||||
# if this is the selected oneof item or if we know we have to
|
||||
# serialize an empty message (i.e. zero value was explicitly
|
||||
# set by the user).
|
||||
continue
|
||||
|
||||
if isinstance(value, list):
|
||||
if meta.proto_type in PACKED_TYPES:
|
||||
# Packed lists look like a length-delimited field. First,
|
||||
# preprocess/encode each value into a buffer and then
|
||||
# treat it like a field of raw bytes.
|
||||
buf = b""
|
||||
for item in value:
|
||||
buf += _preprocess_single(meta.proto_type, item)
|
||||
buf += _preprocess_single(meta.proto_type, "", item)
|
||||
output += _serialize_single(meta.number, TYPE_BYTES, buf)
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
|
||||
)
|
||||
elif isinstance(value, dict):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
for k, v in value.items():
|
||||
assert meta.map_types
|
||||
sk = _serialize_single(1, meta.map_types[0], k)
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
if value == get_default(meta.proto_type) and not selected_in_group:
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
serialize_empty = True
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
meta.number,
|
||||
meta.proto_type,
|
||||
value,
|
||||
serialize_empty=serialize_empty,
|
||||
wraps=meta.wraps or "",
|
||||
)
|
||||
|
||||
return output + self._unknown_fields
|
||||
@@ -500,30 +543,45 @@ class Message(ABC):
|
||||
# For compatibility with other libraries
|
||||
SerializeToString = __bytes__
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
def _type_hint(self, field_name: str) -> Type:
|
||||
module = inspect.getmodule(self.__class__)
|
||||
type_hints = get_type_hints(self.__class__, vars(module))
|
||||
cls = type_hints[field.name]
|
||||
return type_hints[field_name]
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
cls = self._type_hint(field.name)
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = type_hints[field.name].__args__[index]
|
||||
cls = cls.__args__[index]
|
||||
return cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
t = self._cls_for(field, index=-1)
|
||||
t = self._type_hint(field.name)
|
||||
|
||||
value: Any = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
value = {}
|
||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||
# Anything else with type args is a list.
|
||||
value = []
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
# Message means creating an instance of the right type.
|
||||
value = t()
|
||||
if hasattr(t, "__origin__"):
|
||||
if t.__origin__ == dict:
|
||||
# This is some kind of map (dict in Python).
|
||||
value = {}
|
||||
elif t.__origin__ == list:
|
||||
# This is some kind of list (repeated) field.
|
||||
value = []
|
||||
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||
# This is an optional (wrapped) field. For setting the default we
|
||||
# really don't care what kind of field it is.
|
||||
value = None
|
||||
else:
|
||||
value = t()
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
value = 0
|
||||
elif t == datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
value = DATETIME_ZERO
|
||||
else:
|
||||
value = get_default(meta.proto_type)
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
value = t()
|
||||
|
||||
return value
|
||||
|
||||
@@ -540,6 +598,9 @@ class Message(ABC):
|
||||
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif meta.proto_type == TYPE_BOOL:
|
||||
# Booleans use a varint encoding, so convert it to true/false.
|
||||
value = value > 0
|
||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
@@ -548,8 +609,18 @@ class Message(ABC):
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
elif cls == timedelta:
|
||||
value = _Duration().parse(value).to_timedelta()
|
||||
elif meta.wraps:
|
||||
# This is a Google wrapper value message around a single
|
||||
# scalar type.
|
||||
value = _get_wrapper(meta.wraps)().parse(value).value
|
||||
else:
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
@@ -624,48 +695,59 @@ class Message(ABC):
|
||||
def FromString(cls: Type[T], data: bytes) -> T:
|
||||
return cls().parse(data)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
|
||||
"""
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON.
|
||||
used to serialize to e.g. JSON. Defaults to camel casing for
|
||||
compatibility but can be set to other modes.
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, list):
|
||||
if isinstance(v, datetime):
|
||||
if v != DATETIME_ZERO:
|
||||
output[cased_name] = _Timestamp.timestamp_to_json(v)
|
||||
elif isinstance(v, timedelta):
|
||||
if v != timedelta(0):
|
||||
output[cased_name] = _Duration.delta_to_json(v)
|
||||
elif meta.wraps:
|
||||
if v is not None:
|
||||
output[cased_name] = v
|
||||
elif isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
output[field.name] = v
|
||||
output[cased_name] = v
|
||||
elif v._serialized_on_wire:
|
||||
output[field.name] = v.to_dict()
|
||||
output[cased_name] = v.to_dict()
|
||||
elif meta.proto_type == "map":
|
||||
for k in v:
|
||||
if hasattr(v[k], "to_dict"):
|
||||
v[k] = v[k].to_dict()
|
||||
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != get_default(meta.proto_type):
|
||||
output[cased_name] = v
|
||||
elif v != self._get_field_default(field, meta):
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [str(n) for n in v]
|
||||
output[cased_name] = [str(n) for n in v]
|
||||
else:
|
||||
output[field.name] = str(v)
|
||||
output[cased_name] = str(v)
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [b64encode(b).decode("utf8") for b in v]
|
||||
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
|
||||
else:
|
||||
output[field.name] = b64encode(v).decode("utf8")
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(self._cls_for(field))
|
||||
enum_values = list(self._cls_for(field)) # type: ignore
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [enum_values[e].name for e in v]
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
else:
|
||||
output[field.name] = enum_values[v].name
|
||||
output[cased_name] = enum_values[v].name
|
||||
else:
|
||||
output[field.name] = v
|
||||
output[cased_name] = v
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: dict) -> T:
|
||||
@@ -674,44 +756,58 @@ class Message(ABC):
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
self._serialized_on_wire = True
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value and value[field.name] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
# print(v, value[field.name])
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[field.name])):
|
||||
v.append(cls().from_dict(value[field.name][i]))
|
||||
else:
|
||||
v.from_dict(value[field.name])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[field.name]:
|
||||
v[k] = cls().from_dict(value[field.name][k])
|
||||
else:
|
||||
v = value[field.name]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [int(n) for n in value[field.name]]
|
||||
else:
|
||||
v = int(value[field.name])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [b64decode(n) for n in value[field.name]]
|
||||
else:
|
||||
v = b64decode(value[field.name])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
||||
for key in value:
|
||||
snake_cased = safe_snake_case(key)
|
||||
if snake_cased in fields_by_name:
|
||||
field = fields_by_name[snake_cased]
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(
|
||||
value[key].replace("Z", "+00:00")
|
||||
)
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field.name, value[key])
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
return self
|
||||
|
||||
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
||||
@@ -743,6 +839,140 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
return (field.name, getattr(message, field.name))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Duration(Message):
|
||||
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
||||
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
||||
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
||||
seconds: int = int64_field(1)
|
||||
# Signed fractions of a second at nanosecond resolution of the span of time.
|
||||
# Durations less than one second are represented with a 0 `seconds` field and
|
||||
# a positive or negative `nanos` field. For durations of one second or more,
|
||||
# a non-zero value for the `nanos` field must be of the same sign as the
|
||||
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_timedelta(self) -> timedelta:
|
||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||
|
||||
@staticmethod
|
||||
def delta_to_json(delta: timedelta) -> str:
|
||||
parts = str(delta.total_seconds()).split(".")
|
||||
if len(parts) > 1:
|
||||
while len(parts[1]) not in [3, 6, 9]:
|
||||
parts[1] = parts[1] + "0"
|
||||
return ".".join(parts) + "s"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Timestamp(Message):
|
||||
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
||||
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
||||
seconds: int = int64_field(1)
|
||||
# Non-negative fractions of a second at nanosecond resolution. Negative
|
||||
# second values with fractions must still have non-negative nanos values that
|
||||
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_json(dt: datetime) -> str:
|
||||
nanos = dt.microsecond * 1e3
|
||||
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||
result = copy.isoformat()
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + "Z"
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + ".%03dZ" % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + ".%06dZ" % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + ".%09dZ" % nanos
|
||||
|
||||
|
||||
class _WrappedMessage(Message):
|
||||
"""
|
||||
Google protobuf wrapper types base class. JSON representation is just the
|
||||
value itself.
|
||||
"""
|
||||
value: Any
|
||||
|
||||
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
|
||||
return self.value
|
||||
|
||||
def from_dict(self: T, value: Any) -> T:
|
||||
if value is not None:
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BoolValue(_WrappedMessage):
|
||||
value: bool = bool_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int32Value(_WrappedMessage):
|
||||
value: int = int32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt32Value(_WrappedMessage):
|
||||
value: int = uint32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int64Value(_WrappedMessage):
|
||||
value: int = int64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt64Value(_WrappedMessage):
|
||||
value: int = uint64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FloatValue(_WrappedMessage):
|
||||
value: float = float_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DoubleValue(_WrappedMessage):
|
||||
value: float = double_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _StringValue(_WrappedMessage):
|
||||
value: str = string_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BytesValue(_WrappedMessage):
|
||||
value: bytes = bytes_field(1)
|
||||
|
||||
|
||||
def _get_wrapper(proto_type: str) -> Type:
|
||||
"""Get the wrapper message class for a wrapped type."""
|
||||
return {
|
||||
TYPE_BOOL: _BoolValue,
|
||||
TYPE_INT32: _Int32Value,
|
||||
TYPE_UINT32: _UInt32Value,
|
||||
TYPE_INT64: _Int64Value,
|
||||
TYPE_UINT64: _UInt64Value,
|
||||
TYPE_FLOAT: _FloatValue,
|
||||
TYPE_DOUBLE: _DoubleValue,
|
||||
TYPE_STRING: _StringValue,
|
||||
TYPE_BYTES: _BytesValue,
|
||||
}[proto_type]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
41
betterproto/casing.py
Normal file
41
betterproto/casing.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import stringcase
|
||||
|
||||
|
||||
def safe_snake_case(value: str) -> str:
|
||||
"""Snake case a value taking into account Python keywords."""
|
||||
value = stringcase.snakecase(value)
|
||||
if value in [
|
||||
"and",
|
||||
"as",
|
||||
"assert",
|
||||
"break",
|
||||
"class",
|
||||
"continue",
|
||||
"def",
|
||||
"del",
|
||||
"elif",
|
||||
"else",
|
||||
"except",
|
||||
"finally",
|
||||
"for",
|
||||
"from",
|
||||
"global",
|
||||
"if",
|
||||
"import",
|
||||
"in",
|
||||
"is",
|
||||
"lambda",
|
||||
"nonlocal",
|
||||
"not",
|
||||
"or",
|
||||
"pass",
|
||||
"raise",
|
||||
"return",
|
||||
"try",
|
||||
"while",
|
||||
"with",
|
||||
"yield",
|
||||
]:
|
||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||
value += "_"
|
||||
return value
|
@@ -16,6 +16,8 @@ except ImportError:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
import stringcase
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
@@ -25,11 +27,20 @@ from google.protobuf.descriptor_pb2 import (
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
def snake_case(value: str) -> str:
|
||||
return (
|
||||
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
|
||||
)
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
"google.protobuf.DoubleValue": "float",
|
||||
"google.protobuf.FloatValue": "float",
|
||||
"google.protobuf.Int64Value": "int",
|
||||
"google.protobuf.UInt64Value": "int",
|
||||
"google.protobuf.Int32Value": "int",
|
||||
"google.protobuf.UInt32Value": "int",
|
||||
"google.protobuf.BoolValue": "bool",
|
||||
"google.protobuf.StringValue": "str",
|
||||
"google.protobuf.BytesValue": "bytes",
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
@@ -37,15 +48,33 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
if type_name in WRAPPER_TYPES:
|
||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
# This is the current package, which has nested types flattened.
|
||||
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
parts = type_name.split(".")
|
||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||
|
||||
@@ -122,7 +151,7 @@ def get_comment(proto_file, path: List[int]) -> str:
|
||||
|
||||
if path[-2] == 2 and path[-4] != 6:
|
||||
# This is a field
|
||||
return " # " + " # ".join(lines)
|
||||
return " # " + "\n # ".join(lines)
|
||||
else:
|
||||
# This is a message, enum, service, or method
|
||||
if len(lines) == 1 and len(lines[0]) < 70:
|
||||
@@ -146,6 +175,9 @@ def generate_code(request, response):
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
out = proto_file.package
|
||||
if out == "google.protobuf":
|
||||
continue
|
||||
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
|
||||
@@ -163,6 +195,7 @@ def generate_code(request, response):
|
||||
"package": package,
|
||||
"files": [f.name for f in options["files"]],
|
||||
"imports": set(),
|
||||
"datetime_imports": set(),
|
||||
"typing_imports": set(),
|
||||
"messages": [],
|
||||
"enums": [],
|
||||
@@ -179,7 +212,7 @@ def generate_code(request, response):
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name}
|
||||
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
@@ -203,6 +236,14 @@ def generate_code(request, response):
|
||||
packed = False
|
||||
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
|
||||
field_wraps = ""
|
||||
if f.type_name.startswith(
|
||||
".google.protobuf"
|
||||
) and f.type_name.endswith("Value"):
|
||||
w = f.type_name.split(".").pop()[:-5].upper()
|
||||
field_wraps = f"betterproto.TYPE_{w}"
|
||||
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
@@ -252,13 +293,23 @@ def generate_code(request, response):
|
||||
if f.HasField("oneof_index"):
|
||||
one_of = item.oneof_decl[f.oneof_index].name
|
||||
|
||||
if "Optional[" in t:
|
||||
output["typing_imports"].add("Optional")
|
||||
|
||||
if "timedelta" in t:
|
||||
output["datetime_imports"].add("timedelta")
|
||||
elif "datetime" in t:
|
||||
output["datetime_imports"].add("datetime")
|
||||
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
"py_name": safe_snake_case(f.name),
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": field_type,
|
||||
"field_wraps": field_wraps,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"zero": zero,
|
||||
@@ -294,6 +345,7 @@ def generate_code(request, response):
|
||||
|
||||
data = {
|
||||
"name": service.name,
|
||||
"py_name": stringcase.pascalcase(service.name),
|
||||
"comment": get_comment(proto_file, [6, i]),
|
||||
"methods": [],
|
||||
}
|
||||
@@ -317,7 +369,7 @@ def generate_code(request, response):
|
||||
data["methods"].append(
|
||||
{
|
||||
"name": method.name,
|
||||
"py_name": snake_case(method.name),
|
||||
"py_name": stringcase.snakecase(method.name),
|
||||
"comment": get_comment(proto_file, [6, i, 2, j]),
|
||||
"route": f"/{package}.{service.name}/{method.name}",
|
||||
"input": get_ref_type(
|
||||
@@ -338,6 +390,7 @@ def generate_code(request, response):
|
||||
output["services"].append(data)
|
||||
|
||||
output["imports"] = sorted(output["imports"])
|
||||
output["datetime_imports"] = sorted(output["datetime_imports"])
|
||||
output["typing_imports"] = sorted(output["typing_imports"])
|
||||
|
||||
# Fill response
|
||||
@@ -361,10 +414,20 @@ def generate_code(request, response):
|
||||
inits.add(base)
|
||||
|
||||
for base in inits:
|
||||
name = os.path.join(base, "__init__.py")
|
||||
|
||||
if os.path.exists(name):
|
||||
# Never overwrite inits as they may have custom stuff in them.
|
||||
continue
|
||||
|
||||
init = response.file.add()
|
||||
init.name = os.path.join(base, "__init__.py")
|
||||
init.name = name
|
||||
init.content = b""
|
||||
|
||||
filenames = sorted([f.name for f in response.file])
|
||||
for fname in filenames:
|
||||
print(f"Writing {fname}", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
"""The plugin's main entry point."""
|
||||
|
@@ -2,6 +2,10 @@
|
||||
# sources: {{ ', '.join(description.files) }}
|
||||
# plugin: python-betterproto
|
||||
from dataclasses import dataclass
|
||||
{% if description.datetime_imports %}
|
||||
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% if description.typing_imports %}
|
||||
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
@@ -18,7 +22,7 @@ import grpclib
|
||||
|
||||
|
||||
{% if description.enums %}{% for enum in description.enums %}
|
||||
class {{ enum.name }}(betterproto.Enum):
|
||||
class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% if enum.comment %}
|
||||
{{ enum.comment }}
|
||||
|
||||
@@ -35,7 +39,7 @@ class {{ enum.name }}(betterproto.Enum):
|
||||
{% endif %}
|
||||
{% for message in description.messages %}
|
||||
@dataclass
|
||||
class {{ message.name }}(betterproto.Message):
|
||||
class {{ message.py_name }}(betterproto.Message):
|
||||
{% if message.comment %}
|
||||
{{ message.comment }}
|
||||
|
||||
@@ -44,7 +48,7 @@ class {{ message.name }}(betterproto.Message):
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
|
||||
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
|
||||
{% endfor %}
|
||||
{% if not message.properties %}
|
||||
pass
|
||||
@@ -53,13 +57,13 @@ class {{ message.name }}(betterproto.Message):
|
||||
|
||||
{% endfor %}
|
||||
{% for service in description.services %}
|
||||
class {{ service.name }}Stub(betterproto.ServiceStub):
|
||||
class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% if service.comment %}
|
||||
{{ service.comment }}
|
||||
|
||||
{% endif %}
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
|
3
betterproto/tests/bool.json
Normal file
3
betterproto/tests/bool.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"value": true
|
||||
}
|
5
betterproto/tests/bool.proto
Normal file
5
betterproto/tests/bool.proto
Normal file
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
bool value = 1;
|
||||
}
|
4
betterproto/tests/casing.json
Normal file
4
betterproto/tests/casing.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"camelCase": 1,
|
||||
"snakeCase": "ONE"
|
||||
}
|
12
betterproto/tests/casing.proto
Normal file
12
betterproto/tests/casing.proto
Normal file
@@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
enum my_enum {
|
||||
ZERO = 0;
|
||||
ONE = 1;
|
||||
TWO = 2;
|
||||
}
|
||||
|
||||
message Test {
|
||||
int32 camelCase = 1;
|
||||
my_enum snake_case = 2;
|
||||
}
|
@@ -72,7 +72,8 @@ if __name__ == "__main__":
|
||||
input_json = open(filename).read()
|
||||
parsed = Parse(input_json, imported.Test())
|
||||
serialized = parsed.SerializeToString()
|
||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=True)
|
||||
preserve = "casing" not in filename
|
||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
|
||||
|
||||
s_loaded = json.loads(serialized_json)
|
||||
in_loaded = json.loads(input_json)
|
||||
|
1
betterproto/tests/googletypes-missing.json
Normal file
1
betterproto/tests/googletypes-missing.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
5
betterproto/tests/googletypes.json
Normal file
5
betterproto/tests/googletypes.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"maybe": false,
|
||||
"ts": "1972-01-01T10:00:20.021Z",
|
||||
"duration": "1.200s"
|
||||
}
|
12
betterproto/tests/googletypes.proto
Normal file
12
betterproto/tests/googletypes.proto
Normal file
@@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.BoolValue maybe = 1;
|
||||
google.protobuf.Timestamp ts = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
google.protobuf.Int32Value important = 4;
|
||||
}
|
5
betterproto/tests/keywords.json
Normal file
5
betterproto/tests/keywords.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"for": 1,
|
||||
"with": 2,
|
||||
"as": 3
|
||||
}
|
7
betterproto/tests/keywords.proto
Normal file
7
betterproto/tests/keywords.proto
Normal file
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
int32 for = 1;
|
||||
int32 with = 2;
|
||||
int32 as = 3;
|
||||
}
|
@@ -1,5 +1,6 @@
|
||||
import betterproto
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def test_has_field():
|
||||
@@ -115,3 +116,49 @@ def test_oneof_support():
|
||||
assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
|
||||
assert foo.bar == 0
|
||||
assert betterproto.which_one_of(foo2, "group2")[0] == ""
|
||||
|
||||
|
||||
def test_json_casing():
|
||||
@dataclass
|
||||
class CasingTest(betterproto.Message):
|
||||
pascal_case: int = betterproto.int32_field(1)
|
||||
camel_case: int = betterproto.int32_field(2)
|
||||
snake_case: int = betterproto.int32_field(3)
|
||||
kabob_case: int = betterproto.int32_field(4)
|
||||
|
||||
# Parsing should accept almost any input
|
||||
test = CasingTest().from_dict(
|
||||
{"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
|
||||
)
|
||||
|
||||
assert test == CasingTest(1, 2, 3, 4)
|
||||
|
||||
# Serializing should be strict.
|
||||
assert test.to_dict() == {
|
||||
"pascalCase": 1,
|
||||
"camelCase": 2,
|
||||
"snakeCase": 3,
|
||||
"kabobCase": 4,
|
||||
}
|
||||
|
||||
assert test.to_dict(casing=betterproto.Casing.SNAKE) == {
|
||||
"pascal_case": 1,
|
||||
"camel_case": 2,
|
||||
"snake_case": 3,
|
||||
"kabob_case": 4,
|
||||
}
|
||||
|
||||
|
||||
def test_optional_flag():
|
||||
@dataclass
|
||||
class Request(betterproto.Message):
|
||||
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
|
||||
|
||||
# Serialization of not passed vs. set vs. zero-value.
|
||||
assert bytes(Request()) == b""
|
||||
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
|
||||
assert bytes(Request(flag=False)) == b"\n\x00"
|
||||
|
||||
# Differentiate between not passed and the zero-value.
|
||||
assert Request().parse(b"").flag == None
|
||||
assert Request().parse(b"\n\x00").flag == False
|
||||
|
6
setup.py
6
setup.py
@@ -2,8 +2,10 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="betterproto",
|
||||
version="1.0",
|
||||
version="1.1.0",
|
||||
description="A better Protobuf / gRPC generator & library",
|
||||
long_description=open("README.md", "r").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="http://github.com/danielgtaylor/python-betterproto",
|
||||
author="Daniel G. Taylor",
|
||||
author_email="danielgtaylor@gmail.com",
|
||||
@@ -16,7 +18,7 @@ setup(
|
||||
),
|
||||
package_data={"betterproto": ["py.typed", "templates/template.py"]},
|
||||
python_requires=">=3.7",
|
||||
install_requires=["grpclib"],
|
||||
install_requires=["grpclib", "stringcase"],
|
||||
extras_require={"compiler": ["jinja2", "protobuf"]},
|
||||
zip_safe=False,
|
||||
)
|
||||
|
Reference in New Issue
Block a user