From 2f621893463ce11812b78905dec83c393d4bf380 Mon Sep 17 00:00:00 2001 From: Vladimir Solomatin Date: Sat, 13 Mar 2021 00:15:15 +0300 Subject: [PATCH] Fix `typing` and `datetime` imports not being present for service method type annotations (#183) --- src/betterproto/plugin/models.py | 44 +++++++++++++------ tests/inputs/config.py | 1 + .../service_separate_packages/messages.proto | 31 +++++++++++++ .../service_separate_packages/service.proto | 12 +++++ 4 files changed, 74 insertions(+), 14 deletions(-) create mode 100644 tests/inputs/service_separate_packages/messages.proto create mode 100644 tests/inputs/service_separate_packages/service.proto diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 09217b9..097c991 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -324,17 +324,7 @@ class FieldCompiler(MessageCompiler): # Add field to message self.parent.fields.append(self) # Check for new imports - annotation = self.annotation - if "Optional[" in annotation: - self.output_file.typing_imports.add("Optional") - if "List[" in annotation: - self.output_file.typing_imports.add("List") - if "Dict[" in annotation: - self.output_file.typing_imports.add("Dict") - if "timedelta" in annotation: - self.output_file.datetime_imports.add("timedelta") - if "datetime" in annotation: - self.output_file.datetime_imports.add("datetime") + self.add_imports_to(self.output_file) super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ def get_field_string(self, indent: int = 4) -> str: @@ -356,6 +346,33 @@ class FieldCompiler(MessageCompiler): args.append(f"wraps={self.field_wraps}") return args + @property + def datetime_imports(self) -> Set[str]: + imports = set() + annotation = self.annotation + # FIXME: false positives - e.g. `MyDatetimedelta` + if "timedelta" in annotation: + imports.add("timedelta") + if "datetime" in annotation: + imports.add("datetime") + return imports + + @property + def typing_imports(self) -> Set[str]: + imports = set() + annotation = self.annotation + if "Optional[" in annotation: + imports.add("Optional") + if "List[" in annotation: + imports.add("List") + if "Dict[" in annotation: + imports.add("Dict") + return imports + + def add_imports_to(self, output_file: OutputTemplate) -> None: + output_file.datetime_imports.update(self.datetime_imports) + output_file.typing_imports.update(self.typing_imports) + @property def field_wraps(self) -> Optional[str]: """Returns betterproto wrapped field type or None.""" @@ -577,11 +594,10 @@ class ServiceMethodCompiler(ProtoContentBase): # Add method to service self.parent.methods.append(self) - # Check for Optional import + # Check for imports if self.py_input_message: for f in self.py_input_message.fields: - if f.default_value_string == "None": - self.output_file.typing_imports.add("Optional") + f.add_imports_to(self.output_file) if "Optional" in self.py_output_message_type: self.output_file.typing_imports.add("Optional") self.mutable_default_args # ensure this is called before rendering diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 7f2024a..f95aad2 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -14,6 +14,7 @@ services = { "googletypes_response", "googletypes_response_embedded", "service", + "service_separate_packages", "import_service_input_message", "googletypes_service_returns_empty", "googletypes_service_returns_googletype", diff --git a/tests/inputs/service_separate_packages/messages.proto b/tests/inputs/service_separate_packages/messages.proto new file mode 100644 index 0000000..add0ed8 --- /dev/null +++ b/tests/inputs/service_separate_packages/messages.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + +package things.messages; + +message DoThingRequest { + string name = 1; + + // use `repeated` so we can check if `List` is correctly imported + repeated string comments = 2; + + // use google types `timestamp` and `duration` so we can check + // if everything from `datetime` is correctly imported + google.protobuf.Timestamp when = 3; + google.protobuf.Duration duration = 4; +} + +message DoThingResponse { + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; +} diff --git a/tests/inputs/service_separate_packages/service.proto b/tests/inputs/service_separate_packages/service.proto new file mode 100644 index 0000000..48acc25 --- /dev/null +++ b/tests/inputs/service_separate_packages/service.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "messages.proto"; + +package things.service; + +service Test { + rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse); + rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse); + rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); + rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); +}