Fix compilation of fields with name identical to their type (#294)
* Revert "Fix compilation of fields named 'bytes' or 'str' (#226)" This reverts commit deb623ed14cea65f0a0d17e9c770426d71198ae0. * Fix compilation of fileds with name identical to their type * Added test for field-name identical to python type Co-authored-by: Guy Szweigman <guysz@nvidia.com>
This commit is contained in:
parent
a4d2d39546
commit
b0a36d12e4
@ -133,16 +133,6 @@ def lowercase_first(value: str) -> str:
|
|||||||
return value[0:1].lower() + value[1:]
|
return value[0:1].lower() + value[1:]
|
||||||
|
|
||||||
|
|
||||||
def is_reserved_name(value: str) -> bool:
|
|
||||||
if keyword.iskeyword(value):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if value in ("bytes", "str"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(value: str) -> str:
|
def sanitize_name(value: str) -> str:
|
||||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||||
return f"{value}_" if is_reserved_name(value) else value
|
return f"{value}_" if keyword.iskeyword(value) else value
|
||||||
|
@ -30,6 +30,7 @@ reference to `A` to `B`'s `fields` attribute.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import builtins
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto import which_one_of
|
from betterproto import which_one_of
|
||||||
from betterproto.casing import sanitize_name
|
from betterproto.casing import sanitize_name
|
||||||
@ -237,6 +238,7 @@ class OutputTemplate:
|
|||||||
imports: Set[str] = field(default_factory=set)
|
imports: Set[str] = field(default_factory=set)
|
||||||
datetime_imports: Set[str] = field(default_factory=set)
|
datetime_imports: Set[str] = field(default_factory=set)
|
||||||
typing_imports: Set[str] = field(default_factory=set)
|
typing_imports: Set[str] = field(default_factory=set)
|
||||||
|
builtins_import: bool = False
|
||||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||||
@ -268,6 +270,8 @@ class OutputTemplate:
|
|||||||
imports = set()
|
imports = set()
|
||||||
if any(x for x in self.messages if any(x.deprecated_fields)):
|
if any(x for x in self.messages if any(x.deprecated_fields)):
|
||||||
imports.add("warnings")
|
imports.add("warnings")
|
||||||
|
if self.builtins_import:
|
||||||
|
imports.add("builtins")
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
|
|||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
deprecated: bool = field(default=False, init=False)
|
deprecated: bool = field(default=False, init=False)
|
||||||
|
builtins_types: Set[str] = field(default_factory=set)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Add message to output file
|
# Add message to output file
|
||||||
@ -376,6 +381,8 @@ class FieldCompiler(MessageCompiler):
|
|||||||
betterproto_field_type = (
|
betterproto_field_type = (
|
||||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||||
)
|
)
|
||||||
|
if self.py_name in dir(builtins):
|
||||||
|
self.parent.builtins_types.add(self.py_name)
|
||||||
return f"{name}{annotations} = {betterproto_field_type}"
|
return f"{name}{annotations} = {betterproto_field_type}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -408,9 +415,16 @@ class FieldCompiler(MessageCompiler):
|
|||||||
imports.add("Dict")
|
imports.add("Dict")
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_builtins(self) -> bool:
|
||||||
|
return self.py_type in self.parent.builtins_types or (
|
||||||
|
self.py_type == self.py_name and self.py_name in dir(builtins)
|
||||||
|
)
|
||||||
|
|
||||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||||
output_file.datetime_imports.update(self.datetime_imports)
|
output_file.datetime_imports.update(self.datetime_imports)
|
||||||
output_file.typing_imports.update(self.typing_imports)
|
output_file.typing_imports.update(self.typing_imports)
|
||||||
|
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_wraps(self) -> Optional[str]:
|
def field_wraps(self) -> Optional[str]:
|
||||||
@ -504,9 +518,12 @@ class FieldCompiler(MessageCompiler):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def annotation(self) -> str:
|
def annotation(self) -> str:
|
||||||
|
py_type = self.py_type
|
||||||
|
if self.use_builtins:
|
||||||
|
py_type = f"builtins.{py_type}"
|
||||||
if self.repeated:
|
if self.repeated:
|
||||||
return f"List[{self.py_type}]"
|
return f"List[{py_type}]"
|
||||||
return self.py_type
|
return py_type
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"int": 26,
|
||||||
|
"float": 26.0,
|
||||||
|
"str": "value-for-str",
|
||||||
|
"bytes": "001a",
|
||||||
|
"bool": true
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
// Tests that messages may contain fields with names that are identical to their python types (PR #294)
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
int32 int = 1;
|
||||||
|
float float = 2;
|
||||||
|
string str = 3;
|
||||||
|
bytes bytes = 4;
|
||||||
|
bool bool = 5;
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user