diff --git a/README.md b/README.md index 8a86bf8..333d974 100644 --- a/README.md +++ b/README.md @@ -168,19 +168,75 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example. -Use `Message().serialized_on_wire` to determine if it was sent. This is a little bit different from the official Google generated Python code: +Use `betterproto.serialized_on_wire(message)` to determine if it was sent. This is a little bit different from the official Google generated Python code, and it lives outside the generated `Message` class to prevent name clashes. Note that it **only** supports Proto 3 and thus can **only** be used to check if `Message` fields are set. You cannot check if a scalar was sent on the wire. ```py # Old way (official Google Protobuf package) >>> mymessage.HasField('myfield') # New way (this project) ->>> mymessage.myfield.serialized_on_wire +>>> betterproto.serialized_on_wire(mymessage.myfield) +``` + +### One-of Support + +Protobuf supports grouping fields in a `oneof` clause. Only one of the fields in the group may be set at a given time. For example, given the proto: + +```protobuf +syntax = "proto3"; + +message Test { + oneof foo { + bool on = 1; + int32 count = 2; + string name = 3; + } +} +``` + +You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset. + +```py +>>> test = Test() +>>> betterproto.which_one_of(test, "foo") +["", None] + +>>> test.on = True +>>> betterproto.which_one_of(test, "foo") +["on", True] + +# Setting one member of the group resets the others. +>>> test.count = 57 +>>> betterproto.which_one_of(test, "foo") +["count", 57] +>>> test.on +False + +# Default (zero) values also work. +>>> test.name = "" +>>> betterproto.which_one_of(test, "foo") +["name", ""] +>>> test.count +0 +>>> test.on +False +``` + +Again this is a little different than the official Google code generator: + +```py +# Old way (official Google protobuf package) +>>> message.WhichOneof("group") +"foo" + +# New way (this project) +>>> betterproto.which_one_of(message, "group") +["foo", "foo's value"] ``` ## Development -First, make sure you have Python 3.7+ and `pipenv` installed: +First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then: ```sh # Get set up with the virtual env & dependencies @@ -224,10 +280,10 @@ $ pipenv run tests - [x] Refs to nested types - [x] Imports in proto files - [x] Well-known Google types -- [ ] OneOf support +- [x] OneOf support - [x] Basic support on the wire - - [ ] Check which was set from the group - - [ ] Setting one unsets the others + - [x] Check which was set from the group + - [x] Setting one unsets the others - [ ] JSON that isn't completely naive. - [x] 64-bit ints as strings - [x] Maps @@ -236,6 +292,7 @@ $ pipenv run tests - [ ] Any support - [x] Enum strings - [ ] Well known types support (timestamp, duration, wrappers) + - [ ] Support different casing (orig vs. camel vs. others?) - [ ] Async service stubs - [x] Unary-unary - [x] Server streaming response @@ -243,7 +300,7 @@ $ pipenv run tests - [ ] Renaming messages and fields to conform to Python name standards - [ ] Renaming clashes with language keywords and standard library top-level packages - [x] Python package -- [ ] Automate running tests +- [x] Automate running tests - [ ] Cleanup! ## License diff --git a/betterproto/__init__.py b/betterproto/__init__.py index b0fb445..5c1075a 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -130,6 +130,8 @@ class FieldMetadata: proto_type: str # Map information if the proto_type is a map map_types: Optional[Tuple[str, str]] + # Groups several "one-of" fields together + group: Optional[str] @staticmethod def get(field: dataclasses.Field) -> "FieldMetadata": @@ -138,12 +140,16 @@ class FieldMetadata: def dataclass_field( - number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None + number: int, + proto_type: str, + *, + map_types: Optional[Tuple[str, str]] = None, + group: Optional[str] = None, ) -> dataclasses.Field: """Creates a dataclass field with attached protobuf metadata.""" return dataclasses.field( default=PLACEHOLDER, - metadata={"betterproto": FieldMetadata(number, proto_type, map_types)}, + metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)}, ) @@ -152,76 +158,80 @@ def dataclass_field( # out at runtime. The generated dataclass variables are still typed correctly. -def enum_field(number: int) -> Any: - return dataclass_field(number, TYPE_ENUM) +def enum_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_ENUM, group=group) -def bool_field(number: int) -> Any: - return dataclass_field(number, TYPE_BOOL) +def bool_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_BOOL, group=group) -def int32_field(number: int) -> Any: - return dataclass_field(number, TYPE_INT32) +def int32_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_INT32, group=group) -def int64_field(number: int) -> Any: - return dataclass_field(number, TYPE_INT64) +def int64_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_INT64, group=group) -def uint32_field(number: int) -> Any: - return dataclass_field(number, TYPE_UINT32) +def uint32_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_UINT32, group=group) -def uint64_field(number: int) -> Any: - return dataclass_field(number, TYPE_UINT64) +def uint64_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_UINT64, group=group) -def sint32_field(number: int) -> Any: - return dataclass_field(number, TYPE_SINT32) +def sint32_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_SINT32, group=group) -def sint64_field(number: int) -> Any: - return dataclass_field(number, TYPE_SINT64) +def sint64_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_SINT64, group=group) -def float_field(number: int) -> Any: - return dataclass_field(number, TYPE_FLOAT) +def float_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_FLOAT, group=group) -def double_field(number: int) -> Any: - return dataclass_field(number, TYPE_DOUBLE) +def double_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_DOUBLE, group=group) -def fixed32_field(number: int) -> Any: - return dataclass_field(number, TYPE_FIXED32) +def fixed32_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_FIXED32, group=group) -def fixed64_field(number: int) -> Any: - return dataclass_field(number, TYPE_FIXED64) +def fixed64_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_FIXED64, group=group) -def sfixed32_field(number: int) -> Any: - return dataclass_field(number, TYPE_SFIXED32) +def sfixed32_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_SFIXED32, group=group) -def sfixed64_field(number: int) -> Any: - return dataclass_field(number, TYPE_SFIXED64) +def sfixed64_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_SFIXED64, group=group) -def string_field(number: int) -> Any: - return dataclass_field(number, TYPE_STRING) +def string_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_STRING, group=group) -def bytes_field(number: int) -> Any: - return dataclass_field(number, TYPE_BYTES) +def bytes_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_BYTES, group=group) -def message_field(number: int) -> Any: - return dataclass_field(number, TYPE_MESSAGE) +def message_field(number: int, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_MESSAGE, group=group) -def map_field(number: int, key_type: str, value_type: str) -> Any: - return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type)) +def map_field( + number: int, key_type: str, value_type: str, group: Optional[str] = None +) -> Any: + return dataclass_field( + number, TYPE_MAP, map_types=(key_type, value_type), group=group + ) class Enum(int, enum.Enum): @@ -383,46 +393,52 @@ class Message(ABC): to go between Python, binary and JSON protobuf message representations. """ - # True if this message was or should be serialized on the wire. This can - # be used to detect presence (e.g. optional wrapper message) and is used - # internally during parsing/serialization. - serialized_on_wire: bool - def __post_init__(self) -> None: # Set a default value for each field in the class after `__init__` has # already been run. + group_map = {"fields": {}, "groups": {}} for field in dataclasses.fields(self): - if getattr(self, field.name) != PLACEHOLDER: - # Skip anything not set (aka set to the sentinel value) - continue - meta = FieldMetadata.get(field) - t = self._cls_for(field, index=-1) + if meta.group: + group_map["fields"][field.name] = meta.group - value: Any = 0 - if meta.proto_type == TYPE_MAP: - # Maps cannot be repeated, so we check these first. - value = {} - elif hasattr(t, "__args__") and len(t.__args__) == 1: - # Anything else with type args is a list. - value = [] - elif meta.proto_type == TYPE_MESSAGE: - # Message means creating an instance of the right type. - value = t() - else: - value = get_default(meta.proto_type) + if meta.group not in group_map["groups"]: + group_map["groups"][meta.group] = {"current": None, "fields": set()} + group_map["groups"][meta.group]["fields"].add(field) - setattr(self, field.name, value) + if getattr(self, field.name) != PLACEHOLDER: + # Skip anything not set to the sentinel value + + if meta.group: + # This was set, so make it the selected value of the one-of. + group_map["groups"][meta.group]["current"] = field + + continue + + setattr(self, field.name, self._get_field_default(field, meta)) # Now that all the defaults are set, reset it! - self.__dict__["serialized_on_wire"] = False + self.__dict__["_serialized_on_wire"] = False self.__dict__["_unknown_fields"] = b"" + self.__dict__["_group_map"] = group_map def __setattr__(self, attr: str, value: Any) -> None: - if attr != "serialized_on_wire": + if attr != "_serialized_on_wire": # Track when a field has been set. - self.__dict__["serialized_on_wire"] = True + self.__dict__["_serialized_on_wire"] = True + + if attr in getattr(self, "_group_map", {}).get("fields", {}): + group = self._group_map["fields"][attr] + for field in self._group_map["groups"][group]["fields"]: + if field.name == attr: + self._group_map["groups"][group]["current"] = field + else: + super().__setattr__( + field.name, + self._get_field_default(field, FieldMetadata.get(field)), + ) + super().__setattr__(attr, value) def __bytes__(self) -> bytes: @@ -434,8 +450,15 @@ class Message(ABC): meta = FieldMetadata.get(field) value = getattr(self, field.name) + # Being selected in a a group means this field is the one that is + # currently set in a `oneof` group, so it must be serialized even + # if the value is the default zero value. + selected_in_group = False + if meta.group and self._group_map["groups"][meta.group]["current"] == field: + selected_in_group = True + if isinstance(value, list): - if not len(value): + if not len(value) and not selected_in_group: # Empty values are not serialized continue @@ -451,7 +474,7 @@ class Message(ABC): for item in value: output += _serialize_single(meta.number, meta.proto_type, item) elif isinstance(value, dict): - if not len(value): + if not len(value) and not selected_in_group: # Empty values are not serialized continue @@ -461,12 +484,12 @@ class Message(ABC): sv = _serialize_single(2, meta.map_types[1], v) output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: - if value == get_default(meta.proto_type): + if value == get_default(meta.proto_type) and not selected_in_group: # Default (zero) values are not serialized continue serialize_empty = False - if isinstance(value, Message) and value.serialized_on_wire: + if isinstance(value, Message) and value._serialized_on_wire: serialize_empty = True output += _serialize_single( meta.number, meta.proto_type, value, serialize_empty=serialize_empty @@ -486,6 +509,24 @@ class Message(ABC): cls = type_hints[field.name].__args__[index] return cls + def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: + t = self._cls_for(field, index=-1) + + value: Any = 0 + if meta.proto_type == TYPE_MAP: + # Maps cannot be repeated, so we check these first. + value = {} + elif hasattr(t, "__args__") and len(t.__args__) == 1: + # Anything else with type args is a list. + value = [] + elif meta.proto_type == TYPE_MESSAGE: + # Message means creating an instance of the right type. + value = t() + else: + value = get_default(meta.proto_type) + + return value + def _postprocess_single( self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any ) -> Any: @@ -508,7 +549,7 @@ class Message(ABC): elif meta.proto_type == TYPE_MESSAGE: cls = self._cls_for(field) value = cls().parse(value) - value.serialized_on_wire = True + value._serialized_on_wire = True elif meta.proto_type == TYPE_MAP: # TODO: This is slow, use a cache to make it faster since each # key/value pair will recreate the class. @@ -518,8 +559,8 @@ class Message(ABC): Entry = dataclasses.make_dataclass( "Entry", [ - ("key", kt, dataclass_field(1, meta.map_types[0], None)), - ("value", vt, dataclass_field(2, meta.map_types[1], None)), + ("key", kt, dataclass_field(1, meta.map_types[0])), + ("value", vt, dataclass_field(2, meta.map_types[1])), ], bases=(Message,), ) @@ -597,7 +638,7 @@ class Message(ABC): # Convert each item. v = [i.to_dict() for i in v] output[field.name] = v - elif v.serialized_on_wire: + elif v._serialized_on_wire: output[field.name] = v.to_dict() elif meta.proto_type == "map": for k in v: @@ -632,7 +673,7 @@ class Message(ABC): Parse the key/value pairs in `value` into this message instance. This returns the instance itself and is therefore assignable and chainable. """ - self.serialized_on_wire = True + self._serialized_on_wire = True for field in dataclasses.fields(self): meta = FieldMetadata.get(field) if field.name in value and value[field.name] is not None: @@ -685,6 +726,23 @@ class Message(ABC): return self.from_dict(json.loads(value)) +def serialized_on_wire(message: Message) -> bool: + """ + True if this message was or should be serialized on the wire. This can + be used to detect presence (e.g. optional wrapper message) and is used + internally during parsing/serialization. + """ + return message._serialized_on_wire + + +def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: + """Return the name and value of a message's one-of field group.""" + field = message._group_map["groups"].get(group_name, {}).get("current") + if not field: + return ("", None) + return (field.name, getattr(message, field.name)) + + class ServiceStub(ABC): """ Base class for async gRPC service stubs. diff --git a/betterproto/plugin.py b/betterproto/plugin.py index f1a6d30..4e559e3 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -242,6 +242,10 @@ def generate_code(request, response): if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: packed = True + one_of = "" + if f.HasField("oneof_index"): + one_of = item.oneof_decl[f.oneof_index].name + data["properties"].append( { "name": f.name, @@ -254,6 +258,7 @@ def generate_code(request, response): "zero": zero, "repeated": repeated, "packed": packed, + "one_of": one_of, } ) # print(f, file=sys.stderr) diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index a8af85e..8f61ab9 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -44,7 +44,7 @@ class {{ message.name }}(betterproto.Message): {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}) + {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) {% endfor %} {% if not message.properties %} pass diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 8afaa74..d542baa 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -13,23 +13,23 @@ def test_has_field(): # Unset by default foo = Foo() - assert foo.bar.serialized_on_wire == False + assert betterproto.serialized_on_wire(foo.bar) == False # Serialized after setting something foo.bar.baz = 1 - assert foo.bar.serialized_on_wire == True + assert betterproto.serialized_on_wire(foo.bar) == True # Still has it after setting the default value foo.bar.baz = 0 - assert foo.bar.serialized_on_wire == True + assert betterproto.serialized_on_wire(foo.bar) == True - # Manual override - foo.bar.serialized_on_wire = False - assert foo.bar.serialized_on_wire == False + # Manual override (don't do this) + foo.bar._serialized_on_wire = False + assert betterproto.serialized_on_wire(foo.bar) == False # Can manually set it but defaults to false foo.bar = Bar() - assert foo.bar.serialized_on_wire == False + assert betterproto.serialized_on_wire(foo.bar) == False def test_enum_as_int_json(): @@ -70,3 +70,48 @@ def test_unknown_fields(): new_again = Newer().parse(round_trip) assert newer == new_again + + +def test_oneof_support(): + @dataclass + class Sub(betterproto.Message): + val: int = betterproto.int32_field(1) + + @dataclass + class Foo(betterproto.Message): + bar: int = betterproto.int32_field(1, group="group1") + baz: str = betterproto.string_field(2, group="group1") + sub: Sub = betterproto.message_field(3, group="group2") + abc: str = betterproto.string_field(4, group="group2") + + foo = Foo() + + assert betterproto.which_one_of(foo, "group1")[0] == "" + + foo.bar = 1 + foo.baz = "test" + + # Other oneof fields should now be unset + assert foo.bar == 0 + assert betterproto.which_one_of(foo, "group1")[0] == "baz" + + foo.sub.val = 1 + assert betterproto.serialized_on_wire(foo.sub) + + foo.abc = "test" + + # Group 1 shouldn't be touched, group 2 should have reset + assert foo.sub.val == 0 + assert betterproto.serialized_on_wire(foo.sub) == False + assert betterproto.which_one_of(foo, "group2")[0] == "abc" + + # Zero value should always serialize for one-of + foo = Foo(bar=0) + assert betterproto.which_one_of(foo, "group1")[0] == "bar" + assert bytes(foo) == b"\x08\x00" + + # Round trip should also work + foo2 = Foo().parse(bytes(foo)) + assert betterproto.which_one_of(foo2, "group1")[0] == "bar" + assert foo.bar == 0 + assert betterproto.which_one_of(foo2, "group2")[0] == ""