Refactor default value code
This commit is contained in:
parent
4679c571c3
commit
5daf61f64c
@ -118,18 +118,6 @@ class _PLACEHOLDER:
|
||||
PLACEHOLDER: Any = _PLACEHOLDER()
|
||||
|
||||
|
||||
def get_default(proto_type: str) -> Any:
|
||||
"""Get the default (zero value) for a given type."""
|
||||
return {
|
||||
TYPE_BOOL: False,
|
||||
TYPE_FLOAT: 0.0,
|
||||
TYPE_DOUBLE: 0.0,
|
||||
TYPE_STRING: "",
|
||||
TYPE_BYTES: b"",
|
||||
TYPE_MAP: {},
|
||||
}.get(proto_type, 0)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FieldMetadata:
|
||||
"""Stores internal metadata used for parsing & serialization."""
|
||||
@ -467,11 +455,22 @@ class Message(ABC):
|
||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||
selected_in_group = True
|
||||
|
||||
if isinstance(value, list):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field, meta) and not (
|
||||
selected_in_group or serialize_empty
|
||||
):
|
||||
# Default (zero) values are not serialized. Two exceptions are
|
||||
# if this is the selected oneof item or if we know we have to
|
||||
# serialize an empty message (i.e. zero value was explicitly
|
||||
# set by the user).
|
||||
continue
|
||||
|
||||
if isinstance(value, list):
|
||||
if meta.proto_type in PACKED_TYPES:
|
||||
# Packed lists look like a length-delimited field. First,
|
||||
# preprocess/encode each value into a buffer and then
|
||||
@ -484,23 +483,12 @@ class Message(ABC):
|
||||
for item in value:
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
elif isinstance(value, dict):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
for k, v in value.items():
|
||||
assert meta.map_types
|
||||
sk = _serialize_single(1, meta.map_types[0], k)
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
if value == get_default(meta.proto_type) and not selected_in_group:
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
serialize_empty = True
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
)
|
||||
@ -510,30 +498,42 @@ class Message(ABC):
|
||||
# For compatibility with other libraries
|
||||
SerializeToString = __bytes__
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
def _type_hint(self, field_name: str) -> Type:
|
||||
module = inspect.getmodule(self.__class__)
|
||||
type_hints = get_type_hints(self.__class__, vars(module))
|
||||
cls = type_hints[field.name]
|
||||
return type_hints[field_name]
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
cls = self._type_hint(field.name)
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = type_hints[field.name].__args__[index]
|
||||
cls = cls.__args__[index]
|
||||
return cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
t = self._cls_for(field, index=-1)
|
||||
t = self._type_hint(field.name)
|
||||
|
||||
value: Any = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
if hasattr(t, "__origin__"):
|
||||
if t.__origin__ == dict:
|
||||
# This is some kind of map (dict in Python).
|
||||
value = {}
|
||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||
# Anything else with type args is a list.
|
||||
elif t.__origin__ == list:
|
||||
# This is some kind of list (repeated) field.
|
||||
value = []
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
# Message means creating an instance of the right type.
|
||||
value = t()
|
||||
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||
# This is an optional (wrapped) field. For setting the default we
|
||||
# really don't care what kind of field it is.
|
||||
value = None
|
||||
else:
|
||||
value = get_default(meta.proto_type)
|
||||
value = t()
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
value = 0
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
value = t()
|
||||
|
||||
return value
|
||||
|
||||
@ -659,7 +659,7 @@ class Message(ABC):
|
||||
|
||||
if v:
|
||||
output[cased_name] = v
|
||||
elif v != get_default(meta.proto_type):
|
||||
elif v != self._get_field_default(field, meta):
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [str(n) for n in v]
|
||||
|
@ -35,11 +35,15 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user