Fix typing
and datetime
imports not being present for service method type annotations (#183)
This commit is contained in:
parent
8a215367ad
commit
2f62189346
@ -324,17 +324,7 @@ class FieldCompiler(MessageCompiler):
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
self.output_file.typing_imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
if "timedelta" in annotation:
|
||||
self.output_file.datetime_imports.add("timedelta")
|
||||
if "datetime" in annotation:
|
||||
self.output_file.datetime_imports.add("datetime")
|
||||
self.add_imports_to(self.output_file)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
def get_field_string(self, indent: int = 4) -> str:
|
||||
@ -356,6 +346,33 @@ class FieldCompiler(MessageCompiler):
|
||||
args.append(f"wraps={self.field_wraps}")
|
||||
return args
|
||||
|
||||
@property
|
||||
def datetime_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
# FIXME: false positives - e.g. `MyDatetimedelta`
|
||||
if "timedelta" in annotation:
|
||||
imports.add("timedelta")
|
||||
if "datetime" in annotation:
|
||||
imports.add("datetime")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def typing_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
@ -577,11 +594,10 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for Optional import
|
||||
# Check for imports
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
if f.default_value_string == "None":
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
f.add_imports_to(self.output_file)
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.mutable_default_args # ensure this is called before rendering
|
||||
|
@ -14,6 +14,7 @@ services = {
|
||||
"googletypes_response",
|
||||
"googletypes_response_embedded",
|
||||
"service",
|
||||
"service_separate_packages",
|
||||
"import_service_input_message",
|
||||
"googletypes_service_returns_empty",
|
||||
"googletypes_service_returns_googletype",
|
||||
|
31
tests/inputs/service_separate_packages/messages.proto
Normal file
31
tests/inputs/service_separate_packages/messages.proto
Normal file
@ -0,0 +1,31 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
package things.messages;
|
||||
|
||||
message DoThingRequest {
|
||||
string name = 1;
|
||||
|
||||
// use `repeated` so we can check if `List` is correctly imported
|
||||
repeated string comments = 2;
|
||||
|
||||
// use google types `timestamp` and `duration` so we can check
|
||||
// if everything from `datetime` is correctly imported
|
||||
google.protobuf.Timestamp when = 3;
|
||||
google.protobuf.Duration duration = 4;
|
||||
}
|
||||
|
||||
message DoThingResponse {
|
||||
repeated string names = 1;
|
||||
}
|
||||
|
||||
message GetThingRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message GetThingResponse {
|
||||
string name = 1;
|
||||
int32 version = 2;
|
||||
}
|
12
tests/inputs/service_separate_packages/service.proto
Normal file
12
tests/inputs/service_separate_packages/service.proto
Normal file
@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "messages.proto";
|
||||
|
||||
package things.service;
|
||||
|
||||
service Test {
|
||||
rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
|
||||
rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
|
||||
rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
|
||||
rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user