Fix compilation of fields with name identical to their type (#294)
* Revert "Fix compilation of fields named 'bytes' or 'str' (#226)"
This reverts commit deb623ed14.
* Fix compilation of fileds with name identical to their type
* Added test for field-name identical to python type
Co-authored-by: Guy Szweigman <guysz@nvidia.com>
			
			
This commit is contained in:
		| @@ -30,6 +30,7 @@ reference to `A` to `B`'s `fields` attribute. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import builtins | ||||
| import betterproto | ||||
| from betterproto import which_one_of | ||||
| from betterproto.casing import sanitize_name | ||||
| @@ -237,6 +238,7 @@ class OutputTemplate: | ||||
|     imports: Set[str] = field(default_factory=set) | ||||
|     datetime_imports: Set[str] = field(default_factory=set) | ||||
|     typing_imports: Set[str] = field(default_factory=set) | ||||
|     builtins_import: bool = False | ||||
|     messages: List["MessageCompiler"] = field(default_factory=list) | ||||
|     enums: List["EnumDefinitionCompiler"] = field(default_factory=list) | ||||
|     services: List["ServiceCompiler"] = field(default_factory=list) | ||||
| @@ -268,6 +270,8 @@ class OutputTemplate: | ||||
|         imports = set() | ||||
|         if any(x for x in self.messages if any(x.deprecated_fields)): | ||||
|             imports.add("warnings") | ||||
|         if self.builtins_import: | ||||
|             imports.add("builtins") | ||||
|         return imports | ||||
|  | ||||
|  | ||||
| @@ -283,6 +287,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|         default_factory=list | ||||
|     ) | ||||
|     deprecated: bool = field(default=False, init=False) | ||||
|     builtins_types: Set[str] = field(default_factory=set) | ||||
|  | ||||
|     def __post_init__(self) -> None: | ||||
|         # Add message to output file | ||||
| @@ -376,6 +381,8 @@ class FieldCompiler(MessageCompiler): | ||||
|         betterproto_field_type = ( | ||||
|             f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})" | ||||
|         ) | ||||
|         if self.py_name in dir(builtins): | ||||
|             self.parent.builtins_types.add(self.py_name) | ||||
|         return f"{name}{annotations} = {betterproto_field_type}" | ||||
|  | ||||
|     @property | ||||
| @@ -408,9 +415,16 @@ class FieldCompiler(MessageCompiler): | ||||
|             imports.add("Dict") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def use_builtins(self) -> bool: | ||||
|         return self.py_type in self.parent.builtins_types or ( | ||||
|             self.py_type == self.py_name and self.py_name in dir(builtins) | ||||
|         ) | ||||
|  | ||||
|     def add_imports_to(self, output_file: OutputTemplate) -> None: | ||||
|         output_file.datetime_imports.update(self.datetime_imports) | ||||
|         output_file.typing_imports.update(self.typing_imports) | ||||
|         output_file.builtins_import = output_file.builtins_import or self.use_builtins | ||||
|  | ||||
|     @property | ||||
|     def field_wraps(self) -> Optional[str]: | ||||
| @@ -504,9 +518,12 @@ class FieldCompiler(MessageCompiler): | ||||
|  | ||||
|     @property | ||||
|     def annotation(self) -> str: | ||||
|         py_type = self.py_type | ||||
|         if self.use_builtins: | ||||
|             py_type = f"builtins.{py_type}" | ||||
|         if self.repeated: | ||||
|             return f"List[{self.py_type}]" | ||||
|         return self.py_type | ||||
|             return f"List[{py_type}]" | ||||
|         return py_type | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
|   | ||||
		Reference in New Issue
	
	Block a user