Better JSON casing support, renaming messages/fields

This commit is contained in:
Daniel G. Taylor
2019-10-23 15:06:34 -07:00
parent ef0a1bf50c
commit d43d5af5ce
11 changed files with 165 additions and 83 deletions

View File

@@ -24,6 +24,7 @@ from typing import (
import grpclib.client
import grpclib.const
import stringcase
# Proto 3 data types
TYPE_ENUM = "enum"
@@ -101,6 +102,13 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
class Casing(enum.Enum):
"""Casing constants for serialization."""
CAMEL = stringcase.camelcase
SNAKE = stringcase.snakecase
class _PLACEHOLDER:
pass
@@ -624,48 +632,50 @@ class Message(ABC):
def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data)
def to_dict(self) -> dict:
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
"""
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
used to serialize to e.g. JSON. Defaults to camel casing for
compatibility but can be set to other modes.
"""
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
cased_name = casing(field.name)
if meta.proto_type == "message":
if isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
output[field.name] = v
output[cased_name] = v
elif v._serialized_on_wire:
output[field.name] = v.to_dict()
output[cased_name] = v.to_dict()
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict()
if v:
output[field.name] = v
output[cased_name] = v
elif v != get_default(meta.proto_type):
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
output[cased_name] = [str(n) for n in v]
else:
output[field.name] = str(v)
output[cased_name] = str(v)
elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list):
output[field.name] = [b64encode(b).decode("utf8") for b in v]
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
else:
output[field.name] = b64encode(v).decode("utf8")
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field))
if isinstance(v, list):
output[field.name] = [enum_values[e].name for e in v]
output[cased_name] = [enum_values[e].name for e in v]
else:
output[field.name] = enum_values[v].name
output[cased_name] = enum_values[v].name
else:
output[field.name] = v
output[cased_name] = v
return output
def from_dict(self: T, value: dict) -> T:
@@ -674,44 +684,49 @@ class Message(ABC):
returns the instance itself and is therefore assignable and chainable.
"""
self._serialized_on_wire = True
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[field.name])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[field.name])):
v.append(cls().from_dict(value[field.name][i]))
else:
v.from_dict(value[field.name])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
else:
v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[field.name], list):
v = [b64decode(n) for n in value[field.name]]
else:
v = b64decode(value[field.name])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value:
snake_cased = stringcase.snakecase(key)
if snake_cased in fields_by_name:
field = fields_by_name[snake_cased]
meta = FieldMetadata.get(field)
if v is not None:
setattr(self, field.name, v)
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[key])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
if v is not None:
setattr(self, field.name, v)
return self
def to_json(self, indent: Union[None, int, str] = None) -> str: