From b5c1f1aa7c25f74778f3609ade2b3e4431e2c774 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Sat, 19 Oct 2019 12:31:22 -0700 Subject: [PATCH] Support JSON base64 bytes and enums as strings --- README.md | 7 +++--- betterproto/__init__.py | 40 +++++++++++++++++++++++++++++- betterproto/templates/template.py | 4 +-- betterproto/tests/bytes.json | 3 +++ betterproto/tests/bytes.proto | 5 ++++ betterproto/tests/enums.json | 2 +- betterproto/tests/generate.py | 15 +++++++---- betterproto/tests/test_features.py | 18 ++++++++++++++ 8 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 betterproto/tests/bytes.json create mode 100644 betterproto/tests/bytes.proto diff --git a/README.md b/README.md index 24e0e27..b1b5fa1 100644 --- a/README.md +++ b/README.md @@ -169,10 +169,10 @@ Sometimes it is useful to be able to determine whether a message has been sent o Use `Message().serialized_on_wire` to determine if it was sent. This is a little bit different from the official Google generated Python code: ```py -# Old way +# Old way (official Google Protobuf package) >>> mymessage.HasField('myfield') -# New way +# New way (this project) >>> mymessage.myfield.serialized_on_wire ``` @@ -226,8 +226,9 @@ $ pipenv run tests - [x] 64-bit ints as strings - [x] Maps - [x] Lists - - [ ] Bytes as base64 + - [x] Bytes as base64 - [ ] Any support + - [x] Enum strings - [ ] Well known types support (timestamp, duration, wrappers) - [ ] Async service stubs - [x] Unary-unary diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 7a1bb3a..e53f869 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -1,8 +1,10 @@ import dataclasses +import enum import inspect import json import struct from abc import ABC +from base64 import b64encode, b64decode from typing import ( Any, AsyncGenerator, @@ -222,6 +224,18 @@ def map_field(number: int, key_type: str, value_type: str) -> Any: return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type)) +class Enum(int, enum.Enum): + """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" + + @classmethod + def from_string(cls, name: str) -> int: + """Return the value which corresponds to the string name.""" + try: + return cls.__members__[name] + except KeyError as e: + raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e + + def _pack_fmt(proto_type: str) -> str: """Returns a little-endian format string for reading/writing binary.""" return { @@ -596,6 +610,17 @@ class Message(ABC): output[field.name] = [str(n) for n in v] else: output[field.name] = str(v) + elif meta.proto_type == TYPE_BYTES: + if isinstance(v, list): + output[field.name] = [b64encode(b).decode("utf8") for b in v] + else: + output[field.name] = b64encode(v).decode("utf8") + elif meta.proto_type == TYPE_ENUM: + enum_values = list(self._cls_for(field)) + if isinstance(v, list): + output[field.name] = [enum_values[e].name for e in v] + else: + output[field.name] = enum_values[v].name else: output[field.name] = v return output @@ -630,7 +655,20 @@ class Message(ABC): v = [int(n) for n in value[field.name]] else: v = int(value[field.name]) - setattr(self, field.name, v) + elif meta.proto_type == TYPE_BYTES: + if isinstance(value[field.name], list): + v = [b64decode(n) for n in value[field.name]] + else: + v = b64decode(value[field.name]) + elif meta.proto_type == TYPE_ENUM: + enum_cls = self._cls_for(field) + if isinstance(v, list): + v = [enum_cls.from_string(e) for e in v] + elif isinstance(v, str): + v = enum_cls.from_string(v) + + if v is not None: + setattr(self, field.name, v) return self def to_json(self, indent: Union[None, int, str] = None) -> str: diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index ab4a9fa..2e3441a 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -1,8 +1,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # sources: {{ ', '.join(description.files) }} # plugin: python-betterproto -{% if description.enums %}import enum -{% endif %} from dataclasses import dataclass {% if description.typing_imports %} from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} @@ -20,7 +18,7 @@ import grpclib {% if description.enums %}{% for enum in description.enums %} -class {{ enum.name }}(enum.IntEnum): +class {{ enum.name }}(betterproto.Enum): {% if enum.comment %} {{ enum.comment }} diff --git a/betterproto/tests/bytes.json b/betterproto/tests/bytes.json new file mode 100644 index 0000000..34c4554 --- /dev/null +++ b/betterproto/tests/bytes.json @@ -0,0 +1,3 @@ +{ + "data": "SGVsbG8sIFdvcmxkIQ==" +} diff --git a/betterproto/tests/bytes.proto b/betterproto/tests/bytes.proto new file mode 100644 index 0000000..de677e3 --- /dev/null +++ b/betterproto/tests/bytes.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Test { + bytes data = 1; +} diff --git a/betterproto/tests/enums.json b/betterproto/tests/enums.json index 182f73c..a4d009c 100644 --- a/betterproto/tests/enums.json +++ b/betterproto/tests/enums.json @@ -1,3 +1,3 @@ { - "greeting": 1 + "greeting": "HEY" } diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index 49ef870..b09463e 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -69,10 +69,15 @@ if __name__ == "__main__": print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") imported = importlib.import_module(f"{parts[0]}_pb2") - parsed = Parse(open(filename).read(), imported.Test()) + input_json = open(filename).read() + parsed = Parse(input_json, imported.Test()) serialized = parsed.SerializeToString() - serialized_json = MessageToJson( - parsed, preserving_proto_field_name=True, use_integers_for_enums=True - ) - assert json.loads(serialized_json) == json.load(open(filename)) + serialized_json = MessageToJson(parsed, preserving_proto_field_name=True) + + s_loaded = json.loads(serialized_json) + in_loaded = json.loads(input_json) + + if s_loaded != in_loaded: + raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded) + open(out, "wb").write(serialized) diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index a5e026e..4a6c4d4 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -30,3 +30,21 @@ def test_has_field(): # Can manually set it but defaults to false foo.bar = Bar() assert foo.bar.serialized_on_wire == False + + +def test_enum_as_int_json(): + class TestEnum(betterproto.Enum): + ZERO = 0 + ONE = 1 + + @dataclass + class Foo(betterproto.Message): + bar: TestEnum = betterproto.enum_field(1) + + # JSON strings are supported, but ints should still be supported too. + foo = Foo().from_dict({"bar": 1}) + assert foo.bar == TestEnum.ONE + + # Plain-ol'-ints should serialize properly too. + foo.bar = 1 + assert foo.to_dict() == {"bar": "ONE"}