Generate/test refactoring
This commit is contained in:
@@ -8,6 +8,7 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
import subprocess
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from google.protobuf.json_format import Parse
|
||||
@@ -17,17 +18,37 @@ from google.protobuf.descriptor_pool import DescriptorPool
|
||||
root = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def get_files(end: str) -> Generator[Tuple[str, str], None, None]:
|
||||
def get_files(end: str) -> Generator[str, None, None]:
|
||||
for r, dirs, files in os.walk(root):
|
||||
for filename in [f for f in files if f.endswith(end)]:
|
||||
parts = os.path.splitext(filename)[0].split("-")
|
||||
yield [parts[0], os.path.join(r, filename)]
|
||||
yield os.path.join(r, filename)
|
||||
|
||||
|
||||
def get_base(filename: str) -> str:
|
||||
return os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
|
||||
def ensure_ext(filename: str, ext: str) -> str:
|
||||
if not filename.endswith(ext):
|
||||
return filename + ext
|
||||
return filename
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.chdir(root)
|
||||
|
||||
for base, filename in get_files(".proto"):
|
||||
if len(sys.argv) > 1:
|
||||
proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]]
|
||||
bases = {get_base(f) for f in proto_files}
|
||||
json_files = [
|
||||
f for f in get_files(".json") if get_base(f).split("-")[0] in bases
|
||||
]
|
||||
else:
|
||||
proto_files = get_files(".proto")
|
||||
json_files = get_files(".json")
|
||||
|
||||
for filename in proto_files:
|
||||
print(f"Generatinng code for {os.path.basename(filename)}")
|
||||
subprocess.run(
|
||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||
)
|
||||
@@ -36,12 +57,16 @@ if __name__ == "__main__":
|
||||
shell=True,
|
||||
)
|
||||
|
||||
for base, filename in get_files(".json"):
|
||||
for filename in json_files:
|
||||
# Reset the internal symbol database so we can import the `Test` message
|
||||
# multiple times. Ugh.
|
||||
sym = symbol_database.Default()
|
||||
sym.pool = DescriptorPool()
|
||||
imported = importlib.import_module(f"{base}_pb2")
|
||||
|
||||
parts = get_base(filename).split("-")
|
||||
out = filename.replace(".json", ".bin")
|
||||
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
||||
|
||||
imported = importlib.import_module(f"{parts[0]}_pb2")
|
||||
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
||||
open(out, "wb").write(serialized)
|
||||
|
||||
Reference in New Issue
Block a user