Refactor default value code
This commit is contained in:
parent
4679c571c3
commit
5daf61f64c
@ -118,18 +118,6 @@ class _PLACEHOLDER:
|
|||||||
PLACEHOLDER: Any = _PLACEHOLDER()
|
PLACEHOLDER: Any = _PLACEHOLDER()
|
||||||
|
|
||||||
|
|
||||||
def get_default(proto_type: str) -> Any:
|
|
||||||
"""Get the default (zero value) for a given type."""
|
|
||||||
return {
|
|
||||||
TYPE_BOOL: False,
|
|
||||||
TYPE_FLOAT: 0.0,
|
|
||||||
TYPE_DOUBLE: 0.0,
|
|
||||||
TYPE_STRING: "",
|
|
||||||
TYPE_BYTES: b"",
|
|
||||||
TYPE_MAP: {},
|
|
||||||
}.get(proto_type, 0)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class FieldMetadata:
|
class FieldMetadata:
|
||||||
"""Stores internal metadata used for parsing & serialization."""
|
"""Stores internal metadata used for parsing & serialization."""
|
||||||
@ -467,11 +455,22 @@ class Message(ABC):
|
|||||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||||
selected_in_group = True
|
selected_in_group = True
|
||||||
|
|
||||||
if isinstance(value, list):
|
serialize_empty = False
|
||||||
if not len(value) and not selected_in_group:
|
if isinstance(value, Message) and value._serialized_on_wire:
|
||||||
# Empty values are not serialized
|
# Empty messages can still be sent on the wire if they were
|
||||||
|
# set (or received empty).
|
||||||
|
serialize_empty = True
|
||||||
|
|
||||||
|
if value == self._get_field_default(field, meta) and not (
|
||||||
|
selected_in_group or serialize_empty
|
||||||
|
):
|
||||||
|
# Default (zero) values are not serialized. Two exceptions are
|
||||||
|
# if this is the selected oneof item or if we know we have to
|
||||||
|
# serialize an empty message (i.e. zero value was explicitly
|
||||||
|
# set by the user).
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
if meta.proto_type in PACKED_TYPES:
|
if meta.proto_type in PACKED_TYPES:
|
||||||
# Packed lists look like a length-delimited field. First,
|
# Packed lists look like a length-delimited field. First,
|
||||||
# preprocess/encode each value into a buffer and then
|
# preprocess/encode each value into a buffer and then
|
||||||
@ -484,23 +483,12 @@ class Message(ABC):
|
|||||||
for item in value:
|
for item in value:
|
||||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
if not len(value) and not selected_in_group:
|
|
||||||
# Empty values are not serialized
|
|
||||||
continue
|
|
||||||
|
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
assert meta.map_types
|
assert meta.map_types
|
||||||
sk = _serialize_single(1, meta.map_types[0], k)
|
sk = _serialize_single(1, meta.map_types[0], k)
|
||||||
sv = _serialize_single(2, meta.map_types[1], v)
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||||
else:
|
else:
|
||||||
if value == get_default(meta.proto_type) and not selected_in_group:
|
|
||||||
# Default (zero) values are not serialized
|
|
||||||
continue
|
|
||||||
|
|
||||||
serialize_empty = False
|
|
||||||
if isinstance(value, Message) and value._serialized_on_wire:
|
|
||||||
serialize_empty = True
|
|
||||||
output += _serialize_single(
|
output += _serialize_single(
|
||||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||||
)
|
)
|
||||||
@ -510,30 +498,42 @@ class Message(ABC):
|
|||||||
# For compatibility with other libraries
|
# For compatibility with other libraries
|
||||||
SerializeToString = __bytes__
|
SerializeToString = __bytes__
|
||||||
|
|
||||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
def _type_hint(self, field_name: str) -> Type:
|
||||||
"""Get the message class for a field from the type hints."""
|
|
||||||
module = inspect.getmodule(self.__class__)
|
module = inspect.getmodule(self.__class__)
|
||||||
type_hints = get_type_hints(self.__class__, vars(module))
|
type_hints = get_type_hints(self.__class__, vars(module))
|
||||||
cls = type_hints[field.name]
|
return type_hints[field_name]
|
||||||
|
|
||||||
|
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
|
"""Get the message class for a field from the type hints."""
|
||||||
|
cls = self._type_hint(field.name)
|
||||||
if hasattr(cls, "__args__") and index >= 0:
|
if hasattr(cls, "__args__") and index >= 0:
|
||||||
cls = type_hints[field.name].__args__[index]
|
cls = cls.__args__[index]
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||||
t = self._cls_for(field, index=-1)
|
t = self._type_hint(field.name)
|
||||||
|
|
||||||
value: Any = 0
|
value: Any = 0
|
||||||
if meta.proto_type == TYPE_MAP:
|
if hasattr(t, "__origin__"):
|
||||||
# Maps cannot be repeated, so we check these first.
|
if t.__origin__ == dict:
|
||||||
|
# This is some kind of map (dict in Python).
|
||||||
value = {}
|
value = {}
|
||||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
elif t.__origin__ == list:
|
||||||
# Anything else with type args is a list.
|
# This is some kind of list (repeated) field.
|
||||||
value = []
|
value = []
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||||
# Message means creating an instance of the right type.
|
# This is an optional (wrapped) field. For setting the default we
|
||||||
value = t()
|
# really don't care what kind of field it is.
|
||||||
|
value = None
|
||||||
else:
|
else:
|
||||||
value = get_default(meta.proto_type)
|
value = t()
|
||||||
|
elif issubclass(t, Enum):
|
||||||
|
# Enums always default to zero.
|
||||||
|
value = 0
|
||||||
|
else:
|
||||||
|
# This is either a primitive scalar or another message type. Calling
|
||||||
|
# it should result in its zero value.
|
||||||
|
value = t()
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -659,7 +659,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
if v:
|
if v:
|
||||||
output[cased_name] = v
|
output[cased_name] = v
|
||||||
elif v != get_default(meta.proto_type):
|
elif v != self._get_field_default(field, meta):
|
||||||
if meta.proto_type in INT_64_TYPES:
|
if meta.proto_type in INT_64_TYPES:
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [str(n) for n in v]
|
output[cased_name] = [str(n) for n in v]
|
||||||
|
@ -35,11 +35,15 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
|||||||
Return a Python type name for a proto type reference. Adds the import if
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
necessary.
|
necessary.
|
||||||
"""
|
"""
|
||||||
|
# If the package name is a blank string, then this should still work
|
||||||
|
# because by convention packages are lowercase and message/enum types are
|
||||||
|
# pascal-cased. May require refactoring in the future.
|
||||||
type_name = type_name.lstrip(".")
|
type_name = type_name.lstrip(".")
|
||||||
if type_name.startswith(package):
|
if type_name.startswith(package):
|
||||||
|
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||||
|
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||||
# This is the current package, which has nested types flattened.
|
# This is the current package, which has nested types flattened.
|
||||||
# foo.bar_thing => FooBarThing
|
# foo.bar_thing => FooBarThing
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
|
||||||
cased = [stringcase.pascalcase(part) for part in parts]
|
cased = [stringcase.pascalcase(part) for part in parts]
|
||||||
type_name = f'"{"".join(cased)}"'
|
type_name = f'"{"".join(cased)}"'
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user