17 Commits

Author SHA1 Message Date
Daniel G. Taylor
6fd9612ee1 Doc updates, version bump for release 2019-10-27 15:43:52 -07:00
Daniel G. Taylor
ba520f88a4 Install Protobuf include files on CI host 2019-10-27 15:40:33 -07:00
Daniel G. Taylor
b0b64fcbaf Fix tests attempt 3 2019-10-27 15:29:04 -07:00
Daniel G. Taylor
7900c7c9db Fix tests 2019-10-27 15:21:20 -07:00
Daniel G. Taylor
fcc273e294 Fix tests 2019-10-27 15:18:10 -07:00
Daniel G. Taylor
f820397751 Add missing optional types test 2019-10-27 15:14:06 -07:00
Daniel G. Taylor
16687211a2 Typing fixes 2019-10-27 15:13:51 -07:00
Daniel G. Taylor
eb5020db2a Fix bool parsing bug 2019-10-27 14:59:38 -07:00
Daniel G. Taylor
035793aec3 Support wrapper types 2019-10-27 14:55:25 -07:00
Daniel G. Taylor
c79535b614 Support Duration/Timestamp Google well-known types 2019-10-26 23:07:30 -07:00
Daniel G. Taylor
5daf61f64c Refactor default value code 2019-10-25 21:16:32 -07:00
Daniel G. Taylor
4679c571c3 Fix comment newlines 2019-10-25 12:28:26 -07:00
Daniel G. Taylor
ff8463cf12 Handle fields that clash with Python reserved keywords 2019-10-23 21:28:31 -07:00
Daniel G. Taylor
eff9021529 Some informational output from the plugin, do not overwrite __init__.py 2019-10-23 15:07:05 -07:00
Daniel G. Taylor
d43d5af5ce Better JSON casing support, renaming messages/fields 2019-10-23 15:06:34 -07:00
Daniel G. Taylor
ef0a1bf50c Use specific version of pypi publish image 2019-10-23 15:03:13 -07:00
Daniel G. Taylor
0e389abbef Add Python package long description 2019-10-22 21:31:42 -07:00
20 changed files with 693 additions and 168 deletions

View File

@@ -4,7 +4,6 @@ on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
@@ -15,9 +14,20 @@ jobs:
- uses: dschep/install-pipenv-action@v1
- name: Install dependencies
run: |
sudo apt install protobuf-compiler
sudo apt install protobuf-compiler libprotobuf-dev
pipenv install --dev
- name: Run tests
run: |
cp .env.default .env
pipenv run pip install -e .
pipenv run generate
pipenv run test
- name: Build package
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
run: pipenv run python setup.py sdist
- name: Publish package
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@v1.0.0a0
with:
user: __token__
password: ${{ secrets.pypi }}

View File

@@ -14,6 +14,7 @@ rope = "*"
protobuf = "*"
jinja2 = "*"
grpclib = "*"
stringcase = "*"
[requires]
python_version = "3.7"

52
Pipfile.lock generated
View File

@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40"
"sha256": "28c38cd6c4eafb0b9ac9a64cf623145868fdee163111d3b941b34d23011db6ca"
},
"pipfile-spec": 6,
"requires": {
@@ -147,6 +147,13 @@
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
],
"version": "==1.12.0"
},
"stringcase": {
"hashes": [
"sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"
],
"index": "pypi",
"version": "==1.2.0"
}
},
"develop": {
@@ -159,10 +166,10 @@
},
"attrs": {
"hashes": [
"sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2",
"sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
"sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"
],
"version": "==19.2.0"
"version": "==19.3.0"
},
"entrypoints": {
"hashes": [
@@ -211,26 +218,30 @@
},
"mypy": {
"hashes": [
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
"sha256:1521c186a3d200c399bd5573c828ea2db1362af7209b2adb1bb8532cea2fb36f",
"sha256:31a046ab040a84a0fc38bc93694876398e62bc9f35eca8ccbf6418b7297f4c00",
"sha256:3b1a411909c84b2ae9b8283b58b48541654b918e8513c20a400bb946aa9111ae",
"sha256:48c8bc99380575deb39f5d3400ebb6a8a1cb5cc669bbba4d3bb30f904e0a0e7d",
"sha256:540c9caa57a22d0d5d3c69047cc9dd0094d49782603eb03069821b41f9e970e9",
"sha256:672e418425d957e276c291930a3921b4a6413204f53fe7c37cad7bc57b9a3391",
"sha256:6ed3b9b3fdc7193ea7aca6f3c20549b377a56f28769783a8f27191903a54170f",
"sha256:9371290aa2cad5ad133e4cdc43892778efd13293406f7340b9ffe99d5ec7c1d9",
"sha256:ace6ac1d0f87d4072f05b5468a084a45b4eda970e4d26704f201e06d47ab2990",
"sha256:b428f883d2b3fe1d052c630642cc6afddd07d5cd7873da948644508be3b9d4a7",
"sha256:d5bf0e6ec8ba346a2cf35cb55bf4adfddbc6b6576fcc9e10863daa523e418dbb",
"sha256:d7574e283f83c08501607586b3167728c58e8442947e027d2d4c7dcd6d82f453",
"sha256:dc889c84241a857c263a2b1cd1121507db7d5b5f5e87e77147097230f374d10b",
"sha256:f4748697b349f373002656bf32fede706a0e713d67bfdcf04edf39b1f61d46eb"
],
"index": "pypi",
"version": "==0.730"
"version": "==0.740"
},
"mypy-extensions": {
"hashes": [
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
"sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d",
"sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"
],
"version": "==0.4.2"
"version": "==0.4.3"
},
"packaging": {
"hashes": [
@@ -300,20 +311,25 @@
},
"typed-ast": {
"hashes": [
"sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161",
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
"sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47",
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
"sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2",
"sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e",
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
"sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66",
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
],
"version": "==1.4.0"

View File

@@ -10,6 +10,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
- Enums
- Dataclasses
- `async`/`await`
- Timezone-aware `datetime` and `timedelta` objects
- Relative imports
- Mypy type checking
@@ -34,6 +35,8 @@ This project exists because I am unhappy with the state of the official Google p
- Much code looks like C++ or Java ported 1:1 to Python
- Capitalized function names like `HasField()` and `SerializeToString()`
- Uses `SerializeToString()` rather than the built-in `__bytes__()`
- Special wrapped types don't use Python's `None`
- Timestamp/duration types don't use Python's built-in `datetime` module
This project is a reimplementation from the ground up focused on idiomatic modern Python to help fix some of the above. While it may not be a 1:1 drop-in replacement due to changed method names and call patterns, the wire format is identical.
@@ -168,6 +171,12 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
- Dicts: `Message().to_dict()`, `Message().from_dict(...)`
- JSON: `Message().to_json()`, `Message().from_json(...)`
For compatibility the default is to convert field names to `camelCase`. You can control this behavior by passing a casing value, e.g:
```py
>>> MyMessage().to_dict(casing=betterproto.Casing.SNAKE)
```
### Determining if a message was sent
Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
@@ -238,6 +247,53 @@ Again this is a little different than the official Google code generator:
["foo", "foo's value"]
```
### Well-Known Google Types
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
| Google Message | Python Type | Default |
| --------------------------- | ---------------------------------------- | ---------------------- |
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
| `google.protobuf.*Value` | `Optional[...]` | `None` |
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
```protobuf
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
}
```
You can do stuff like:
```py
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
>>> t
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
>>> t.ts - t.duration
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
>>> t.ts.isoformat()
'2019-01-01T12:00:00+00:00'
>>> t.maybe = None
>>> t.to_dict()
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```
## Development
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
@@ -295,14 +351,14 @@ $ pipenv run tests
- [x] Bytes as base64
- [ ] Any support
- [x] Enum strings
- [ ] Well known types support (timestamp, duration, wrappers)
- [ ] Support different casing (orig vs. camel vs. others?)
- [x] Well known types support (timestamp, duration, wrappers)
- [x] Support different casing (orig vs. camel vs. others?)
- [ ] Async service stubs
- [x] Unary-unary
- [x] Server streaming response
- [ ] Client streaming request
- [ ] Renaming messages and fields to conform to Python name standards
- [ ] Renaming clashes with language keywords and standard library top-level packages
- [x] Renaming messages and fields to conform to Python name standards
- [x] Renaming clashes with language keywords
- [x] Python package
- [x] Automate running tests
- [ ] Cleanup!

View File

@@ -5,6 +5,7 @@ import json
import struct
from abc import ABC
from base64 import b64encode, b64decode
from datetime import datetime, timedelta, timezone
from typing import (
Any,
AsyncGenerator,
@@ -24,6 +25,9 @@ from typing import (
import grpclib.client
import grpclib.const
import stringcase
from .casing import safe_snake_case
# Proto 3 data types
TYPE_ENUM = "enum"
@@ -101,6 +105,17 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
class Casing(enum.Enum):
"""Casing constants for serialization."""
CAMEL = stringcase.camelcase
SNAKE = stringcase.snakecase
class _PLACEHOLDER:
pass
@@ -108,18 +123,6 @@ class _PLACEHOLDER:
PLACEHOLDER: Any = _PLACEHOLDER()
def get_default(proto_type: str) -> Any:
"""Get the default (zero value) for a given type."""
return {
TYPE_BOOL: False,
TYPE_FLOAT: 0.0,
TYPE_DOUBLE: 0.0,
TYPE_STRING: "",
TYPE_BYTES: b"",
TYPE_MAP: {},
}.get(proto_type, 0)
@dataclasses.dataclass(frozen=True)
class FieldMetadata:
"""Stores internal metadata used for parsing & serialization."""
@@ -129,9 +132,11 @@ class FieldMetadata:
# Protobuf type name
proto_type: str
# Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]]
map_types: Optional[Tuple[str, str]] = None
# Groups several "one-of" fields together
group: Optional[str]
group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
@staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata":
@@ -145,11 +150,14 @@ def dataclass_field(
*,
map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None,
wraps: Optional[str] = None,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
},
)
@@ -222,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group)
def message_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group)
def message_field(
number: int, group: Optional[str] = None, wraps: Optional[str] = None
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
def map_field(
@@ -274,7 +284,7 @@ def encode_varint(value: int) -> bytes:
return bytes(b + [bits])
def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
TYPE_ENUM,
@@ -297,16 +307,37 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
elif proto_type == TYPE_STRING:
return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)
return bytes(value)
return value
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
value = _preprocess_single(proto_type, wraps, value)
output = b""
if proto_type in WIRE_VARINT_TYPES:
@@ -319,7 +350,7 @@ def _serialize_single(
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty:
if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@@ -359,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
while i < len(value):
start = i
num_wire, i = decode_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3
wire_type = num_wire & 0x7
@@ -375,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
elif wire_type == 5:
decoded, i = value[i : i + 4], i + 4
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
)
@@ -392,15 +420,19 @@ class Message(ABC):
register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations.
"""
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
group_map = {"fields": {}, "groups": {}}
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if meta.group:
# This is part of a one-of group.
group_map["fields"][field.name] = meta.group
if meta.group not in group_map["groups"]:
@@ -450,6 +482,11 @@ class Message(ABC):
meta = FieldMetadata.get(field)
value = getattr(self, field.name)
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
continue
# Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
@@ -457,42 +494,48 @@ class Message(ABC):
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
selected_in_group = True
if isinstance(value, list):
if not len(value) and not selected_in_group:
# Empty values are not serialized
serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
# Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = True
if value == self._get_field_default(field, meta) and not (
selected_in_group or serialize_empty
):
# Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to
# serialize an empty message (i.e. zero value was explicitly
# set by the user).
continue
if isinstance(value, list):
if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = b""
for item in value:
buf += _preprocess_single(meta.proto_type, item)
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(meta.number, meta.proto_type, item)
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
)
elif isinstance(value, dict):
if not len(value) and not selected_in_group:
# Empty values are not serialized
continue
for k, v in value.items():
assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
if value == get_default(meta.proto_type) and not selected_in_group:
# Default (zero) values are not serialized
continue
serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
serialize_empty = True
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps or "",
)
return output + self._unknown_fields
@@ -500,30 +543,45 @@ class Message(ABC):
# For compatibility with other libraries
SerializeToString = __bytes__
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
def _type_hint(self, field_name: str) -> Type:
module = inspect.getmodule(self.__class__)
type_hints = get_type_hints(self.__class__, vars(module))
cls = type_hints[field.name]
return type_hints[field_name]
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
cls = self._type_hint(field.name)
if hasattr(cls, "__args__") and index >= 0:
cls = type_hints[field.name].__args__[index]
cls = cls.__args__[index]
return cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
t = self._cls_for(field, index=-1)
t = self._type_hint(field.name)
value: Any = 0
if meta.proto_type == TYPE_MAP:
# Maps cannot be repeated, so we check these first.
if hasattr(t, "__origin__"):
if t.__origin__ == dict:
# This is some kind of map (dict in Python).
value = {}
elif hasattr(t, "__args__") and len(t.__args__) == 1:
# Anything else with type args is a list.
elif t.__origin__ == list:
# This is some kind of list (repeated) field.
value = []
elif meta.proto_type == TYPE_MESSAGE:
# Message means creating an instance of the right type.
value = t()
elif t.__origin__ == Union and t.__args__[1] == type(None):
# This is an optional (wrapped) field. For setting the default we
# really don't care what kind of field it is.
value = None
else:
value = get_default(meta.proto_type)
value = t()
elif issubclass(t, Enum):
# Enums always default to zero.
value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = DATETIME_ZERO
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
value = t()
return value
@@ -540,6 +598,9 @@ class Message(ABC):
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
@@ -548,6 +609,16 @@ class Message(ABC):
value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
elif meta.wraps:
# This is a Google wrapper value message around a single
# scalar type.
value = _get_wrapper(meta.wraps)().parse(value).value
else:
value = cls().parse(value)
value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
@@ -624,48 +695,59 @@ class Message(ABC):
def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data)
def to_dict(self) -> dict:
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
"""
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
used to serialize to e.g. JSON. Defaults to camel casing for
compatibility but can be set to other modes.
"""
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, list):
if isinstance(v, datetime):
if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0):
output[cased_name] = _Duration.delta_to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
elif isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
output[field.name] = v
output[cased_name] = v
elif v._serialized_on_wire:
output[field.name] = v.to_dict()
output[cased_name] = v.to_dict()
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict()
if v:
output[field.name] = v
elif v != get_default(meta.proto_type):
output[cased_name] = v
elif v != self._get_field_default(field, meta):
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
output[cased_name] = [str(n) for n in v]
else:
output[field.name] = str(v)
output[cased_name] = str(v)
elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list):
output[field.name] = [b64encode(b).decode("utf8") for b in v]
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
else:
output[field.name] = b64encode(v).decode("utf8")
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field))
enum_values = list(self._cls_for(field)) # type: ignore
if isinstance(v, list):
output[field.name] = [enum_values[e].name for e in v]
output[cased_name] = [enum_values[e].name for e in v]
else:
output[field.name] = enum_values[v].name
output[cased_name] = enum_values[v].name
else:
output[field.name] = v
output[cased_name] = v
return output
def from_dict(self: T, value: dict) -> T:
@@ -674,35 +756,49 @@ class Message(ABC):
returns the instance itself and is therefore assignable and chainable.
"""
self._serialized_on_wire = True
for field in dataclasses.fields(self):
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value:
snake_cased = safe_snake_case(key)
if snake_cased in fields_by_name:
field = fields_by_name[snake_cased]
meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None:
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[field.name])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[field.name])):
v.append(cls().from_dict(value[field.name][i]))
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[field.name])
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[field.name]
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[field.name])
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[field.name], list):
v = [b64decode(n) for n in value[field.name]]
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[field.name])
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
if isinstance(v, list):
@@ -743,6 +839,140 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field.name, getattr(message, field.name))
@dataclasses.dataclass
class _Duration(Message):
# Signed seconds of the span of time. Must be from -315,576,000,000 to
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
seconds: int = int64_field(1)
# Signed fractions of a second at nanosecond resolution of the span of time.
# Durations less than one second are represented with a 0 `seconds` field and
# a positive or negative `nanos` field. For durations of one second or more,
# a non-zero value for the `nanos` field must be of the same sign as the
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
nanos: int = int32_field(2)
def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod
def delta_to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
@dataclasses.dataclass
class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
seconds: int = int64_field(1)
# Non-negative fractions of a second at nanosecond resolution. Negative
# second values with fractions must still have non-negative nanos values that
# count forward in time. Must be from 0 to 999,999,999 inclusive.
nanos: int = int32_field(2)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def timestamp_to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + ".%09dZ" % nanos
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
return self.value
def from_dict(self: T, value: Any) -> T:
if value is not None:
self.value = value
return self
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,
TYPE_INT32: _Int32Value,
TYPE_UINT32: _UInt32Value,
TYPE_INT64: _Int64Value,
TYPE_UINT64: _UInt64Value,
TYPE_FLOAT: _FloatValue,
TYPE_DOUBLE: _DoubleValue,
TYPE_STRING: _StringValue,
TYPE_BYTES: _BytesValue,
}[proto_type]
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.

41
betterproto/casing.py Normal file
View File

@@ -0,0 +1,41 @@
import stringcase
def safe_snake_case(value: str) -> str:
"""Snake case a value taking into account Python keywords."""
value = stringcase.snakecase(value)
if value in [
"and",
"as",
"assert",
"break",
"class",
"continue",
"def",
"del",
"elif",
"else",
"except",
"finally",
"for",
"from",
"global",
"if",
"import",
"in",
"is",
"lambda",
"nonlocal",
"not",
"or",
"pass",
"raise",
"return",
"try",
"while",
"with",
"yield",
]:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
value += "_"
return value

View File

@@ -16,6 +16,8 @@ except ImportError:
)
raise SystemExit(1)
import stringcase
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
@@ -25,11 +27,20 @@ from google.protobuf.descriptor_pb2 import (
ServiceDescriptorProto,
)
from betterproto.casing import safe_snake_case
def snake_case(value: str) -> str:
return (
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
)
WRAPPER_TYPES = {
"google.protobuf.DoubleValue": "float",
"google.protobuf.FloatValue": "float",
"google.protobuf.Int64Value": "int",
"google.protobuf.UInt64Value": "int",
"google.protobuf.Int32Value": "int",
"google.protobuf.UInt32Value": "int",
"google.protobuf.BoolValue": "bool",
"google.protobuf.StringValue": "str",
"google.protobuf.BytesValue": "bytes",
}
def get_ref_type(package: str, imports: set, type_name: str) -> str:
@@ -37,15 +48,33 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
Return a Python type name for a proto type reference. Adds the import if
necessary.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
if type_name in WRAPPER_TYPES:
return f"Optional[{WRAPPER_TYPES[type_name]}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
@@ -122,7 +151,7 @@ def get_comment(proto_file, path: List[int]) -> str:
if path[-2] == 2 and path[-4] != 6:
# This is a field
return " # " + " # ".join(lines)
return " # " + "\n # ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 70:
@@ -146,6 +175,9 @@ def generate_code(request, response):
output_map = {}
for proto_file in request.proto_file:
out = proto_file.package
if out == "google.protobuf":
continue
if not out:
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
@@ -163,6 +195,7 @@ def generate_code(request, response):
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
@@ -179,7 +212,7 @@ def generate_code(request, response):
for item, path in traverse(proto_file):
# print(item, file=sys.stderr)
# print(path, file=sys.stderr)
data = {"name": item.name}
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
@@ -203,6 +236,14 @@ def generate_code(request, response):
packed = False
field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
if f.type_name.startswith(
".google.protobuf"
) and f.type_name.endswith("Value"):
w = f.type_name.split(".").pop()[:-5].upper()
field_wraps = f"betterproto.TYPE_{w}"
map_types = None
if f.type == 11:
# This might be a map...
@@ -252,13 +293,23 @@ def generate_code(request, response):
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
output["typing_imports"].add("Optional")
if "timedelta" in t:
output["datetime_imports"].add("timedelta")
elif "datetime" in t:
output["datetime_imports"].add("datetime")
data["properties"].append(
{
"name": f.name,
"py_name": safe_snake_case(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
@@ -294,6 +345,7 @@ def generate_code(request, response):
data = {
"name": service.name,
"py_name": stringcase.pascalcase(service.name),
"comment": get_comment(proto_file, [6, i]),
"methods": [],
}
@@ -317,7 +369,7 @@ def generate_code(request, response):
data["methods"].append(
{
"name": method.name,
"py_name": snake_case(method.name),
"py_name": stringcase.snakecase(method.name),
"comment": get_comment(proto_file, [6, i, 2, j]),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type(
@@ -338,6 +390,7 @@ def generate_code(request, response):
output["services"].append(data)
output["imports"] = sorted(output["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"])
# Fill response
@@ -361,10 +414,20 @@ def generate_code(request, response):
inits.add(base)
for base in inits:
name = os.path.join(base, "__init__.py")
if os.path.exists(name):
# Never overwrite inits as they may have custom stuff in them.
continue
init = response.file.add()
init.name = os.path.join(base, "__init__.py")
init.name = name
init.content = b""
filenames = sorted([f.name for f in response.file])
for fname in filenames:
print(f"Writing {fname}", file=sys.stderr)
def main():
"""The plugin's main entry point."""

View File

@@ -2,6 +2,10 @@
# sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto
from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@@ -18,7 +22,7 @@ import grpclib
{% if description.enums %}{% for enum in description.enums %}
class {{ enum.name }}(betterproto.Enum):
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
@@ -35,7 +39,7 @@ class {{ enum.name }}(betterproto.Enum):
{% endif %}
{% for message in description.messages %}
@dataclass
class {{ message.name }}(betterproto.Message):
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
@@ -44,7 +48,7 @@ class {{ message.name }}(betterproto.Message):
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %}
{% if not message.properties %}
pass
@@ -53,13 +57,13 @@ class {{ message.name }}(betterproto.Message):
{% endfor %}
{% for service in description.services %}
class {{ service.name }}Stub(betterproto.ServiceStub):
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %}
{{ method.comment }}

View File

@@ -0,0 +1,3 @@
{
"value": true
}

View File

@@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
bool value = 1;
}

View File

@@ -0,0 +1,4 @@
{
"camelCase": 1,
"snakeCase": "ONE"
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
enum my_enum {
ZERO = 0;
ONE = 1;
TWO = 2;
}
message Test {
int32 camelCase = 1;
my_enum snake_case = 2;
}

View File

@@ -72,7 +72,8 @@ if __name__ == "__main__":
input_json = open(filename).read()
parsed = Parse(input_json, imported.Test())
serialized = parsed.SerializeToString()
serialized_json = MessageToJson(parsed, preserving_proto_field_name=True)
preserve = "casing" not in filename
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
s_loaded = json.loads(serialized_json)
in_loaded = json.loads(input_json)

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1,5 @@
{
"maybe": false,
"ts": "1972-01-01T10:00:20.021Z",
"duration": "1.200s"
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
google.protobuf.Int32Value important = 4;
}

View File

@@ -0,0 +1,5 @@
{
"for": 1,
"with": 2,
"as": 3
}

View File

@@ -0,0 +1,7 @@
syntax = "proto3";
message Test {
int32 for = 1;
int32 with = 2;
int32 as = 3;
}

View File

@@ -1,5 +1,6 @@
import betterproto
from dataclasses import dataclass
from typing import Optional
def test_has_field():
@@ -115,3 +116,49 @@ def test_oneof_support():
assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
assert foo.bar == 0
assert betterproto.which_one_of(foo2, "group2")[0] == ""
def test_json_casing():
@dataclass
class CasingTest(betterproto.Message):
pascal_case: int = betterproto.int32_field(1)
camel_case: int = betterproto.int32_field(2)
snake_case: int = betterproto.int32_field(3)
kabob_case: int = betterproto.int32_field(4)
# Parsing should accept almost any input
test = CasingTest().from_dict(
{"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
)
assert test == CasingTest(1, 2, 3, 4)
# Serializing should be strict.
assert test.to_dict() == {
"pascalCase": 1,
"camelCase": 2,
"snakeCase": 3,
"kabobCase": 4,
}
assert test.to_dict(casing=betterproto.Casing.SNAKE) == {
"pascal_case": 1,
"camel_case": 2,
"snake_case": 3,
"kabob_case": 4,
}
def test_optional_flag():
@dataclass
class Request(betterproto.Message):
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
# Serialization of not passed vs. set vs. zero-value.
assert bytes(Request()) == b""
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
assert bytes(Request(flag=False)) == b"\n\x00"
# Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None
assert Request().parse(b"\n\x00").flag == False

View File

@@ -2,8 +2,10 @@ from setuptools import setup, find_packages
setup(
name="betterproto",
version="1.0",
version="1.1.0",
description="A better Protobuf / gRPC generator & library",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
url="http://github.com/danielgtaylor/python-betterproto",
author="Daniel G. Taylor",
author_email="danielgtaylor@gmail.com",
@@ -16,7 +18,7 @@ setup(
),
package_data={"betterproto": ["py.typed", "templates/template.py"]},
python_requires=">=3.7",
install_requires=["grpclib"],
install_requires=["grpclib", "stringcase"],
extras_require={"compiler": ["jinja2", "protobuf"]},
zip_safe=False,
)