Generate/test refactoring
This commit is contained in:
parent
1f46e10ba7
commit
1a488faf7a
@ -5,6 +5,7 @@
|
|||||||
- [x] Zig-zag signed fields (sint32, sint64)
|
- [x] Zig-zag signed fields (sint32, sint64)
|
||||||
- [x] Don't encode zero values for nested types
|
- [x] Don't encode zero values for nested types
|
||||||
- [x] Enums
|
- [x] Enums
|
||||||
|
- [ ] Repeated message fields
|
||||||
- [ ] Maps
|
- [ ] Maps
|
||||||
- [ ] Support passthrough of unknown fields
|
- [ ] Support passthrough of unknown fields
|
||||||
- [ ] Refs to nested types
|
- [ ] Refs to nested types
|
||||||
|
@ -303,7 +303,11 @@ def _postprocess_single(
|
|||||||
if meta.proto_type in ["string"]:
|
if meta.proto_type in ["string"]:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type in ["message"]:
|
elif meta.proto_type in ["message"]:
|
||||||
value = field.default_factory().parse(value)
|
orig = value
|
||||||
|
value = field.default_factory()
|
||||||
|
if isinstance(value, Message):
|
||||||
|
# If it's a message (instead of e.g. list) then keep going!
|
||||||
|
value.parse(orig)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import importlib
|
import importlib
|
||||||
|
import sys
|
||||||
from typing import Generator, Tuple
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
from google.protobuf.json_format import Parse
|
from google.protobuf.json_format import Parse
|
||||||
@ -17,17 +18,37 @@ from google.protobuf.descriptor_pool import DescriptorPool
|
|||||||
root = os.path.dirname(os.path.realpath(__file__))
|
root = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
|
||||||
def get_files(end: str) -> Generator[Tuple[str, str], None, None]:
|
def get_files(end: str) -> Generator[str, None, None]:
|
||||||
for r, dirs, files in os.walk(root):
|
for r, dirs, files in os.walk(root):
|
||||||
for filename in [f for f in files if f.endswith(end)]:
|
for filename in [f for f in files if f.endswith(end)]:
|
||||||
parts = os.path.splitext(filename)[0].split("-")
|
yield os.path.join(r, filename)
|
||||||
yield [parts[0], os.path.join(r, filename)]
|
|
||||||
|
|
||||||
|
def get_base(filename: str) -> str:
|
||||||
|
return os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_ext(filename: str, ext: str) -> str:
|
||||||
|
if not filename.endswith(ext):
|
||||||
|
return filename + ext
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.chdir(root)
|
os.chdir(root)
|
||||||
|
|
||||||
for base, filename in get_files(".proto"):
|
if len(sys.argv) > 1:
|
||||||
|
proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]]
|
||||||
|
bases = {get_base(f) for f in proto_files}
|
||||||
|
json_files = [
|
||||||
|
f for f in get_files(".json") if get_base(f).split("-")[0] in bases
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
proto_files = get_files(".proto")
|
||||||
|
json_files = get_files(".json")
|
||||||
|
|
||||||
|
for filename in proto_files:
|
||||||
|
print(f"Generatinng code for {os.path.basename(filename)}")
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||||
)
|
)
|
||||||
@ -36,12 +57,16 @@ if __name__ == "__main__":
|
|||||||
shell=True,
|
shell=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for base, filename in get_files(".json"):
|
for filename in json_files:
|
||||||
# Reset the internal symbol database so we can import the `Test` message
|
# Reset the internal symbol database so we can import the `Test` message
|
||||||
# multiple times. Ugh.
|
# multiple times. Ugh.
|
||||||
sym = symbol_database.Default()
|
sym = symbol_database.Default()
|
||||||
sym.pool = DescriptorPool()
|
sym.pool = DescriptorPool()
|
||||||
imported = importlib.import_module(f"{base}_pb2")
|
|
||||||
|
parts = get_base(filename).split("-")
|
||||||
out = filename.replace(".json", ".bin")
|
out = filename.replace(".json", ".bin")
|
||||||
|
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
||||||
|
|
||||||
|
imported = importlib.import_module(f"{parts[0]}_pb2")
|
||||||
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
||||||
open(out, "wb").write(serialized)
|
open(out, "wb").write(serialized)
|
||||||
|
@ -2,22 +2,30 @@ import importlib
|
|||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from generate import get_files
|
from generate import get_files, get_base
|
||||||
|
|
||||||
inputs = get_files(".bin")
|
inputs = get_files(".bin")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name,filename", inputs)
|
@pytest.mark.parametrize("filename", inputs)
|
||||||
def test_sample(name: str, filename: str) -> None:
|
def test_sample(filename: str) -> None:
|
||||||
imported = importlib.import_module(name)
|
module = get_base(filename).split("-")[0]
|
||||||
|
imported = importlib.import_module(module)
|
||||||
data_binary = open(filename, "rb").read()
|
data_binary = open(filename, "rb").read()
|
||||||
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
|
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
|
||||||
t1 = imported.Test().parse(data_binary)
|
t1 = imported.Test().parse(data_binary)
|
||||||
t2 = imported.Test().from_dict(data_dict)
|
t2 = imported.Test().from_dict(data_dict)
|
||||||
print(t1)
|
print(t1)
|
||||||
print(t2)
|
print(t2)
|
||||||
|
|
||||||
|
# Equality should automagically work for dataclasses!
|
||||||
assert t1 == t2
|
assert t1 == t2
|
||||||
|
|
||||||
|
# Generally this can't be relied on, but here we are aiming to match the
|
||||||
|
# existing Python implementation and aren't doing anything tricky.
|
||||||
|
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||||
assert bytes(t1) == data_binary
|
assert bytes(t1) == data_binary
|
||||||
assert bytes(t2) == data_binary
|
assert bytes(t2) == data_binary
|
||||||
|
|
||||||
assert t1.to_dict() == data_dict
|
assert t1.to_dict() == data_dict
|
||||||
assert t2.to_dict() == data_dict
|
assert t2.to_dict() == data_dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user