Generate __init__.py files
This commit is contained in:
		
							
								
								
									
										2
									
								
								Pipfile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Pipfile
									
									
									
									
									
								
							| @@ -18,6 +18,6 @@ jinja2 = "*" | |||||||
| python_version = "3.7" | python_version = "3.7" | ||||||
|  |  | ||||||
| [scripts] | [scripts] | ||||||
| plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=." | plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output" | ||||||
| generate = "python betterproto/tests/generate.py" | generate = "python betterproto/tests/generate.py" | ||||||
| test = "pytest ./betterproto/tests" | test = "pytest ./betterproto/tests" | ||||||
|   | |||||||
| @@ -209,6 +209,10 @@ def string_field(number: int, default: str = "") -> Any: | |||||||
|     return dataclass_field(number, TYPE_STRING, default=default) |     return dataclass_field(number, TYPE_STRING, default=default) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bytes_field(number: int, default: bytes = b"") -> Any: | ||||||
|  |     return dataclass_field(number, TYPE_BYTES, default=default) | ||||||
|  |  | ||||||
|  |  | ||||||
| def message_field(number: int) -> Any: | def message_field(number: int) -> Any: | ||||||
|     return dataclass_field(number, TYPE_MESSAGE) |     return dataclass_field(number, TYPE_MESSAGE) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -40,12 +40,13 @@ def py_type( | |||||||
|         message_type = descriptor.type_name.lstrip(".") |         message_type = descriptor.type_name.lstrip(".") | ||||||
|         if message_type.startswith(package): |         if message_type.startswith(package): | ||||||
|             # This is the current package, which has nested types flattened. |             # This is the current package, which has nested types flattened. | ||||||
|             message_type = message_type.lstrip(package).lstrip(".").replace(".", "") |             message_type = ( | ||||||
|  |                 f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"' | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if "." in message_type: |         if "." in message_type: | ||||||
|             # This is imported from another package. No need |             # This is imported from another package. No need | ||||||
|             # to use a forward ref and we need to add the import. |             # to use a forward ref and we need to add the import. | ||||||
|             message_type = message_type.strip('"') |  | ||||||
|             parts = message_type.split(".") |             parts = message_type.split(".") | ||||||
|             imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") |             imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") | ||||||
|             message_type = f"{parts[-2]}.{parts[-1]}" |             message_type = f"{parts[-2]}.{parts[-1]}" | ||||||
| @@ -58,7 +59,7 @@ def py_type( | |||||||
|         #     file=sys.stderr, |         #     file=sys.stderr, | ||||||
|         # ) |         # ) | ||||||
|  |  | ||||||
|         return f'"{message_type}"' |         return message_type | ||||||
|     elif descriptor.type == 12: |     elif descriptor.type == 12: | ||||||
|         return "bytes" |         return "bytes" | ||||||
|     else: |     else: | ||||||
| @@ -247,10 +248,28 @@ def generate_code(request, response): | |||||||
|         # Fill response |         # Fill response | ||||||
|         f = response.file.add() |         f = response.file.add() | ||||||
|         # print(filename, file=sys.stderr) |         # print(filename, file=sys.stderr) | ||||||
|         f.name = filename + ".py" |         f.name = filename.replace(".", os.path.sep) + ".py" | ||||||
|  |  | ||||||
|         # f.content = json.dumps(output, indent=2) |         # f.content = json.dumps(output, indent=2) | ||||||
|         f.content = template.render(description=output).rstrip("\n") + "\n" |         f.content = template.render(description=output).rstrip("\n") + "\n" | ||||||
|  |  | ||||||
|  |     inits = set([""]) | ||||||
|  |     for f in response.file: | ||||||
|  |         # Ensure output paths exist | ||||||
|  |         print(f.name, file=sys.stderr) | ||||||
|  |         dirnames = os.path.dirname(f.name) | ||||||
|  |         if dirnames: | ||||||
|  |             os.makedirs(dirnames, exist_ok=True) | ||||||
|  |             base = "" | ||||||
|  |             for part in dirnames.split(os.path.sep): | ||||||
|  |                 base = os.path.join(base, part) | ||||||
|  |                 inits.add(base) | ||||||
|  |  | ||||||
|  |     for base in inits: | ||||||
|  |         init = response.file.add() | ||||||
|  |         init.name = os.path.join(base, "__init__.py") | ||||||
|  |         init.content = b"" | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # Read request message from stdin |     # Read request message from stdin | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user