diff --git a/Pipfile b/Pipfile index 18f83ea..63f99a3 100644 --- a/Pipfile +++ b/Pipfile @@ -18,5 +18,6 @@ jinja2 = "*" python_version = "3.7" [scripts] +plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output" generate = "python betterproto/tests/generate.py" test = "pytest ./betterproto/tests" diff --git a/README.md b/README.md index d813b82..9da3c77 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i This project is heavily inspired by, and borrows functionality from: +- https://github.com/protocolbuffers/protobuf/tree/master/python - https://github.com/eigenein/protobuf/ - https://github.com/vmagamedov/grpclib @@ -27,8 +28,8 @@ This project is heavily inspired by, and borrows functionality from: - [x] Maps - [x] Maps of message fields - [ ] Support passthrough of unknown fields -- [ ] Refs to nested types -- [ ] Imports in proto files +- [x] Refs to nested types +- [x] Imports in proto files - [ ] Well-known Google types - [ ] JSON that isn't completely naive. - [ ] Async service stubs diff --git a/betterproto/__init__.py b/betterproto/__init__.py index ed44e87..f9b6f15 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -92,6 +92,18 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +def get_default(proto_type: int) -> Any: + """Get the default (zero value) for a given type.""" + return { + TYPE_BOOL: False, + TYPE_FLOAT: 0.0, + TYPE_DOUBLE: 0.0, + TYPE_STRING: "", + TYPE_BYTES: b"", + TYPE_MAP: {}, + }.get(proto_type, 0) + + @dataclasses.dataclass(frozen=True) class FieldMetadata: """Stores internal metadata used for parsing & serialization.""" @@ -114,7 +126,7 @@ class FieldMetadata: def dataclass_field( number: int, proto_type: str, - default: Any, + default: Any = None, map_types: Optional[Tuple[str, str]] = None, **kwargs: dict, ) -> dataclasses.Field: @@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: return dataclass_field(number, TYPE_ENUM, default=default) +def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any: + return dataclass_field(number, TYPE_BOOL, default=default) + + def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: return dataclass_field(number, TYPE_INT32, default=default) @@ -193,8 +209,12 @@ def string_field(number: int, default: str = "") -> Any: return dataclass_field(number, TYPE_STRING, default=default) -def message_field(number: int, default: Type["Message"]) -> Any: - return dataclass_field(number, TYPE_MESSAGE, default=default) +def bytes_field(number: int, default: bytes = b"") -> Any: + return dataclass_field(number, TYPE_BYTES, default=default) + + +def message_field(number: int) -> Any: + return dataclass_field(number, TYPE_MESSAGE) def map_field(number: int, key_type: str, value_type: str) -> Any: @@ -345,6 +365,29 @@ class Message(ABC): to go between Python, binary and JSON protobuf message representations. """ + def __post_init__(self) -> None: + # Set a default value for each field in the class after `__init__` has + # already been run. + for field in dataclasses.fields(self): + meta = FieldMetadata.get(field) + + t = self._cls_for(field, index=-1) + + value = 0 + if meta.proto_type == TYPE_MAP: + # Maps cannot be repeated, so we check these first. + value = {} + elif hasattr(t, "__args__") and len(t.__args__) == 1: + # Anything else with type args is a list. + value = [] + elif meta.proto_type == TYPE_MESSAGE: + # Message means creating an instance of the right type. + value = t() + else: + value = get_default(meta.proto_type) + + setattr(self, field.name, value) + def __bytes__(self) -> bytes: """ Get the binary encoded Protobuf representation of this instance. @@ -356,6 +399,7 @@ class Message(ABC): if isinstance(value, list): if not len(value): + # Empty values are not serialized continue if meta.proto_type in PACKED_TYPES: @@ -371,6 +415,7 @@ class Message(ABC): output += _serialize_single(meta.number, meta.proto_type, item) elif isinstance(value, dict): if not len(value): + # Empty values are not serialized continue for k, v in value.items(): @@ -378,7 +423,8 @@ class Message(ABC): sv = _serialize_single(2, meta.map_types[1], v) output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: - if value == field.default: + if value == get_default(meta.proto_type): + # Default (zero) values are not serialized continue output += _serialize_single(meta.number, meta.proto_type, value) @@ -390,7 +436,7 @@ class Message(ABC): module = inspect.getmodule(self) type_hints = get_type_hints(self, vars(module)) cls = type_hints[field.name] - if hasattr(cls, "__args__"): + if hasattr(cls, "__args__") and index >= 0: cls = type_hints[field.name].__args__[index] return cls @@ -522,7 +568,7 @@ class Message(ABC): """ for field in dataclasses.fields(self): meta = FieldMetadata.get(field) - if field.name in value: + if field.name in value and value[field.name] is not None: if meta.proto_type == "message": v = getattr(self, field.name) # print(v, value[field.name]) diff --git a/betterproto/templates/main.py b/betterproto/templates/main.py index 332a144..08b5dc1 100644 --- a/betterproto/templates/main.py +++ b/betterproto/templates/main.py @@ -7,6 +7,10 @@ from dataclasses import dataclass from typing import Dict, List import betterproto +{% for i in description.imports %} + +{{ i }} +{% endfor %} {% if description.enums %}{% for enum in description.enums %} @@ -21,9 +25,9 @@ class {{ enum.name }}(enum.IntEnum): {% endif %} {{ entry.name }} = {{ entry.value }} {% endfor %} + + {% endfor %} - - {% endif %} {% for message in description.messages %} @dataclass @@ -36,8 +40,11 @@ class {{ message.name }}(betterproto.Message): {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}) + {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}) {% endfor %} + {% if not message.properties %} + pass + {% endif %} {% endfor %} diff --git a/betterproto/tests/ref.json b/betterproto/tests/ref.json new file mode 100644 index 0000000..2c6bdc1 --- /dev/null +++ b/betterproto/tests/ref.json @@ -0,0 +1,5 @@ +{ + "greeting": { + "greeting": "hello" + } +} diff --git a/betterproto/tests/ref.proto b/betterproto/tests/ref.proto new file mode 100644 index 0000000..6945590 --- /dev/null +++ b/betterproto/tests/ref.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package ref; + +import "repeatedmessage.proto"; + +message Test { + repeatedmessage.Sub greeting = 1; +} diff --git a/betterproto/tests/repeatedmessage.proto b/betterproto/tests/repeatedmessage.proto index ea4c01f..0ffacaf 100644 --- a/betterproto/tests/repeatedmessage.proto +++ b/betterproto/tests/repeatedmessage.proto @@ -1,5 +1,7 @@ syntax = "proto3"; +package repeatedmessage; + message Test { repeated Sub greetings = 1; } diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index 49b7a44..18d8d6c 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -2,7 +2,7 @@ import importlib import pytest import json -from generate import get_files, get_base +from .generate import get_files, get_base inputs = get_files(".bin") @@ -10,7 +10,7 @@ inputs = get_files(".bin") @pytest.mark.parametrize("filename", inputs) def test_sample(filename: str) -> None: module = get_base(filename).split("-")[0] - imported = importlib.import_module(module) + imported = importlib.import_module(f"betterproto.tests.{module}") data_binary = open(filename, "rb").read() data_dict = json.loads(open(filename.replace(".bin", ".json")).read()) t1 = imported.Test().parse(data_binary) diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index f779199..86f35f0 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -22,33 +22,46 @@ from jinja2 import Environment, PackageLoader def py_type( - message: DescriptorProto, descriptor: FieldDescriptorProto -) -> Tuple[str, str]: + package: str, + imports: set, + message: DescriptorProto, + descriptor: FieldDescriptorProto, +) -> str: if descriptor.type in [1, 2, 6, 7, 15, 16]: - return "float", descriptor.default_value + return "float" elif descriptor.type in [3, 4, 5, 13, 17, 18]: - return "int", descriptor.default_value + return "int" elif descriptor.type == 8: - return "bool", descriptor.default_value.capitalize() + return "bool" elif descriptor.type == 9: - default = "" - if descriptor.default_value: - default = f'"{descriptor.default_value}"' - return "str", default - elif descriptor.type == 11: - # Type referencing another defined Message - # print(descriptor.type_name, file=sys.stderr) - # message_type = descriptor.type_name.replace(".", "") - message_type = descriptor.type_name.split(".").pop() - return f'"{message_type}"', f"lambda: {message_type}()" + return "str" + elif descriptor.type in [11, 14]: + # Type referencing another defined Message or a named enum + message_type = descriptor.type_name.lstrip(".") + if message_type.startswith(package): + # This is the current package, which has nested types flattened. + message_type = ( + f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"' + ) + + if "." in message_type: + # This is imported from another package. No need + # to use a forward ref and we need to add the import. + parts = message_type.split(".") + imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") + message_type = f"{parts[-2]}.{parts[-1]}" + + # print( + # descriptor.name, + # package, + # descriptor.type_name, + # message_type, + # file=sys.stderr, + # ) + + return message_type elif descriptor.type == 12: - default = "" - if descriptor.default_value: - default = f'b"{descriptor.default_value}"' - return "bytes", default - elif descriptor.type == 14: - # print(descriptor.type_name, file=sys.stderr) - return descriptor.type_name.split(".").pop(), 0 + return "bytes" else: raise NotImplementedError(f"Unknown type {descriptor.type}") @@ -64,6 +77,8 @@ def traverse(proto_file): if item.nested_type: for n, p in _traverse(path + [i, 3], item.nested_type): + # Adjust the name since we flatten the heirarchy. + n.name = item.name + n.name yield n, p return itertools.chain( @@ -85,6 +100,7 @@ def get_comment(proto_file, path: List[int]) -> str: else: # This is a class if len(lines) == 1 and len(lines[0]) < 70: + lines[0] = lines[0].strip('"') return f' """{lines[0]}"""' else: return f' """\n{" ".join(lines)}\n """' @@ -100,115 +116,160 @@ def generate_code(request, response): ) template = env.get_template("main.py") + # TODO: Refactor below to generate a single file per package if packages + # are being used, otherwise one output for each input. Figure out how to + # set up relative imports when needed and change the Message type refs to + # use the import names when not in the current module. + output_map = {} for proto_file in request.proto_file: - # print(proto_file.message_type, file=sys.stderr) - # print(proto_file.source_code_info, file=sys.stderr) - output = { - "package": proto_file.package, - "filename": proto_file.name, - "messages": [], - "enums": [], - } + out = proto_file.package + if not out: + out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") - # Parse request - for item, path in traverse(proto_file): - # print(item, file=sys.stderr) - # print(path, file=sys.stderr) - data = {"name": item.name} + if out not in output_map: + output_map[out] = {"package": proto_file.package, "files": []} + output_map[out]["files"].append(proto_file) - if isinstance(item, DescriptorProto): + # TODO: Figure out how to handle gRPC request/response messages and add + # processing below for Service. + + for filename, options in output_map.items(): + package = options["package"] + # print(package, filename, file=sys.stderr) + output = {"package": package, "imports": set(), "messages": [], "enums": []} + + for proto_file in options["files"]: + # print(proto_file.message_type, file=sys.stderr) + # print(proto_file.service, file=sys.stderr) + # print(proto_file.source_code_info, file=sys.stderr) + + for item, path in traverse(proto_file): # print(item, file=sys.stderr) - if item.options.map_entry: - # Skip generated map entry messages since we just use dicts - continue + # print(path, file=sys.stderr) + data = {"name": item.name} - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } - ) + if isinstance(item, DescriptorProto): + # print(item, file=sys.stderr) + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + continue - for i, f in enumerate(item.field): - t, zero = py_type(item, f) - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop() - map_entry = f"{f.name.capitalize()}Entry" - - if message_type == map_entry: - for nested in item.nested_type: - if nested.name == map_entry: - if nested.options.map_entry: - print("Found a map!", file=sys.stderr) - k, _ = py_type(item, nested.field[0]) - v, _ = py_type(item, nested.field[1]) - t = f"Dict[{k}, {v}]" - zero = "dict" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "list" - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - data["properties"].append( + data.update( { - "name": f.name, - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, + "type": "Message", + "comment": get_comment(proto_file, path), + "properties": [], } ) - # print(f, file=sys.stderr) - output["messages"].append(data) + for i, f in enumerate(item.field): + t = py_type(package, output["imports"], item, f) - elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ + repeated = False + packed = False + + field_type = f.Type.Name(f.type).lower()[5:] + map_types = None + if f.type == 11: + # This might be a map... + message_type = f.type_name.split(".").pop() + map_entry = f"{f.name.capitalize()}Entry" + + if message_type == map_entry: + for nested in item.nested_type: + if nested.name == map_entry: + if nested.options.map_entry: + # print("Found a map!", file=sys.stderr) + k = py_type( + package, + output["imports"], + item, + nested.field[0], + ) + v = py_type( + package, + output["imports"], + item, + nested.field[1], + ) + t = f"Dict[{k}, {v}]" + field_type = "map" + map_types = ( + f.Type.Name(nested.field[0].type), + f.Type.Name(nested.field[1].type), + ) + + if f.label == 3 and field_type != "map": + # Repeated field + repeated = True + t = f"List[{t}]" + + if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: + packed = True + + data["properties"].append( { - "name": v.name, - "value": v.number, + "name": f.name, + "number": f.number, "comment": get_comment(proto_file, path + [2, i]), + "proto_type": int(f.type), + "field_type": field_type, + "map_types": map_types, + "type": t, + "repeated": repeated, + "packed": packed, } - for i, v in enumerate(item.value) - ], - } - ) + ) + # print(f, file=sys.stderr) - output["enums"].append(data) + output["messages"].append(data) + + elif isinstance(item, EnumDescriptorProto): + # print(item.name, path, file=sys.stderr) + data.update( + { + "type": "Enum", + "comment": get_comment(proto_file, path), + "entries": [ + { + "name": v.name, + "value": v.number, + "comment": get_comment(proto_file, path + [2, i]), + } + for i, v in enumerate(item.value) + ], + } + ) + + output["enums"].append(data) + + output["imports"] = sorted(output["imports"]) # Fill response f = response.file.add() - f.name = os.path.splitext(proto_file.name)[0] + ".py" + # print(filename, file=sys.stderr) + f.name = filename.replace(".", os.path.sep) + ".py" + # f.content = json.dumps(output, indent=2) f.content = template.render(description=output).rstrip("\n") + "\n" + inits = set([""]) + for f in response.file: + # Ensure output paths exist + print(f.name, file=sys.stderr) + dirnames = os.path.dirname(f.name) + if dirnames: + os.makedirs(dirnames, exist_ok=True) + base = "" + for part in dirnames.split(os.path.sep): + base = os.path.join(base, part) + inits.add(base) + + for base in inits: + init = response.file.add() + init.name = os.path.join(base, "__init__.py") + init.content = b"" + if __name__ == "__main__": # Read request message from stdin