Raise AttributeError on attempts to access unset oneof fields (#510)

This commit is contained in:
Alexander Khabarov 2023-07-21 13:26:30 +01:00 committed by GitHub
parent 098989e9e9
commit 6faac1d1ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 116 additions and 29 deletions

View File

@ -8,9 +8,10 @@ repos:
- id: isort - id: isort
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.3.0 rev: 23.1.0
hooks: hooks:
- id: black - id: black
args: ["--target-version", "py310"]
- repo: https://github.com/PyCQA/doc8 - repo: https://github.com/PyCQA/doc8
rev: 0.10.1 rev: 0.10.1

2
poetry.lock generated
View File

@ -1858,4 +1858,4 @@ compiler = ["black", "isort", "jinja2"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.7" python-versions = "^3.7"
content-hash = "62d298634665ebd06f69ec8ea543c3d7720184ec9d833c32575de8d965332aec" content-hash = "8f733a72705d31633a7f198a7a7dd6e3170876a1ccb8ca75b7d94b6379384a8f"

View File

@ -13,7 +13,7 @@ packages = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
black = { version = ">=19.3b0", optional = true } black = { version = ">=23.1.0", optional = true }
grpclib = "^0.4.1" grpclib = "^0.4.1"
importlib-metadata = { version = ">=1.6.0", python = "<3.8" } importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true } jinja2 = { version = ">=3.0.3", optional = true }
@ -62,7 +62,7 @@ cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy" help = "Check types with mypy"
[tool.poe.tasks.format] [tool.poe.tasks.format]
cmd = "black . --exclude tests/output_" cmd = "black . --exclude tests/output_ --target-version py310"
help = "Apply black formatting to source code" help = "Apply black formatting to source code"
[tool.poe.tasks.docs] [tool.poe.tasks.docs]

View File

@ -693,8 +693,28 @@ class Message(ABC):
def __getattribute__(self, name: str) -> Any: def __getattribute__(self, name: str) -> Any:
""" """
Lazily initialize default values to avoid infinite recursion for recursive Lazily initialize default values to avoid infinite recursion for recursive
message types message types.
Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
""" """
try:
group_current = super().__getattribute__("_group_current")
except AttributeError:
pass
else:
if name not in {"__class__", "_betterproto"}:
group = self._betterproto.oneof_group_by_field.get(name)
if group is not None and group_current[group] != name:
if sys.version_info < (3, 10):
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}"
)
else:
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}",
name=name,
obj=self,
)
value = super().__getattribute__(name) value = super().__getattribute__(name)
if value is not PLACEHOLDER: if value is not PLACEHOLDER:
return value return value
@ -761,7 +781,10 @@ class Message(ABC):
""" """
output = bytearray() output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
value = getattr(self, field_name) value = getattr(self, field_name)
except AttributeError:
continue
if value is None: if value is None:
# Optional items should be skipped. This is used for the Google # Optional items should be skipped. This is used for the Google
@ -775,9 +798,7 @@ class Message(ABC):
# Note that proto3 field presence/optional fields are put in a # Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we # synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value. # send the value even if the value is the default zero value.
selected_in_group = ( selected_in_group = bool(meta.group)
meta.group and self._group_current[meta.group] == field_name
)
# Empty messages can still be sent on the wire if they were # Empty messages can still be sent on the wire if they were
# set (or received empty). # set (or received empty).
@ -1016,7 +1037,12 @@ class Message(ABC):
parsed.wire_type, meta, field_name, parsed.value parsed.wire_type, meta, field_name, parsed.value
) )
try:
current = getattr(self, field_name) current = getattr(self, field_name)
except AttributeError:
current = self._get_field_default(field_name)
setattr(self, field_name, current)
if meta.proto_type == TYPE_MAP: if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map. # Value represents a single key/value pair entry in the map.
current[value.key] = value.value current[value.key] = value.value
@ -1077,7 +1103,10 @@ class Message(ABC):
defaults = self._betterproto.default_gen defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list field_is_repeated = defaults[field_name] is list
try:
value = getattr(self, field_name) value = getattr(self, field_name)
except AttributeError:
value = self._get_field_default(field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE: if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime): if isinstance(value, datetime):
@ -1209,7 +1238,7 @@ class Message(ABC):
if value[key] is not None: if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE: if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name) v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name] cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list): if isinstance(v, list):
if cls == datetime: if cls == datetime:
@ -1486,7 +1515,6 @@ class Message(ABC):
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
for group, field_set in group_to_one_ofs.items(): for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1: if len(field_set) == 1:
(field,) = field_set (field,) = field_set
field_name = field.name field_name = field.name

View File

@ -21,7 +21,6 @@ class ServiceBase(ABC):
stream: grpclib.server.Stream, stream: grpclib.server.Stream,
request: Any, request: Any,
) -> None: ) -> None:
response_iter = handler(request) response_iter = handler(request)
# check if response is actually an AsyncIterator # check if response is actually an AsyncIterator
# this might be false if the method just returns without # this might be false if the method just returns without

View File

@ -21,7 +21,6 @@ from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str: def outputfile_compiler(output_file: OutputTemplate) -> str:
templates_folder = os.path.abspath( templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates") os.path.join(os.path.dirname(__file__), "..", "templates")
) )

View File

@ -159,7 +159,6 @@ def _make_one_of_field_compiler(
proto_obj: "FieldDescriptorProto", proto_obj: "FieldDescriptorProto",
path: List[int], path: List[int],
) -> FieldCompiler: ) -> FieldCompiler:
pydantic = output_package.pydantic_dataclasses pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls( return Cls(

View File

@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():
# None of these fields were explicitly set BUT they should not actually be null # None of these fields were explicitly set BUT they should not actually be null
# themselves # themselves
assert isinstance(message.foo, Foo) assert not hasattr(message, "foo")
assert isinstance(message2.foo, Foo) assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
assert not hasattr(message2, "foo")
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER
assert isinstance(message_reference.foo, ReferenceFoo) assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo) assert isinstance(message_reference2.foo, ReferenceFoo)

View File

@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
) )
assert message.move == Move( assert not hasattr(message, "move")
x=0, y=0 assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
) # Proto3 will default this as there is no null
assert message.signal == Signal.PASS assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)
@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
message.from_json( message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
) )
assert message.move == Move( assert not hasattr(message, "move")
x=0, y=0 assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
) # Proto3 will default this as there is no null
assert message.signal == Signal.RESIGN assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0].json) message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3) assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS assert not hasattr(message, "signal")
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@ -0,0 +1,46 @@
from dataclasses import dataclass
import pytest
import betterproto
def test_oneof_pattern_matching():
@dataclass
class Sub(betterproto.Message):
val: int = betterproto.int32_field(1)
@dataclass
class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1")
sub: Sub = betterproto.message_field(3, group="group2")
abc: str = betterproto.string_field(4, group="group2")
foo = Foo(baz="test1", abc="test2")
match foo:
case Foo(bar=_):
pytest.fail("Matched 'bar' instead of 'baz'")
case Foo(baz=v):
assert v == "test1"
case _:
pytest.fail("Matched neither 'bar' nor 'baz'")
match foo:
case Foo(sub=_):
pytest.fail("Matched 'sub' instead of 'abc'")
case Foo(abc=v):
assert v == "test2"
case _:
pytest.fail("Matched neither 'sub' nor 'abc'")
foo.sub = Sub(val=1)
match foo:
case Foo(sub=Sub(val=v)):
assert v == 1
case Foo(abc=v):
pytest.fail("Matched 'abc' instead of 'sub'")
case _:
pytest.fail("Matched neither 'sub' nor 'abc'")

View File

@ -1,4 +1,5 @@
import json import json
import sys
from copy import ( from copy import (
copy, copy,
deepcopy, deepcopy,
@ -18,6 +19,8 @@ from typing import (
Optional, Optional,
) )
import pytest
import betterproto import betterproto
@ -151,17 +154,18 @@ def test_oneof_support():
foo.baz = "test" foo.baz = "test"
# Other oneof fields should now be unset # Other oneof fields should now be unset
assert foo.bar == 0 assert not hasattr(foo, "bar")
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group1")[0] == "baz" assert betterproto.which_one_of(foo, "group1")[0] == "baz"
foo.sub.val = 1 foo.sub = Sub(val=1)
assert betterproto.serialized_on_wire(foo.sub) assert betterproto.serialized_on_wire(foo.sub)
foo.abc = "test" foo.abc = "test"
# Group 1 shouldn't be touched, group 2 should have reset # Group 1 shouldn't be touched, group 2 should have reset
assert foo.sub.val == 0 assert not hasattr(foo, "sub")
assert betterproto.serialized_on_wire(foo.sub) is False assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group2")[0] == "abc" assert betterproto.which_one_of(foo, "group2")[0] == "abc"
# Zero value should always serialize for one-of # Zero value should always serialize for one-of
@ -176,6 +180,16 @@ def test_oneof_support():
assert betterproto.which_one_of(foo2, "group2")[0] == "" assert betterproto.which_one_of(foo2, "group2")[0] == ""
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="pattern matching is only supported in python3.10+",
)
def test_oneof_pattern_matching():
from .oneof_pattern_matching import test_oneof_pattern_matching
test_oneof_pattern_matching()
def test_json_casing(): def test_json_casing():
@dataclass @dataclass
class CasingTest(betterproto.Message): class CasingTest(betterproto.Message):