Fix typing and datetime imports not being present for service method type annotations (#183)

This commit is contained in:
Vladimir Solomatin 2021-03-13 00:15:15 +03:00 committed by GitHub
parent 8a215367ad
commit 2f62189346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 14 deletions

View File

@ -324,17 +324,7 @@ class FieldCompiler(MessageCompiler):
# Add field to message
self.parent.fields.append(self)
# Check for new imports
annotation = self.annotation
if "Optional[" in annotation:
self.output_file.typing_imports.add("Optional")
if "List[" in annotation:
self.output_file.typing_imports.add("List")
if "Dict[" in annotation:
self.output_file.typing_imports.add("Dict")
if "timedelta" in annotation:
self.output_file.datetime_imports.add("timedelta")
if "datetime" in annotation:
self.output_file.datetime_imports.add("datetime")
self.add_imports_to(self.output_file)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
def get_field_string(self, indent: int = 4) -> str:
@ -356,6 +346,33 @@ class FieldCompiler(MessageCompiler):
args.append(f"wraps={self.field_wraps}")
return args
@property
def datetime_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
# FIXME: false positives - e.g. `MyDatetimedelta`
if "timedelta" in annotation:
imports.add("timedelta")
if "datetime" in annotation:
imports.add("datetime")
return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
@property
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
@ -577,11 +594,10 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service
self.parent.methods.append(self)
# Check for Optional import
# Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
if f.default_value_string == "None":
self.output_file.typing_imports.add("Optional")
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering

View File

@ -14,6 +14,7 @@ services = {
"googletypes_response",
"googletypes_response_embedded",
"service",
"service_separate_packages",
"import_service_input_message",
"googletypes_service_returns_empty",
"googletypes_service_returns_googletype",

View File

@ -0,0 +1,31 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
package things.messages;
message DoThingRequest {
string name = 1;
// use `repeated` so we can check if `List` is correctly imported
repeated string comments = 2;
// use google types `timestamp` and `duration` so we can check
// if everything from `datetime` is correctly imported
google.protobuf.Timestamp when = 3;
google.protobuf.Duration duration = 4;
}
message DoThingResponse {
repeated string names = 1;
}
message GetThingRequest {
string name = 1;
}
message GetThingResponse {
string name = 1;
int32 version = 2;
}

View File

@ -0,0 +1,12 @@
syntax = "proto3";
import "messages.proto";
package things.service;
service Test {
rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
}