From 5daf61f64c864a37349384af66670b43a2f6bbc8 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Fri, 25 Oct 2019 21:16:32 -0700 Subject: [PATCH] Refactor default value code --- betterproto/__init__.py | 86 ++++++++++++++++++++--------------------- betterproto/plugin.py | 12 ++++-- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index c698f95..8fa5819 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -118,18 +118,6 @@ class _PLACEHOLDER: PLACEHOLDER: Any = _PLACEHOLDER() -def get_default(proto_type: str) -> Any: - """Get the default (zero value) for a given type.""" - return { - TYPE_BOOL: False, - TYPE_FLOAT: 0.0, - TYPE_DOUBLE: 0.0, - TYPE_STRING: "", - TYPE_BYTES: b"", - TYPE_MAP: {}, - }.get(proto_type, 0) - - @dataclasses.dataclass(frozen=True) class FieldMetadata: """Stores internal metadata used for parsing & serialization.""" @@ -467,11 +455,22 @@ class Message(ABC): if meta.group and self._group_map["groups"][meta.group]["current"] == field: selected_in_group = True - if isinstance(value, list): - if not len(value) and not selected_in_group: - # Empty values are not serialized - continue + serialize_empty = False + if isinstance(value, Message) and value._serialized_on_wire: + # Empty messages can still be sent on the wire if they were + # set (or received empty). + serialize_empty = True + if value == self._get_field_default(field, meta) and not ( + selected_in_group or serialize_empty + ): + # Default (zero) values are not serialized. Two exceptions are + # if this is the selected oneof item or if we know we have to + # serialize an empty message (i.e. zero value was explicitly + # set by the user). + continue + + if isinstance(value, list): if meta.proto_type in PACKED_TYPES: # Packed lists look like a length-delimited field. First, # preprocess/encode each value into a buffer and then @@ -484,23 +483,12 @@ class Message(ABC): for item in value: output += _serialize_single(meta.number, meta.proto_type, item) elif isinstance(value, dict): - if not len(value) and not selected_in_group: - # Empty values are not serialized - continue - for k, v in value.items(): assert meta.map_types sk = _serialize_single(1, meta.map_types[0], k) sv = _serialize_single(2, meta.map_types[1], v) output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: - if value == get_default(meta.proto_type) and not selected_in_group: - # Default (zero) values are not serialized - continue - - serialize_empty = False - if isinstance(value, Message) and value._serialized_on_wire: - serialize_empty = True output += _serialize_single( meta.number, meta.proto_type, value, serialize_empty=serialize_empty ) @@ -510,30 +498,42 @@ class Message(ABC): # For compatibility with other libraries SerializeToString = __bytes__ - def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: - """Get the message class for a field from the type hints.""" + def _type_hint(self, field_name: str) -> Type: module = inspect.getmodule(self.__class__) type_hints = get_type_hints(self.__class__, vars(module)) - cls = type_hints[field.name] + return type_hints[field_name] + + def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: + """Get the message class for a field from the type hints.""" + cls = self._type_hint(field.name) if hasattr(cls, "__args__") and index >= 0: - cls = type_hints[field.name].__args__[index] + cls = cls.__args__[index] return cls def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: - t = self._cls_for(field, index=-1) + t = self._type_hint(field.name) value: Any = 0 - if meta.proto_type == TYPE_MAP: - # Maps cannot be repeated, so we check these first. - value = {} - elif hasattr(t, "__args__") and len(t.__args__) == 1: - # Anything else with type args is a list. - value = [] - elif meta.proto_type == TYPE_MESSAGE: - # Message means creating an instance of the right type. - value = t() + if hasattr(t, "__origin__"): + if t.__origin__ == dict: + # This is some kind of map (dict in Python). + value = {} + elif t.__origin__ == list: + # This is some kind of list (repeated) field. + value = [] + elif t.__origin__ == Union and t.__args__[1] == type(None): + # This is an optional (wrapped) field. For setting the default we + # really don't care what kind of field it is. + value = None + else: + value = t() + elif issubclass(t, Enum): + # Enums always default to zero. + value = 0 else: - value = get_default(meta.proto_type) + # This is either a primitive scalar or another message type. Calling + # it should result in its zero value. + value = t() return value @@ -659,7 +659,7 @@ class Message(ABC): if v: output[cased_name] = v - elif v != get_default(meta.proto_type): + elif v != self._get_field_default(field, meta): if meta.proto_type in INT_64_TYPES: if isinstance(v, list): output[cased_name] = [str(n) for n in v] diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 7795623..38c08d4 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -35,13 +35,17 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str: Return a Python type name for a proto type reference. Adds the import if necessary. """ + # If the package name is a blank string, then this should still work + # because by convention packages are lowercase and message/enum types are + # pascal-cased. May require refactoring in the future. type_name = type_name.lstrip(".") if type_name.startswith(package): - # This is the current package, which has nested types flattened. - # foo.bar_thing => FooBarThing parts = type_name.lstrip(package).lstrip(".").split(".") - cased = [stringcase.pascalcase(part) for part in parts] - type_name = f'"{"".join(cased)}"' + if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): + # This is the current package, which has nested types flattened. + # foo.bar_thing => FooBarThing + cased = [stringcase.pascalcase(part) for part in parts] + type_name = f'"{"".join(cased)}"' if "." in type_name: # This is imported from another package. No need