diff --git a/Pipfile b/Pipfile index dd34c74..63f99a3 100644 --- a/Pipfile +++ b/Pipfile @@ -18,6 +18,6 @@ jinja2 = "*" python_version = "3.7" [scripts] -plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=." +plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output" generate = "python betterproto/tests/generate.py" test = "pytest ./betterproto/tests" diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 3f4fd00..f9b6f15 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -209,6 +209,10 @@ def string_field(number: int, default: str = "") -> Any: return dataclass_field(number, TYPE_STRING, default=default) +def bytes_field(number: int, default: bytes = b"") -> Any: + return dataclass_field(number, TYPE_BYTES, default=default) + + def message_field(number: int) -> Any: return dataclass_field(number, TYPE_MESSAGE) diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index a7b2e9d..86f35f0 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -40,12 +40,13 @@ def py_type( message_type = descriptor.type_name.lstrip(".") if message_type.startswith(package): # This is the current package, which has nested types flattened. - message_type = message_type.lstrip(package).lstrip(".").replace(".", "") + message_type = ( + f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"' + ) if "." in message_type: # This is imported from another package. No need # to use a forward ref and we need to add the import. - message_type = message_type.strip('"') parts = message_type.split(".") imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") message_type = f"{parts[-2]}.{parts[-1]}" @@ -58,7 +59,7 @@ def py_type( # file=sys.stderr, # ) - return f'"{message_type}"' + return message_type elif descriptor.type == 12: return "bytes" else: @@ -247,10 +248,28 @@ def generate_code(request, response): # Fill response f = response.file.add() # print(filename, file=sys.stderr) - f.name = filename + ".py" + f.name = filename.replace(".", os.path.sep) + ".py" + # f.content = json.dumps(output, indent=2) f.content = template.render(description=output).rstrip("\n") + "\n" + inits = set([""]) + for f in response.file: + # Ensure output paths exist + print(f.name, file=sys.stderr) + dirnames = os.path.dirname(f.name) + if dirnames: + os.makedirs(dirnames, exist_ok=True) + base = "" + for part in dirnames.split(os.path.sep): + base = os.path.join(base, part) + inits.add(base) + + for base in inits: + init = response.file.add() + init.name = os.path.join(base, "__init__.py") + init.content = b"" + if __name__ == "__main__": # Read request message from stdin