Better JSON casing support, renaming messages/fields
This commit is contained in:
@@ -24,6 +24,7 @@ from typing import (
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
import stringcase
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
@@ -101,6 +102,13 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = stringcase.camelcase
|
||||
SNAKE = stringcase.snakecase
|
||||
|
||||
|
||||
class _PLACEHOLDER:
|
||||
pass
|
||||
|
||||
@@ -624,48 +632,50 @@ class Message(ABC):
|
||||
def FromString(cls: Type[T], data: bytes) -> T:
|
||||
return cls().parse(data)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
|
||||
"""
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON.
|
||||
used to serialize to e.g. JSON. Defaults to camel casing for
|
||||
compatibility but can be set to other modes.
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name)
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
output[field.name] = v
|
||||
output[cased_name] = v
|
||||
elif v._serialized_on_wire:
|
||||
output[field.name] = v.to_dict()
|
||||
output[cased_name] = v.to_dict()
|
||||
elif meta.proto_type == "map":
|
||||
for k in v:
|
||||
if hasattr(v[k], "to_dict"):
|
||||
v[k] = v[k].to_dict()
|
||||
|
||||
if v:
|
||||
output[field.name] = v
|
||||
output[cased_name] = v
|
||||
elif v != get_default(meta.proto_type):
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [str(n) for n in v]
|
||||
output[cased_name] = [str(n) for n in v]
|
||||
else:
|
||||
output[field.name] = str(v)
|
||||
output[cased_name] = str(v)
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [b64encode(b).decode("utf8") for b in v]
|
||||
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
|
||||
else:
|
||||
output[field.name] = b64encode(v).decode("utf8")
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(self._cls_for(field))
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [enum_values[e].name for e in v]
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
else:
|
||||
output[field.name] = enum_values[v].name
|
||||
output[cased_name] = enum_values[v].name
|
||||
else:
|
||||
output[field.name] = v
|
||||
output[cased_name] = v
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: dict) -> T:
|
||||
@@ -674,44 +684,49 @@ class Message(ABC):
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
self._serialized_on_wire = True
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value and value[field.name] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
# print(v, value[field.name])
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[field.name])):
|
||||
v.append(cls().from_dict(value[field.name][i]))
|
||||
else:
|
||||
v.from_dict(value[field.name])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[field.name]:
|
||||
v[k] = cls().from_dict(value[field.name][k])
|
||||
else:
|
||||
v = value[field.name]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [int(n) for n in value[field.name]]
|
||||
else:
|
||||
v = int(value[field.name])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [b64decode(n) for n in value[field.name]]
|
||||
else:
|
||||
v = b64decode(value[field.name])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
||||
for key in value:
|
||||
snake_cased = stringcase.snakecase(key)
|
||||
if snake_cased in fields_by_name:
|
||||
field = fields_by_name[snake_cased]
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
# print(v, value[key])
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
return self
|
||||
|
||||
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
||||
|
||||
@@ -16,6 +16,8 @@ except ImportError:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
import stringcase
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
@@ -26,12 +28,6 @@ from google.protobuf.descriptor_pb2 import (
|
||||
)
|
||||
|
||||
|
||||
def snake_case(value: str) -> str:
|
||||
return (
|
||||
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
|
||||
)
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
@@ -40,12 +36,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
type_name = type_name.lstrip(".")
|
||||
if type_name.startswith(package):
|
||||
# This is the current package, which has nested types flattened.
|
||||
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
|
||||
# foo.bar_thing => FooBarThing
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
parts = type_name.split(".")
|
||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||
|
||||
@@ -179,7 +179,7 @@ def generate_code(request, response):
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name}
|
||||
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
@@ -255,6 +255,7 @@ def generate_code(request, response):
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
"py_name": stringcase.snakecase(f.name),
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
@@ -294,6 +295,7 @@ def generate_code(request, response):
|
||||
|
||||
data = {
|
||||
"name": service.name,
|
||||
"py_name": stringcase.pascalcase(service.name),
|
||||
"comment": get_comment(proto_file, [6, i]),
|
||||
"methods": [],
|
||||
}
|
||||
@@ -317,7 +319,7 @@ def generate_code(request, response):
|
||||
data["methods"].append(
|
||||
{
|
||||
"name": method.name,
|
||||
"py_name": snake_case(method.name),
|
||||
"py_name": stringcase.snakecase(method.name),
|
||||
"comment": get_comment(proto_file, [6, i, 2, j]),
|
||||
"route": f"/{package}.{service.name}/{method.name}",
|
||||
"input": get_ref_type(
|
||||
|
||||
@@ -18,7 +18,7 @@ import grpclib
|
||||
|
||||
|
||||
{% if description.enums %}{% for enum in description.enums %}
|
||||
class {{ enum.name }}(betterproto.Enum):
|
||||
class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% if enum.comment %}
|
||||
{{ enum.comment }}
|
||||
|
||||
@@ -35,7 +35,7 @@ class {{ enum.name }}(betterproto.Enum):
|
||||
{% endif %}
|
||||
{% for message in description.messages %}
|
||||
@dataclass
|
||||
class {{ message.name }}(betterproto.Message):
|
||||
class {{ message.py_name }}(betterproto.Message):
|
||||
{% if message.comment %}
|
||||
{{ message.comment }}
|
||||
|
||||
@@ -44,7 +44,7 @@ class {{ message.name }}(betterproto.Message):
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
|
||||
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
|
||||
{% endfor %}
|
||||
{% if not message.properties %}
|
||||
pass
|
||||
@@ -53,7 +53,7 @@ class {{ message.name }}(betterproto.Message):
|
||||
|
||||
{% endfor %}
|
||||
{% for service in description.services %}
|
||||
class {{ service.name }}Stub(betterproto.ServiceStub):
|
||||
class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% if service.comment %}
|
||||
{{ service.comment }}
|
||||
|
||||
|
||||
4
betterproto/tests/casing.json
Normal file
4
betterproto/tests/casing.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"camelCase": 1,
|
||||
"snakeCase": "ONE"
|
||||
}
|
||||
12
betterproto/tests/casing.proto
Normal file
12
betterproto/tests/casing.proto
Normal file
@@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
enum my_enum {
|
||||
ZERO = 0;
|
||||
ONE = 1;
|
||||
TWO = 2;
|
||||
}
|
||||
|
||||
message Test {
|
||||
int32 camelCase = 1;
|
||||
my_enum snake_case = 2;
|
||||
}
|
||||
@@ -72,7 +72,8 @@ if __name__ == "__main__":
|
||||
input_json = open(filename).read()
|
||||
parsed = Parse(input_json, imported.Test())
|
||||
serialized = parsed.SerializeToString()
|
||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=True)
|
||||
preserve = "casing" not in filename
|
||||
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
|
||||
|
||||
s_loaded = json.loads(serialized_json)
|
||||
in_loaded = json.loads(input_json)
|
||||
|
||||
@@ -115,3 +115,34 @@ def test_oneof_support():
|
||||
assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
|
||||
assert foo.bar == 0
|
||||
assert betterproto.which_one_of(foo2, "group2")[0] == ""
|
||||
|
||||
|
||||
def test_json_casing():
|
||||
@dataclass
|
||||
class CasingTest(betterproto.Message):
|
||||
pascal_case: int = betterproto.int32_field(1)
|
||||
camel_case: int = betterproto.int32_field(2)
|
||||
snake_case: int = betterproto.int32_field(3)
|
||||
kabob_case: int = betterproto.int32_field(4)
|
||||
|
||||
# Parsing should accept almost any input
|
||||
test = CasingTest().from_dict(
|
||||
{"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
|
||||
)
|
||||
|
||||
assert test == CasingTest(1, 2, 3, 4)
|
||||
|
||||
# Serializing should be strict.
|
||||
assert test.to_dict() == {
|
||||
"pascalCase": 1,
|
||||
"camelCase": 2,
|
||||
"snakeCase": 3,
|
||||
"kabobCase": 4,
|
||||
}
|
||||
|
||||
assert test.to_dict(casing=betterproto.Casing.SNAKE) == {
|
||||
"pascal_case": 1,
|
||||
"camel_case": 2,
|
||||
"snake_case": 3,
|
||||
"kabob_case": 4,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user