Fix typing and datetime imports not being present for service method type annotations (#183)
				
					
				
			This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							8a215367ad
						
					
				
				
					commit
					2f62189346
				
			| @@ -324,17 +324,7 @@ class FieldCompiler(MessageCompiler): | ||||
|         # Add field to message | ||||
|         self.parent.fields.append(self) | ||||
|         # Check for new imports | ||||
|         annotation = self.annotation | ||||
|         if "Optional[" in annotation: | ||||
|             self.output_file.typing_imports.add("Optional") | ||||
|         if "List[" in annotation: | ||||
|             self.output_file.typing_imports.add("List") | ||||
|         if "Dict[" in annotation: | ||||
|             self.output_file.typing_imports.add("Dict") | ||||
|         if "timedelta" in annotation: | ||||
|             self.output_file.datetime_imports.add("timedelta") | ||||
|         if "datetime" in annotation: | ||||
|             self.output_file.datetime_imports.add("datetime") | ||||
|         self.add_imports_to(self.output_file) | ||||
|         super().__post_init__()  # call FieldCompiler-> MessageCompiler __post_init__ | ||||
|  | ||||
|     def get_field_string(self, indent: int = 4) -> str: | ||||
| @@ -356,6 +346,33 @@ class FieldCompiler(MessageCompiler): | ||||
|             args.append(f"wraps={self.field_wraps}") | ||||
|         return args | ||||
|  | ||||
|     @property | ||||
|     def datetime_imports(self) -> Set[str]: | ||||
|         imports = set() | ||||
|         annotation = self.annotation | ||||
|         # FIXME: false positives - e.g. `MyDatetimedelta` | ||||
|         if "timedelta" in annotation: | ||||
|             imports.add("timedelta") | ||||
|         if "datetime" in annotation: | ||||
|             imports.add("datetime") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def typing_imports(self) -> Set[str]: | ||||
|         imports = set() | ||||
|         annotation = self.annotation | ||||
|         if "Optional[" in annotation: | ||||
|             imports.add("Optional") | ||||
|         if "List[" in annotation: | ||||
|             imports.add("List") | ||||
|         if "Dict[" in annotation: | ||||
|             imports.add("Dict") | ||||
|         return imports | ||||
|  | ||||
|     def add_imports_to(self, output_file: OutputTemplate) -> None: | ||||
|         output_file.datetime_imports.update(self.datetime_imports) | ||||
|         output_file.typing_imports.update(self.typing_imports) | ||||
|  | ||||
|     @property | ||||
|     def field_wraps(self) -> Optional[str]: | ||||
|         """Returns betterproto wrapped field type or None.""" | ||||
| @@ -577,11 +594,10 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|         # Add method to service | ||||
|         self.parent.methods.append(self) | ||||
|  | ||||
|         # Check for Optional import | ||||
|         # Check for imports | ||||
|         if self.py_input_message: | ||||
|             for f in self.py_input_message.fields: | ||||
|                 if f.default_value_string == "None": | ||||
|                     self.output_file.typing_imports.add("Optional") | ||||
|                 f.add_imports_to(self.output_file) | ||||
|         if "Optional" in self.py_output_message_type: | ||||
|             self.output_file.typing_imports.add("Optional") | ||||
|         self.mutable_default_args  # ensure this is called before rendering | ||||
|   | ||||
		Reference in New Issue
	
	Block a user