From 1a488faf7a4c7065f48246e72aac7afc1f3b9a9f Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Wed, 9 Oct 2019 17:21:29 -0700 Subject: [PATCH] Generate/test refactoring --- README.md | 1 + betterproto/__init__.py | 6 +++++- betterproto/tests/generate.py | 37 ++++++++++++++++++++++++++------ betterproto/tests/test_inputs.py | 16 ++++++++++---- 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 565c91a..897b31d 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ - [x] Zig-zag signed fields (sint32, sint64) - [x] Don't encode zero values for nested types - [x] Enums +- [ ] Repeated message fields - [ ] Maps - [ ] Support passthrough of unknown fields - [ ] Refs to nested types diff --git a/betterproto/__init__.py b/betterproto/__init__.py index d256c3e..966b9a6 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -303,7 +303,11 @@ def _postprocess_single( if meta.proto_type in ["string"]: value = value.decode("utf-8") elif meta.proto_type in ["message"]: - value = field.default_factory().parse(value) + orig = value + value = field.default_factory() + if isinstance(value, Message): + # If it's a message (instead of e.g. list) then keep going! + value.parse(orig) return value diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index c1cfbde..5fd037e 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -8,6 +8,7 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" import subprocess import importlib +import sys from typing import Generator, Tuple from google.protobuf.json_format import Parse @@ -17,17 +18,37 @@ from google.protobuf.descriptor_pool import DescriptorPool root = os.path.dirname(os.path.realpath(__file__)) -def get_files(end: str) -> Generator[Tuple[str, str], None, None]: +def get_files(end: str) -> Generator[str, None, None]: for r, dirs, files in os.walk(root): for filename in [f for f in files if f.endswith(end)]: - parts = os.path.splitext(filename)[0].split("-") - yield [parts[0], os.path.join(r, filename)] + yield os.path.join(r, filename) + + +def get_base(filename: str) -> str: + return os.path.splitext(os.path.basename(filename))[0] + + +def ensure_ext(filename: str, ext: str) -> str: + if not filename.endswith(ext): + return filename + ext + return filename if __name__ == "__main__": os.chdir(root) - for base, filename in get_files(".proto"): + if len(sys.argv) > 1: + proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]] + bases = {get_base(f) for f in proto_files} + json_files = [ + f for f in get_files(".json") if get_base(f).split("-")[0] in bases + ] + else: + proto_files = get_files(".proto") + json_files = get_files(".json") + + for filename in proto_files: + print(f"Generatinng code for {os.path.basename(filename)}") subprocess.run( f"protoc --python_out=. {os.path.basename(filename)}", shell=True ) @@ -36,12 +57,16 @@ if __name__ == "__main__": shell=True, ) - for base, filename in get_files(".json"): + for filename in json_files: # Reset the internal symbol database so we can import the `Test` message # multiple times. Ugh. sym = symbol_database.Default() sym.pool = DescriptorPool() - imported = importlib.import_module(f"{base}_pb2") + + parts = get_base(filename).split("-") out = filename.replace(".json", ".bin") + print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") + + imported = importlib.import_module(f"{parts[0]}_pb2") serialized = Parse(open(filename).read(), imported.Test()).SerializeToString() open(out, "wb").write(serialized) diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index 1ac7f6f..49b7a44 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -2,22 +2,30 @@ import importlib import pytest import json -from generate import get_files +from generate import get_files, get_base inputs = get_files(".bin") -@pytest.mark.parametrize("name,filename", inputs) -def test_sample(name: str, filename: str) -> None: - imported = importlib.import_module(name) +@pytest.mark.parametrize("filename", inputs) +def test_sample(filename: str) -> None: + module = get_base(filename).split("-")[0] + imported = importlib.import_module(module) data_binary = open(filename, "rb").read() data_dict = json.loads(open(filename.replace(".bin", ".json")).read()) t1 = imported.Test().parse(data_binary) t2 = imported.Test().from_dict(data_dict) print(t1) print(t2) + + # Equality should automagically work for dataclasses! assert t1 == t2 + + # Generally this can't be relied on, but here we are aiming to match the + # existing Python implementation and aren't doing anything tricky. + # https://developers.google.com/protocol-buffers/docs/encoding#implications assert bytes(t1) == data_binary assert bytes(t2) == data_binary + assert t1.to_dict() == data_dict assert t2.to_dict() == data_dict