Generate __init__.py files

This commit is contained in:
Daniel G. Taylor 2019-10-12 10:16:05 -07:00
parent dcb7102d92
commit 130acfffa3
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
3 changed files with 28 additions and 5 deletions

View File

@ -18,6 +18,6 @@ jinja2 = "*"
python_version = "3.7" python_version = "3.7"
[scripts] [scripts]
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=." plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output"
generate = "python betterproto/tests/generate.py" generate = "python betterproto/tests/generate.py"
test = "pytest ./betterproto/tests" test = "pytest ./betterproto/tests"

View File

@ -209,6 +209,10 @@ def string_field(number: int, default: str = "") -> Any:
return dataclass_field(number, TYPE_STRING, default=default) return dataclass_field(number, TYPE_STRING, default=default)
def bytes_field(number: int, default: bytes = b"") -> Any:
return dataclass_field(number, TYPE_BYTES, default=default)
def message_field(number: int) -> Any: def message_field(number: int) -> Any:
return dataclass_field(number, TYPE_MESSAGE) return dataclass_field(number, TYPE_MESSAGE)

View File

@ -40,12 +40,13 @@ def py_type(
message_type = descriptor.type_name.lstrip(".") message_type = descriptor.type_name.lstrip(".")
if message_type.startswith(package): if message_type.startswith(package):
# This is the current package, which has nested types flattened. # This is the current package, which has nested types flattened.
message_type = message_type.lstrip(package).lstrip(".").replace(".", "") message_type = (
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
)
if "." in message_type: if "." in message_type:
# This is imported from another package. No need # This is imported from another package. No need
# to use a forward ref and we need to add the import. # to use a forward ref and we need to add the import.
message_type = message_type.strip('"')
parts = message_type.split(".") parts = message_type.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
message_type = f"{parts[-2]}.{parts[-1]}" message_type = f"{parts[-2]}.{parts[-1]}"
@ -58,7 +59,7 @@ def py_type(
# file=sys.stderr, # file=sys.stderr,
# ) # )
return f'"{message_type}"' return message_type
elif descriptor.type == 12: elif descriptor.type == 12:
return "bytes" return "bytes"
else: else:
@ -247,10 +248,28 @@ def generate_code(request, response):
# Fill response # Fill response
f = response.file.add() f = response.file.add()
# print(filename, file=sys.stderr) # print(filename, file=sys.stderr)
f.name = filename + ".py" f.name = filename.replace(".", os.path.sep) + ".py"
# f.content = json.dumps(output, indent=2) # f.content = json.dumps(output, indent=2)
f.content = template.render(description=output).rstrip("\n") + "\n" f.content = template.render(description=output).rstrip("\n") + "\n"
inits = set([""])
for f in response.file:
# Ensure output paths exist
print(f.name, file=sys.stderr)
dirnames = os.path.dirname(f.name)
if dirnames:
os.makedirs(dirnames, exist_ok=True)
base = ""
for part in dirnames.split(os.path.sep):
base = os.path.join(base, part)
inits.add(base)
for base in inits:
init = response.file.add()
init.name = os.path.join(base, "__init__.py")
init.content = b""
if __name__ == "__main__": if __name__ == "__main__":
# Read request message from stdin # Read request message from stdin