Generate __init__.py files

This commit is contained in:
Daniel G. Taylor 2019-10-12 10:16:05 -07:00
parent dcb7102d92
commit 130acfffa3
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
3 changed files with 28 additions and 5 deletions

View File

@ -18,6 +18,6 @@ jinja2 = "*"
python_version = "3.7"
[scripts]
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=."
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output"
generate = "python betterproto/tests/generate.py"
test = "pytest ./betterproto/tests"

View File

@ -209,6 +209,10 @@ def string_field(number: int, default: str = "") -> Any:
return dataclass_field(number, TYPE_STRING, default=default)
def bytes_field(number: int, default: bytes = b"") -> Any:
return dataclass_field(number, TYPE_BYTES, default=default)
def message_field(number: int) -> Any:
return dataclass_field(number, TYPE_MESSAGE)

View File

@ -40,12 +40,13 @@ def py_type(
message_type = descriptor.type_name.lstrip(".")
if message_type.startswith(package):
# This is the current package, which has nested types flattened.
message_type = message_type.lstrip(package).lstrip(".").replace(".", "")
message_type = (
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
)
if "." in message_type:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
message_type = message_type.strip('"')
parts = message_type.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
message_type = f"{parts[-2]}.{parts[-1]}"
@ -58,7 +59,7 @@ def py_type(
# file=sys.stderr,
# )
return f'"{message_type}"'
return message_type
elif descriptor.type == 12:
return "bytes"
else:
@ -247,10 +248,28 @@ def generate_code(request, response):
# Fill response
f = response.file.add()
# print(filename, file=sys.stderr)
f.name = filename + ".py"
f.name = filename.replace(".", os.path.sep) + ".py"
# f.content = json.dumps(output, indent=2)
f.content = template.render(description=output).rstrip("\n") + "\n"
inits = set([""])
for f in response.file:
# Ensure output paths exist
print(f.name, file=sys.stderr)
dirnames = os.path.dirname(f.name)
if dirnames:
os.makedirs(dirnames, exist_ok=True)
base = ""
for part in dirnames.split(os.path.sep):
base = os.path.join(base, part)
inits.add(base)
for base in inits:
init = response.file.add()
init.name = os.path.join(base, "__init__.py")
init.content = b""
if __name__ == "__main__":
# Read request message from stdin