Raise AttributeError on attempts to access unset oneof fields (#510)
				
					
				
			This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							098989e9e9
						
					
				
				
					commit
					6faac1d1ca
				
			| @@ -8,9 +8,10 @@ repos: | |||||||
|       - id: isort |       - id: isort | ||||||
|  |  | ||||||
|   - repo: https://github.com/psf/black |   - repo: https://github.com/psf/black | ||||||
|     rev: 22.3.0 |     rev: 23.1.0 | ||||||
|     hooks: |     hooks: | ||||||
|       - id: black |       - id: black | ||||||
|  |         args: ["--target-version", "py310"] | ||||||
|  |  | ||||||
|   - repo: https://github.com/PyCQA/doc8 |   - repo: https://github.com/PyCQA/doc8 | ||||||
|     rev: 0.10.1 |     rev: 0.10.1 | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							| @@ -1858,4 +1858,4 @@ compiler = ["black", "isort", "jinja2"] | |||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = "^3.7" | python-versions = "^3.7" | ||||||
| content-hash = "62d298634665ebd06f69ec8ea543c3d7720184ec9d833c32575de8d965332aec" | content-hash = "8f733a72705d31633a7f198a7a7dd6e3170876a1ccb8ca75b7d94b6379384a8f" | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ packages = [ | |||||||
|  |  | ||||||
| [tool.poetry.dependencies] | [tool.poetry.dependencies] | ||||||
| python = "^3.7" | python = "^3.7" | ||||||
| black = { version = ">=19.3b0", optional = true } | black = { version = ">=23.1.0", optional = true } | ||||||
| grpclib = "^0.4.1" | grpclib = "^0.4.1" | ||||||
| importlib-metadata = { version = ">=1.6.0", python = "<3.8" } | importlib-metadata = { version = ">=1.6.0", python = "<3.8" } | ||||||
| jinja2 = { version = ">=3.0.3", optional = true } | jinja2 = { version = ">=3.0.3", optional = true } | ||||||
| @@ -62,7 +62,7 @@ cmd  = "mypy src --ignore-missing-imports" | |||||||
| help = "Check types with mypy" | help = "Check types with mypy" | ||||||
|  |  | ||||||
| [tool.poe.tasks.format] | [tool.poe.tasks.format] | ||||||
| cmd  = "black . --exclude tests/output_" | cmd  = "black . --exclude tests/output_ --target-version py310" | ||||||
| help = "Apply black formatting to source code" | help = "Apply black formatting to source code" | ||||||
|  |  | ||||||
| [tool.poe.tasks.docs] | [tool.poe.tasks.docs] | ||||||
|   | |||||||
| @@ -693,8 +693,28 @@ class Message(ABC): | |||||||
|         def __getattribute__(self, name: str) -> Any: |         def __getattribute__(self, name: str) -> Any: | ||||||
|             """ |             """ | ||||||
|             Lazily initialize default values to avoid infinite recursion for recursive |             Lazily initialize default values to avoid infinite recursion for recursive | ||||||
|             message types |             message types. | ||||||
|  |             Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields. | ||||||
|             """ |             """ | ||||||
|  |             try: | ||||||
|  |                 group_current = super().__getattribute__("_group_current") | ||||||
|  |             except AttributeError: | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 if name not in {"__class__", "_betterproto"}: | ||||||
|  |                     group = self._betterproto.oneof_group_by_field.get(name) | ||||||
|  |                     if group is not None and group_current[group] != name: | ||||||
|  |                         if sys.version_info < (3, 10): | ||||||
|  |                             raise AttributeError( | ||||||
|  |                                 f"{group!r} is set to {group_current[group]!r}, not {name!r}" | ||||||
|  |                             ) | ||||||
|  |                         else: | ||||||
|  |                             raise AttributeError( | ||||||
|  |                                 f"{group!r} is set to {group_current[group]!r}, not {name!r}", | ||||||
|  |                                 name=name, | ||||||
|  |                                 obj=self, | ||||||
|  |                             ) | ||||||
|  |  | ||||||
|             value = super().__getattribute__(name) |             value = super().__getattribute__(name) | ||||||
|             if value is not PLACEHOLDER: |             if value is not PLACEHOLDER: | ||||||
|                 return value |                 return value | ||||||
| @@ -761,7 +781,10 @@ class Message(ABC): | |||||||
|         """ |         """ | ||||||
|         output = bytearray() |         output = bytearray() | ||||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): |         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||||
|             value = getattr(self, field_name) |             try: | ||||||
|  |                 value = getattr(self, field_name) | ||||||
|  |             except AttributeError: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|             if value is None: |             if value is None: | ||||||
|                 # Optional items should be skipped. This is used for the Google |                 # Optional items should be skipped. This is used for the Google | ||||||
| @@ -775,9 +798,7 @@ class Message(ABC): | |||||||
|             # Note that proto3 field presence/optional fields are put in a |             # Note that proto3 field presence/optional fields are put in a | ||||||
|             # synthetic single-item oneof by protoc, which helps us ensure we |             # synthetic single-item oneof by protoc, which helps us ensure we | ||||||
|             # send the value even if the value is the default zero value. |             # send the value even if the value is the default zero value. | ||||||
|             selected_in_group = ( |             selected_in_group = bool(meta.group) | ||||||
|                 meta.group and self._group_current[meta.group] == field_name |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             # Empty messages can still be sent on the wire if they were |             # Empty messages can still be sent on the wire if they were | ||||||
|             # set (or received empty). |             # set (or received empty). | ||||||
| @@ -1016,7 +1037,12 @@ class Message(ABC): | |||||||
|                     parsed.wire_type, meta, field_name, parsed.value |                     parsed.wire_type, meta, field_name, parsed.value | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|             current = getattr(self, field_name) |             try: | ||||||
|  |                 current = getattr(self, field_name) | ||||||
|  |             except AttributeError: | ||||||
|  |                 current = self._get_field_default(field_name) | ||||||
|  |                 setattr(self, field_name, current) | ||||||
|  |  | ||||||
|             if meta.proto_type == TYPE_MAP: |             if meta.proto_type == TYPE_MAP: | ||||||
|                 # Value represents a single key/value pair entry in the map. |                 # Value represents a single key/value pair entry in the map. | ||||||
|                 current[value.key] = value.value |                 current[value.key] = value.value | ||||||
| @@ -1077,7 +1103,10 @@ class Message(ABC): | |||||||
|         defaults = self._betterproto.default_gen |         defaults = self._betterproto.default_gen | ||||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): |         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||||
|             field_is_repeated = defaults[field_name] is list |             field_is_repeated = defaults[field_name] is list | ||||||
|             value = getattr(self, field_name) |             try: | ||||||
|  |                 value = getattr(self, field_name) | ||||||
|  |             except AttributeError: | ||||||
|  |                 value = self._get_field_default(field_name) | ||||||
|             cased_name = casing(field_name).rstrip("_")  # type: ignore |             cased_name = casing(field_name).rstrip("_")  # type: ignore | ||||||
|             if meta.proto_type == TYPE_MESSAGE: |             if meta.proto_type == TYPE_MESSAGE: | ||||||
|                 if isinstance(value, datetime): |                 if isinstance(value, datetime): | ||||||
| @@ -1209,7 +1238,7 @@ class Message(ABC): | |||||||
|  |  | ||||||
|             if value[key] is not None: |             if value[key] is not None: | ||||||
|                 if meta.proto_type == TYPE_MESSAGE: |                 if meta.proto_type == TYPE_MESSAGE: | ||||||
|                     v = getattr(self, field_name) |                     v = self._get_field_default(field_name) | ||||||
|                     cls = self._betterproto.cls_by_field[field_name] |                     cls = self._betterproto.cls_by_field[field_name] | ||||||
|                     if isinstance(v, list): |                     if isinstance(v, list): | ||||||
|                         if cls == datetime: |                         if cls == datetime: | ||||||
| @@ -1486,7 +1515,6 @@ class Message(ABC): | |||||||
|         field_name_to_meta = cls._betterproto_meta.meta_by_field_name  # type: ignore |         field_name_to_meta = cls._betterproto_meta.meta_by_field_name  # type: ignore | ||||||
|  |  | ||||||
|         for group, field_set in group_to_one_ofs.items(): |         for group, field_set in group_to_one_ofs.items(): | ||||||
|  |  | ||||||
|             if len(field_set) == 1: |             if len(field_set) == 1: | ||||||
|                 (field,) = field_set |                 (field,) = field_set | ||||||
|                 field_name = field.name |                 field_name = field.name | ||||||
|   | |||||||
| @@ -21,7 +21,6 @@ class ServiceBase(ABC): | |||||||
|         stream: grpclib.server.Stream, |         stream: grpclib.server.Stream, | ||||||
|         request: Any, |         request: Any, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |  | ||||||
|         response_iter = handler(request) |         response_iter = handler(request) | ||||||
|         # check if response is actually an AsyncIterator |         # check if response is actually an AsyncIterator | ||||||
|         # this might be false if the method just returns without |         # this might be false if the method just returns without | ||||||
|   | |||||||
| @@ -21,7 +21,6 @@ from .models import OutputTemplate | |||||||
|  |  | ||||||
|  |  | ||||||
| def outputfile_compiler(output_file: OutputTemplate) -> str: | def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||||
|  |  | ||||||
|     templates_folder = os.path.abspath( |     templates_folder = os.path.abspath( | ||||||
|         os.path.join(os.path.dirname(__file__), "..", "templates") |         os.path.join(os.path.dirname(__file__), "..", "templates") | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -159,7 +159,6 @@ def _make_one_of_field_compiler( | |||||||
|     proto_obj: "FieldDescriptorProto", |     proto_obj: "FieldDescriptorProto", | ||||||
|     path: List[int], |     path: List[int], | ||||||
| ) -> FieldCompiler: | ) -> FieldCompiler: | ||||||
|  |  | ||||||
|     pydantic = output_package.pydantic_dataclasses |     pydantic = output_package.pydantic_dataclasses | ||||||
|     Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler |     Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler | ||||||
|     return Cls( |     return Cls( | ||||||
|   | |||||||
| @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof(): | |||||||
|  |  | ||||||
|     # None of these fields were explicitly set BUT they should not actually be null |     # None of these fields were explicitly set BUT they should not actually be null | ||||||
|     # themselves |     # themselves | ||||||
|     assert isinstance(message.foo, Foo) |     assert not hasattr(message, "foo") | ||||||
|     assert isinstance(message2.foo, Foo) |     assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER | ||||||
|  |     assert not hasattr(message2, "foo") | ||||||
|  |     assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER | ||||||
|  |  | ||||||
|     assert isinstance(message_reference.foo, ReferenceFoo) |     assert isinstance(message_reference.foo, ReferenceFoo) | ||||||
|     assert isinstance(message_reference2.foo, ReferenceFoo) |     assert isinstance(message_reference2.foo, ReferenceFoo) | ||||||
|   | |||||||
| @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value(): | |||||||
|         get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json |         get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     assert message.move == Move( |     assert not hasattr(message, "move") | ||||||
|         x=0, y=0 |     assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER | ||||||
|     )  # Proto3 will default this as there is no null |  | ||||||
|     assert message.signal == Signal.PASS |     assert message.signal == Signal.PASS | ||||||
|     assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) |     assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) | ||||||
|  |  | ||||||
| @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value(): | |||||||
|     message.from_json( |     message.from_json( | ||||||
|         get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json |         get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json | ||||||
|     ) |     ) | ||||||
|     assert message.move == Move( |     assert not hasattr(message, "move") | ||||||
|         x=0, y=0 |     assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER | ||||||
|     )  # Proto3 will default this as there is no null |  | ||||||
|     assert message.signal == Signal.RESIGN |     assert message.signal == Signal.RESIGN | ||||||
|     assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) |     assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) | ||||||
|  |  | ||||||
| @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set(): | |||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof_enum")[0].json) |     message.from_json(get_test_case_json_data("oneof_enum")[0].json) | ||||||
|     assert message.move == Move(x=2, y=3) |     assert message.move == Move(x=2, y=3) | ||||||
|     assert message.signal == Signal.PASS |     assert not hasattr(message, "signal") | ||||||
|  |     assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER | ||||||
|     assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) |     assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								tests/oneof_pattern_matching.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/oneof_pattern_matching.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | from dataclasses import dataclass | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | import betterproto | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_oneof_pattern_matching(): | ||||||
|  |     @dataclass | ||||||
|  |     class Sub(betterproto.Message): | ||||||
|  |         val: int = betterproto.int32_field(1) | ||||||
|  |  | ||||||
|  |     @dataclass | ||||||
|  |     class Foo(betterproto.Message): | ||||||
|  |         bar: int = betterproto.int32_field(1, group="group1") | ||||||
|  |         baz: str = betterproto.string_field(2, group="group1") | ||||||
|  |         sub: Sub = betterproto.message_field(3, group="group2") | ||||||
|  |         abc: str = betterproto.string_field(4, group="group2") | ||||||
|  |  | ||||||
|  |     foo = Foo(baz="test1", abc="test2") | ||||||
|  |  | ||||||
|  |     match foo: | ||||||
|  |         case Foo(bar=_): | ||||||
|  |             pytest.fail("Matched 'bar' instead of 'baz'") | ||||||
|  |         case Foo(baz=v): | ||||||
|  |             assert v == "test1" | ||||||
|  |         case _: | ||||||
|  |             pytest.fail("Matched neither 'bar' nor 'baz'") | ||||||
|  |  | ||||||
|  |     match foo: | ||||||
|  |         case Foo(sub=_): | ||||||
|  |             pytest.fail("Matched 'sub' instead of 'abc'") | ||||||
|  |         case Foo(abc=v): | ||||||
|  |             assert v == "test2" | ||||||
|  |         case _: | ||||||
|  |             pytest.fail("Matched neither 'sub' nor 'abc'") | ||||||
|  |  | ||||||
|  |     foo.sub = Sub(val=1) | ||||||
|  |  | ||||||
|  |     match foo: | ||||||
|  |         case Foo(sub=Sub(val=v)): | ||||||
|  |             assert v == 1 | ||||||
|  |         case Foo(abc=v): | ||||||
|  |             pytest.fail("Matched 'abc' instead of 'sub'") | ||||||
|  |         case _: | ||||||
|  |             pytest.fail("Matched neither 'sub' nor 'abc'") | ||||||
| @@ -1,4 +1,5 @@ | |||||||
| import json | import json | ||||||
|  | import sys | ||||||
| from copy import ( | from copy import ( | ||||||
|     copy, |     copy, | ||||||
|     deepcopy, |     deepcopy, | ||||||
| @@ -18,6 +19,8 @@ from typing import ( | |||||||
|     Optional, |     Optional, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -151,17 +154,18 @@ def test_oneof_support(): | |||||||
|     foo.baz = "test" |     foo.baz = "test" | ||||||
|  |  | ||||||
|     # Other oneof fields should now be unset |     # Other oneof fields should now be unset | ||||||
|     assert foo.bar == 0 |     assert not hasattr(foo, "bar") | ||||||
|  |     assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER | ||||||
|     assert betterproto.which_one_of(foo, "group1")[0] == "baz" |     assert betterproto.which_one_of(foo, "group1")[0] == "baz" | ||||||
|  |  | ||||||
|     foo.sub.val = 1 |     foo.sub = Sub(val=1) | ||||||
|     assert betterproto.serialized_on_wire(foo.sub) |     assert betterproto.serialized_on_wire(foo.sub) | ||||||
|  |  | ||||||
|     foo.abc = "test" |     foo.abc = "test" | ||||||
|  |  | ||||||
|     # Group 1 shouldn't be touched, group 2 should have reset |     # Group 1 shouldn't be touched, group 2 should have reset | ||||||
|     assert foo.sub.val == 0 |     assert not hasattr(foo, "sub") | ||||||
|     assert betterproto.serialized_on_wire(foo.sub) is False |     assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER | ||||||
|     assert betterproto.which_one_of(foo, "group2")[0] == "abc" |     assert betterproto.which_one_of(foo, "group2")[0] == "abc" | ||||||
|  |  | ||||||
|     # Zero value should always serialize for one-of |     # Zero value should always serialize for one-of | ||||||
| @@ -176,6 +180,16 @@ def test_oneof_support(): | |||||||
|     assert betterproto.which_one_of(foo2, "group2")[0] == "" |     assert betterproto.which_one_of(foo2, "group2")[0] == "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.skipif( | ||||||
|  |     sys.version_info < (3, 10), | ||||||
|  |     reason="pattern matching is only supported in python3.10+", | ||||||
|  | ) | ||||||
|  | def test_oneof_pattern_matching(): | ||||||
|  |     from .oneof_pattern_matching import test_oneof_pattern_matching | ||||||
|  |  | ||||||
|  |     test_oneof_pattern_matching() | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_json_casing(): | def test_json_casing(): | ||||||
|     @dataclass |     @dataclass | ||||||
|     class CasingTest(betterproto.Message): |     class CasingTest(betterproto.Message): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user