diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 6ea3891..5e9fa3c 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -1,28 +1,28 @@ -from abc import ABC +import dataclasses +import inspect import json import struct +from abc import ABC from typing import ( - get_type_hints, - AsyncGenerator, - Union, - Generator, Any, - SupportsBytes, - List, - Tuple, + AsyncGenerator, Callable, - Type, + Dict, + Generator, Iterable, - TypeVar, + List, Optional, + SupportsBytes, + Tuple, + Type, + TypeVar, + Union, + get_type_hints, ) -import dataclasses import grpclib.client import grpclib.const -import inspect - # Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" @@ -54,6 +54,9 @@ FIXED_TYPES = [ TYPE_SFIXED64, ] +# Fields that are numerical 64-bit types +INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64] + # Fields that are efficiently packed when PACKED_TYPES = [ TYPE_ENUM, @@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: return value -def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes: +def _serialize_single( + field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False +) -> bytes: """Serializes a single field and value.""" value = _preprocess_single(proto_type, value) @@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes: key = encode_varint((field_number << 3) | 1) output += key + value elif proto_type in WIRE_LEN_DELIM_TYPES: - if len(value): + if len(value) or serialize_empty: key = encode_varint((field_number << 3) | 2) output += key + encode_varint(len(value)) + value else: @@ -362,6 +367,11 @@ class Message(ABC): to go between Python, binary and JSON protobuf message representations. """ + # True if this message was or should be serialized on the wire. This can + # be used to detect presence (e.g. optional wrapper message) and is used + # internally during parsing/serialization. + serialized_on_wire: bool + def __post_init__(self) -> None: # Set a default value for each field in the class after `__init__` has # already been run. @@ -389,6 +399,15 @@ class Message(ABC): setattr(self, field.name, value) + # Now that all the defaults are set, reset it! + self.__dict__["serialized_on_wire"] = False + + def __setattr__(self, attr: str, value: Any) -> None: + if attr != "serialized_on_wire": + # Track when a field has been set. + self.__dict__["serialized_on_wire"] = True + super().__setattr__(attr, value) + def __bytes__(self) -> bytes: """ Get the binary encoded Protobuf representation of this instance. @@ -429,7 +448,12 @@ class Message(ABC): # Default (zero) values are not serialized continue - output += _serialize_single(meta.number, meta.proto_type, value) + serialize_empty = False + if isinstance(value, Message) and value.serialized_on_wire: + serialize_empty = True + output += _serialize_single( + meta.number, meta.proto_type, value, serialize_empty=serialize_empty + ) return output @@ -462,12 +486,13 @@ class Message(ABC): fmt = _pack_fmt(meta.proto_type) value = struct.unpack(fmt, value)[0] elif wire_type == WIRE_LEN_DELIM: - if meta.proto_type in [TYPE_STRING]: + if meta.proto_type == TYPE_STRING: value = value.decode("utf-8") - elif meta.proto_type in [TYPE_MESSAGE]: + elif meta.proto_type == TYPE_MESSAGE: cls = self._cls_for(field) value = cls().parse(value) - elif meta.proto_type in [TYPE_MAP]: + value.serialized_on_wire = True + elif meta.proto_type == TYPE_MAP: # TODO: This is slow, use a cache to make it faster since each # key/value pair will recreate the class. assert meta.map_types @@ -535,8 +560,6 @@ class Message(ABC): # TODO: handle unknown fields pass - from typing import cast - return self # For compatibility with other libraries. @@ -549,7 +572,7 @@ class Message(ABC): Returns a dict representation of this message instance which can be used to serialize to e.g. JSON. """ - output = {} + output: Dict[str, Any] = {} for field in dataclasses.fields(self): meta = FieldMetadata.get(field) v = getattr(self, field.name) @@ -557,13 +580,9 @@ class Message(ABC): if isinstance(v, list): # Convert each item. v = [i.to_dict() for i in v] - # Filter out empty items which we won't serialize. - v = [i for i in v if i] - else: - v = v.to_dict() - - if v: output[field.name] = v + elif v.serialized_on_wire: + output[field.name] = v.to_dict() elif meta.proto_type == "map": for k in v: if hasattr(v[k], "to_dict"): @@ -572,7 +591,13 @@ class Message(ABC): if v: output[field.name] = v elif v != get_default(meta.proto_type): - output[field.name] = v + if meta.proto_type in INT_64_TYPES: + if isinstance(v, list): + output[field.name] = [str(n) for n in v] + else: + output[field.name] = str(v) + else: + output[field.name] = v return output def from_dict(self: T, value: dict) -> T: @@ -580,6 +605,7 @@ class Message(ABC): Parse the key/value pairs in `value` into this message instance. This returns the instance itself and is therefore assignable and chainable. """ + self.serialized_on_wire = True for field in dataclasses.fields(self): meta = FieldMetadata.get(field) if field.name in value and value[field.name] is not None: @@ -598,7 +624,13 @@ class Message(ABC): for k in value[field.name]: v[k] = cls().from_dict(value[field.name][k]) else: - setattr(self, field.name, value[field.name]) + v = value[field.name] + if meta.proto_type in INT_64_TYPES: + if isinstance(value[field.name], list): + v = [int(n) for n in value[field.name]] + else: + v = int(value[field.name]) + setattr(self, field.name, v) return self def to_json(self) -> str: @@ -613,9 +645,6 @@ class Message(ABC): return self.from_dict(json.loads(value)) -ResponseType = TypeVar("ResponseType", bound="Message") - - class ServiceStub(ABC): """ Base class for async gRPC service stubs. diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index ae3b095..5fa9e55 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,19 +1,21 @@ #!/usr/bin/env python +import importlib +import json import os # isort: skip +import subprocess +import sys +from typing import Generator, Tuple + +from google.protobuf import symbol_database +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.json_format import MessageToJson, Parse # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -import subprocess -import importlib -import sys -from typing import Generator, Tuple -from google.protobuf.json_format import Parse -from google.protobuf import symbol_database -from google.protobuf.descriptor_pool import DescriptorPool root = os.path.dirname(os.path.realpath(__file__)) @@ -68,5 +70,10 @@ if __name__ == "__main__": print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") imported = importlib.import_module(f"{parts[0]}_pb2") - serialized = Parse(open(filename).read(), imported.Test()).SerializeToString() + parsed = Parse(open(filename).read(), imported.Test()) + serialized = parsed.SerializeToString() + serialized_json = MessageToJson( + parsed, preserving_proto_field_name=True, use_integers_for_enums=True + ) + assert json.loads(serialized_json) == json.load(open(filename)) open(out, "wb").write(serialized) diff --git a/betterproto/tests/nested.json b/betterproto/tests/nested.json index 217a7d4..f34f1d7 100644 --- a/betterproto/tests/nested.json +++ b/betterproto/tests/nested.json @@ -1,5 +1,6 @@ { "nested": { "count": 150 - } + }, + "sibling": {} } diff --git a/betterproto/tests/nested.proto b/betterproto/tests/nested.proto index 0ed4540..974bf86 100644 --- a/betterproto/tests/nested.proto +++ b/betterproto/tests/nested.proto @@ -10,8 +10,9 @@ message Test { Nested nested = 1; Sibling sibling = 2; + Sibling sibling2 = 3; } message Sibling { int32 foo = 1; -} \ No newline at end of file +} diff --git a/betterproto/tests/repeatedpacked.json b/betterproto/tests/repeatedpacked.json index 2a19e3d..106fd90 100644 --- a/betterproto/tests/repeatedpacked.json +++ b/betterproto/tests/repeatedpacked.json @@ -1,5 +1,5 @@ { "counts": [1, 2, -1, -2], - "signed": [1, 2, -1, -2], + "signed": ["1", "2", "-1", "-2"], "fixed": [1.0, 2.7, 3.4] } diff --git a/betterproto/tests/signed-negative.json b/betterproto/tests/signed-negative.json index 85e74c8..2f6525a 100644 --- a/betterproto/tests/signed-negative.json +++ b/betterproto/tests/signed-negative.json @@ -1,4 +1,4 @@ { "signed_32": -150, - "signed_64": -150 + "signed_64": "-150" } diff --git a/betterproto/tests/signed.json b/betterproto/tests/signed.json index 3d5696a..6049d88 100644 --- a/betterproto/tests/signed.json +++ b/betterproto/tests/signed.json @@ -1,4 +1,4 @@ { "signed_32": 150, - "signed_64": 150 + "signed_64": "150" } diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py new file mode 100644 index 0000000..a5e026e --- /dev/null +++ b/betterproto/tests/test_features.py @@ -0,0 +1,32 @@ +import betterproto +from dataclasses import dataclass + + +def test_has_field(): + @dataclass + class Bar(betterproto.Message): + baz: int = betterproto.int32_field(1) + + @dataclass + class Foo(betterproto.Message): + bar: Bar = betterproto.message_field(1) + + # Unset by default + foo = Foo() + assert foo.bar.serialized_on_wire == False + + # Serialized after setting something + foo.bar.baz = 1 + assert foo.bar.serialized_on_wire == True + + # Still has it after setting the default value + foo.bar.baz = 0 + assert foo.bar.serialized_on_wire == True + + # Manual override + foo.bar.serialized_on_wire = False + assert foo.bar.serialized_on_wire == False + + # Can manually set it but defaults to false + foo.bar = Bar() + assert foo.bar.serialized_on_wire == False diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index 18d8d6c..e8bc66c 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -1,8 +1,9 @@ import importlib -import pytest import json -from .generate import get_files, get_base +import pytest + +from .generate import get_base, get_files inputs = get_files(".bin") diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index 6f578ca..2d3e470 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -1,27 +1,24 @@ #!/usr/bin/env python -import sys - import itertools import json import os.path import re -from typing import Tuple, Any, List +import sys import textwrap +from typing import Any, List, Tuple +from jinja2 import Environment, PackageLoader + +from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, - FileDescriptorProto, FieldDescriptorProto, + FileDescriptorProto, ServiceDescriptorProto, ) -from google.protobuf.compiler import plugin_pb2 as plugin - - -from jinja2 import Environment, PackageLoader - def snake_case(value: str) -> str: return ( diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f409741 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[tool.black] +target-version = ['py37'] + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +line_length = 88