diff --git a/Pipfile b/Pipfile index 53a24fc..0cf0f10 100644 --- a/Pipfile +++ b/Pipfile @@ -14,6 +14,7 @@ rope = "*" protobuf = "*" jinja2 = "*" grpclib = "*" +stringcase = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index 2b39de5..c479fbe 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40" + "sha256": "28c38cd6c4eafb0b9ac9a64cf623145868fdee163111d3b941b34d23011db6ca" }, "pipfile-spec": 6, "requires": { @@ -147,6 +147,13 @@ "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" ], "version": "==1.12.0" + }, + "stringcase": { + "hashes": [ + "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008" + ], + "index": "pypi", + "version": "==1.2.0" } }, "develop": { @@ -159,10 +166,10 @@ }, "attrs": { "hashes": [ - "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", - "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" ], - "version": "==19.2.0" + "version": "==19.3.0" }, "entrypoints": { "hashes": [ @@ -211,26 +218,30 @@ }, "mypy": { "hashes": [ - "sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b", - "sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac", - "sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c", - "sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879", - "sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795", - "sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021", - "sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3", - "sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23", - "sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd", - "sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b", - "sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046" + "sha256:1521c186a3d200c399bd5573c828ea2db1362af7209b2adb1bb8532cea2fb36f", + "sha256:31a046ab040a84a0fc38bc93694876398e62bc9f35eca8ccbf6418b7297f4c00", + "sha256:3b1a411909c84b2ae9b8283b58b48541654b918e8513c20a400bb946aa9111ae", + "sha256:48c8bc99380575deb39f5d3400ebb6a8a1cb5cc669bbba4d3bb30f904e0a0e7d", + "sha256:540c9caa57a22d0d5d3c69047cc9dd0094d49782603eb03069821b41f9e970e9", + "sha256:672e418425d957e276c291930a3921b4a6413204f53fe7c37cad7bc57b9a3391", + "sha256:6ed3b9b3fdc7193ea7aca6f3c20549b377a56f28769783a8f27191903a54170f", + "sha256:9371290aa2cad5ad133e4cdc43892778efd13293406f7340b9ffe99d5ec7c1d9", + "sha256:ace6ac1d0f87d4072f05b5468a084a45b4eda970e4d26704f201e06d47ab2990", + "sha256:b428f883d2b3fe1d052c630642cc6afddd07d5cd7873da948644508be3b9d4a7", + "sha256:d5bf0e6ec8ba346a2cf35cb55bf4adfddbc6b6576fcc9e10863daa523e418dbb", + "sha256:d7574e283f83c08501607586b3167728c58e8442947e027d2d4c7dcd6d82f453", + "sha256:dc889c84241a857c263a2b1cd1121507db7d5b5f5e87e77147097230f374d10b", + "sha256:f4748697b349f373002656bf32fede706a0e713d67bfdcf04edf39b1f61d46eb" ], "index": "pypi", - "version": "==0.730" + "version": "==0.740" }, "mypy-extensions": { "hashes": [ - "sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458" + "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d", + "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8" ], - "version": "==0.4.2" + "version": "==0.4.3" }, "packaging": { "hashes": [ @@ -300,20 +311,25 @@ }, "typed-ast": { "hashes": [ + "sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161", "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", + "sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47", "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", + "sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2", + "sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e", "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", + "sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66", "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" ], "version": "==1.4.0" diff --git a/README.md b/README.md index 8a016e2..836968d 100644 --- a/README.md +++ b/README.md @@ -301,7 +301,7 @@ $ pipenv run tests - [x] Unary-unary - [x] Server streaming response - [ ] Client streaming request -- [ ] Renaming messages and fields to conform to Python name standards +- [x] Renaming messages and fields to conform to Python name standards - [ ] Renaming clashes with language keywords and standard library top-level packages - [x] Python package - [x] Automate running tests diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 5c1075a..a80f6d8 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -24,6 +24,7 @@ from typing import ( import grpclib.client import grpclib.const +import stringcase # Proto 3 data types TYPE_ENUM = "enum" @@ -101,6 +102,13 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +class Casing(enum.Enum): + """Casing constants for serialization.""" + + CAMEL = stringcase.camelcase + SNAKE = stringcase.snakecase + + class _PLACEHOLDER: pass @@ -624,48 +632,50 @@ class Message(ABC): def FromString(cls: Type[T], data: bytes) -> T: return cls().parse(data) - def to_dict(self) -> dict: + def to_dict(self, casing: Casing = Casing.CAMEL) -> dict: """ Returns a dict representation of this message instance which can be - used to serialize to e.g. JSON. + used to serialize to e.g. JSON. Defaults to camel casing for + compatibility but can be set to other modes. """ output: Dict[str, Any] = {} for field in dataclasses.fields(self): meta = FieldMetadata.get(field) v = getattr(self, field.name) + cased_name = casing(field.name) if meta.proto_type == "message": if isinstance(v, list): # Convert each item. v = [i.to_dict() for i in v] - output[field.name] = v + output[cased_name] = v elif v._serialized_on_wire: - output[field.name] = v.to_dict() + output[cased_name] = v.to_dict() elif meta.proto_type == "map": for k in v: if hasattr(v[k], "to_dict"): v[k] = v[k].to_dict() if v: - output[field.name] = v + output[cased_name] = v elif v != get_default(meta.proto_type): if meta.proto_type in INT_64_TYPES: if isinstance(v, list): - output[field.name] = [str(n) for n in v] + output[cased_name] = [str(n) for n in v] else: - output[field.name] = str(v) + output[cased_name] = str(v) elif meta.proto_type == TYPE_BYTES: if isinstance(v, list): - output[field.name] = [b64encode(b).decode("utf8") for b in v] + output[cased_name] = [b64encode(b).decode("utf8") for b in v] else: - output[field.name] = b64encode(v).decode("utf8") + output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: enum_values = list(self._cls_for(field)) if isinstance(v, list): - output[field.name] = [enum_values[e].name for e in v] + output[cased_name] = [enum_values[e].name for e in v] else: - output[field.name] = enum_values[v].name + output[cased_name] = enum_values[v].name else: - output[field.name] = v + output[cased_name] = v return output def from_dict(self: T, value: dict) -> T: @@ -674,44 +684,49 @@ class Message(ABC): returns the instance itself and is therefore assignable and chainable. """ self._serialized_on_wire = True - for field in dataclasses.fields(self): - meta = FieldMetadata.get(field) - if field.name in value and value[field.name] is not None: - if meta.proto_type == "message": - v = getattr(self, field.name) - # print(v, value[field.name]) - if isinstance(v, list): - cls = self._cls_for(field) - for i in range(len(value[field.name])): - v.append(cls().from_dict(value[field.name][i])) - else: - v.from_dict(value[field.name]) - elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field.name) - cls = self._cls_for(field, index=1) - for k in value[field.name]: - v[k] = cls().from_dict(value[field.name][k]) - else: - v = value[field.name] - if meta.proto_type in INT_64_TYPES: - if isinstance(value[field.name], list): - v = [int(n) for n in value[field.name]] - else: - v = int(value[field.name]) - elif meta.proto_type == TYPE_BYTES: - if isinstance(value[field.name], list): - v = [b64decode(n) for n in value[field.name]] - else: - v = b64decode(value[field.name]) - elif meta.proto_type == TYPE_ENUM: - enum_cls = self._cls_for(field) - if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) + fields_by_name = {f.name: f for f in dataclasses.fields(self)} + for key in value: + snake_cased = stringcase.snakecase(key) + if snake_cased in fields_by_name: + field = fields_by_name[snake_cased] + meta = FieldMetadata.get(field) - if v is not None: - setattr(self, field.name, v) + if value[key] is not None: + if meta.proto_type == "message": + v = getattr(self, field.name) + # print(v, value[key]) + if isinstance(v, list): + cls = self._cls_for(field) + for i in range(len(value[key])): + v.append(cls().from_dict(value[key][i])) + else: + v.from_dict(value[key]) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + v = getattr(self, field.name) + cls = self._cls_for(field, index=1) + for k in value[key]: + v[k] = cls().from_dict(value[key][k]) + else: + v = value[key] + if meta.proto_type in INT_64_TYPES: + if isinstance(value[key], list): + v = [int(n) for n in value[key]] + else: + v = int(value[key]) + elif meta.proto_type == TYPE_BYTES: + if isinstance(value[key], list): + v = [b64decode(n) for n in value[key]] + else: + v = b64decode(value[key]) + elif meta.proto_type == TYPE_ENUM: + enum_cls = self._cls_for(field) + if isinstance(v, list): + v = [enum_cls.from_string(e) for e in v] + elif isinstance(v, str): + v = enum_cls.from_string(v) + + if v is not None: + setattr(self, field.name, v) return self def to_json(self, indent: Union[None, int, str] = None) -> str: diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 98c9cf1..cec96a9 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -16,6 +16,8 @@ except ImportError: ) raise SystemExit(1) +import stringcase + from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.descriptor_pb2 import ( DescriptorProto, @@ -26,12 +28,6 @@ from google.protobuf.descriptor_pb2 import ( ) -def snake_case(value: str) -> str: - return ( - re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_") - ) - - def get_ref_type(package: str, imports: set, type_name: str) -> str: """ Return a Python type name for a proto type reference. Adds the import if @@ -40,12 +36,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str: type_name = type_name.lstrip(".") if type_name.startswith(package): # This is the current package, which has nested types flattened. - type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"' + # foo.bar_thing => FooBarThing + parts = type_name.lstrip(package).lstrip(".").split(".") + cased = [stringcase.pascalcase(part) for part in parts] + type_name = f'"{"".join(cased)}"' if "." in type_name: # This is imported from another package. No need # to use a forward ref and we need to add the import. parts = type_name.split(".") + parts[-1] = stringcase.pascalcase(parts[-1]) imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") type_name = f"{parts[-2]}.{parts[-1]}" @@ -179,7 +179,7 @@ def generate_code(request, response): for item, path in traverse(proto_file): # print(item, file=sys.stderr) # print(path, file=sys.stderr) - data = {"name": item.name} + data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)} if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) @@ -255,6 +255,7 @@ def generate_code(request, response): data["properties"].append( { "name": f.name, + "py_name": stringcase.snakecase(f.name), "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), @@ -294,6 +295,7 @@ def generate_code(request, response): data = { "name": service.name, + "py_name": stringcase.pascalcase(service.name), "comment": get_comment(proto_file, [6, i]), "methods": [], } @@ -317,7 +319,7 @@ def generate_code(request, response): data["methods"].append( { "name": method.name, - "py_name": snake_case(method.name), + "py_name": stringcase.snakecase(method.name), "comment": get_comment(proto_file, [6, i, 2, j]), "route": f"/{package}.{service.name}/{method.name}", "input": get_ref_type( diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 8f61ab9..5ae5857 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -18,7 +18,7 @@ import grpclib {% if description.enums %}{% for enum in description.enums %} -class {{ enum.name }}(betterproto.Enum): +class {{ enum.py_name }}(betterproto.Enum): {% if enum.comment %} {{ enum.comment }} @@ -35,7 +35,7 @@ class {{ enum.name }}(betterproto.Enum): {% endif %} {% for message in description.messages %} @dataclass -class {{ message.name }}(betterproto.Message): +class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} @@ -44,7 +44,7 @@ class {{ message.name }}(betterproto.Message): {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) + {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) {% endfor %} {% if not message.properties %} pass @@ -53,7 +53,7 @@ class {{ message.name }}(betterproto.Message): {% endfor %} {% for service in description.services %} -class {{ service.name }}Stub(betterproto.ServiceStub): +class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} {{ service.comment }} diff --git a/betterproto/tests/casing.json b/betterproto/tests/casing.json new file mode 100644 index 0000000..559104b --- /dev/null +++ b/betterproto/tests/casing.json @@ -0,0 +1,4 @@ +{ + "camelCase": 1, + "snakeCase": "ONE" +} diff --git a/betterproto/tests/casing.proto b/betterproto/tests/casing.proto new file mode 100644 index 0000000..4ab37ae --- /dev/null +++ b/betterproto/tests/casing.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +enum my_enum { + ZERO = 0; + ONE = 1; + TWO = 2; +} + +message Test { + int32 camelCase = 1; + my_enum snake_case = 2; +} diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index b09463e..987f2d9 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -72,7 +72,8 @@ if __name__ == "__main__": input_json = open(filename).read() parsed = Parse(input_json, imported.Test()) serialized = parsed.SerializeToString() - serialized_json = MessageToJson(parsed, preserving_proto_field_name=True) + preserve = "casing" not in filename + serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve) s_loaded = json.loads(serialized_json) in_loaded = json.loads(input_json) diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index d542baa..03c3023 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -115,3 +115,34 @@ def test_oneof_support(): assert betterproto.which_one_of(foo2, "group1")[0] == "bar" assert foo.bar == 0 assert betterproto.which_one_of(foo2, "group2")[0] == "" + + +def test_json_casing(): + @dataclass + class CasingTest(betterproto.Message): + pascal_case: int = betterproto.int32_field(1) + camel_case: int = betterproto.int32_field(2) + snake_case: int = betterproto.int32_field(3) + kabob_case: int = betterproto.int32_field(4) + + # Parsing should accept almost any input + test = CasingTest().from_dict( + {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4} + ) + + assert test == CasingTest(1, 2, 3, 4) + + # Serializing should be strict. + assert test.to_dict() == { + "pascalCase": 1, + "camelCase": 2, + "snakeCase": 3, + "kabobCase": 4, + } + + assert test.to_dict(casing=betterproto.Casing.SNAKE) == { + "pascal_case": 1, + "camel_case": 2, + "snake_case": 3, + "kabob_case": 4, + } diff --git a/setup.py b/setup.py index 8e74191..a2fd4d2 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( ), package_data={"betterproto": ["py.typed", "templates/template.py"]}, python_requires=">=3.7", - install_requires=["grpclib"], + install_requires=["grpclib", "stringcase"], extras_require={"compiler": ["jinja2", "protobuf"]}, zip_safe=False, )