Refactor default value code

This commit is contained in:
Daniel G. Taylor 2019-10-25 21:16:32 -07:00
parent 4679c571c3
commit 5daf61f64c
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
2 changed files with 51 additions and 47 deletions

View File

@ -118,18 +118,6 @@ class _PLACEHOLDER:
PLACEHOLDER: Any = _PLACEHOLDER() PLACEHOLDER: Any = _PLACEHOLDER()
def get_default(proto_type: str) -> Any:
"""Get the default (zero value) for a given type."""
return {
TYPE_BOOL: False,
TYPE_FLOAT: 0.0,
TYPE_DOUBLE: 0.0,
TYPE_STRING: "",
TYPE_BYTES: b"",
TYPE_MAP: {},
}.get(proto_type, 0)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FieldMetadata: class FieldMetadata:
"""Stores internal metadata used for parsing & serialization.""" """Stores internal metadata used for parsing & serialization."""
@ -467,11 +455,22 @@ class Message(ABC):
if meta.group and self._group_map["groups"][meta.group]["current"] == field: if meta.group and self._group_map["groups"][meta.group]["current"] == field:
selected_in_group = True selected_in_group = True
if isinstance(value, list): serialize_empty = False
if not len(value) and not selected_in_group: if isinstance(value, Message) and value._serialized_on_wire:
# Empty values are not serialized # Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = True
if value == self._get_field_default(field, meta) and not (
selected_in_group or serialize_empty
):
# Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to
# serialize an empty message (i.e. zero value was explicitly
# set by the user).
continue continue
if isinstance(value, list):
if meta.proto_type in PACKED_TYPES: if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First, # Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then # preprocess/encode each value into a buffer and then
@ -484,23 +483,12 @@ class Message(ABC):
for item in value: for item in value:
output += _serialize_single(meta.number, meta.proto_type, item) output += _serialize_single(meta.number, meta.proto_type, item)
elif isinstance(value, dict): elif isinstance(value, dict):
if not len(value) and not selected_in_group:
# Empty values are not serialized
continue
for k, v in value.items(): for k, v in value.items():
assert meta.map_types assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k) sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v) sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv) output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else: else:
if value == get_default(meta.proto_type) and not selected_in_group:
# Default (zero) values are not serialized
continue
serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
serialize_empty = True
output += _serialize_single( output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty meta.number, meta.proto_type, value, serialize_empty=serialize_empty
) )
@ -510,30 +498,42 @@ class Message(ABC):
# For compatibility with other libraries # For compatibility with other libraries
SerializeToString = __bytes__ SerializeToString = __bytes__
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: def _type_hint(self, field_name: str) -> Type:
"""Get the message class for a field from the type hints."""
module = inspect.getmodule(self.__class__) module = inspect.getmodule(self.__class__)
type_hints = get_type_hints(self.__class__, vars(module)) type_hints = get_type_hints(self.__class__, vars(module))
cls = type_hints[field.name] return type_hints[field_name]
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
cls = self._type_hint(field.name)
if hasattr(cls, "__args__") and index >= 0: if hasattr(cls, "__args__") and index >= 0:
cls = type_hints[field.name].__args__[index] cls = cls.__args__[index]
return cls return cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
t = self._cls_for(field, index=-1) t = self._type_hint(field.name)
value: Any = 0 value: Any = 0
if meta.proto_type == TYPE_MAP: if hasattr(t, "__origin__"):
# Maps cannot be repeated, so we check these first. if t.__origin__ == dict:
# This is some kind of map (dict in Python).
value = {} value = {}
elif hasattr(t, "__args__") and len(t.__args__) == 1: elif t.__origin__ == list:
# Anything else with type args is a list. # This is some kind of list (repeated) field.
value = [] value = []
elif meta.proto_type == TYPE_MESSAGE: elif t.__origin__ == Union and t.__args__[1] == type(None):
# Message means creating an instance of the right type. # This is an optional (wrapped) field. For setting the default we
value = t() # really don't care what kind of field it is.
value = None
else: else:
value = get_default(meta.proto_type) value = t()
elif issubclass(t, Enum):
# Enums always default to zero.
value = 0
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
value = t()
return value return value
@ -659,7 +659,7 @@ class Message(ABC):
if v: if v:
output[cased_name] = v output[cased_name] = v
elif v != get_default(meta.proto_type): elif v != self._get_field_default(field, meta):
if meta.proto_type in INT_64_TYPES: if meta.proto_type in INT_64_TYPES:
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [str(n) for n in v] output[cased_name] = [str(n) for n in v]

View File

@ -35,11 +35,15 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
necessary. necessary.
""" """
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".") type_name = type_name.lstrip(".")
if type_name.startswith(package): if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened. # This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing # foo.bar_thing => FooBarThing
parts = type_name.lstrip(package).lstrip(".").split(".")
cased = [stringcase.pascalcase(part) for part in parts] cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"' type_name = f'"{"".join(cased)}"'