Raise AttributeError on attempts to access unset oneof fields (#510)
				
					
				
			This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							098989e9e9
						
					
				
				
					commit
					6faac1d1ca
				
			| @@ -693,8 +693,28 @@ class Message(ABC): | ||||
|         def __getattribute__(self, name: str) -> Any: | ||||
|             """ | ||||
|             Lazily initialize default values to avoid infinite recursion for recursive | ||||
|             message types | ||||
|             message types. | ||||
|             Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields. | ||||
|             """ | ||||
|             try: | ||||
|                 group_current = super().__getattribute__("_group_current") | ||||
|             except AttributeError: | ||||
|                 pass | ||||
|             else: | ||||
|                 if name not in {"__class__", "_betterproto"}: | ||||
|                     group = self._betterproto.oneof_group_by_field.get(name) | ||||
|                     if group is not None and group_current[group] != name: | ||||
|                         if sys.version_info < (3, 10): | ||||
|                             raise AttributeError( | ||||
|                                 f"{group!r} is set to {group_current[group]!r}, not {name!r}" | ||||
|                             ) | ||||
|                         else: | ||||
|                             raise AttributeError( | ||||
|                                 f"{group!r} is set to {group_current[group]!r}, not {name!r}", | ||||
|                                 name=name, | ||||
|                                 obj=self, | ||||
|                             ) | ||||
|  | ||||
|             value = super().__getattribute__(name) | ||||
|             if value is not PLACEHOLDER: | ||||
|                 return value | ||||
| @@ -761,7 +781,10 @@ class Message(ABC): | ||||
|         """ | ||||
|         output = bytearray() | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|             value = getattr(self, field_name) | ||||
|             try: | ||||
|                 value = getattr(self, field_name) | ||||
|             except AttributeError: | ||||
|                 continue | ||||
|  | ||||
|             if value is None: | ||||
|                 # Optional items should be skipped. This is used for the Google | ||||
| @@ -775,9 +798,7 @@ class Message(ABC): | ||||
|             # Note that proto3 field presence/optional fields are put in a | ||||
|             # synthetic single-item oneof by protoc, which helps us ensure we | ||||
|             # send the value even if the value is the default zero value. | ||||
|             selected_in_group = ( | ||||
|                 meta.group and self._group_current[meta.group] == field_name | ||||
|             ) | ||||
|             selected_in_group = bool(meta.group) | ||||
|  | ||||
|             # Empty messages can still be sent on the wire if they were | ||||
|             # set (or received empty). | ||||
| @@ -1016,7 +1037,12 @@ class Message(ABC): | ||||
|                     parsed.wire_type, meta, field_name, parsed.value | ||||
|                 ) | ||||
|  | ||||
|             current = getattr(self, field_name) | ||||
|             try: | ||||
|                 current = getattr(self, field_name) | ||||
|             except AttributeError: | ||||
|                 current = self._get_field_default(field_name) | ||||
|                 setattr(self, field_name, current) | ||||
|  | ||||
|             if meta.proto_type == TYPE_MAP: | ||||
|                 # Value represents a single key/value pair entry in the map. | ||||
|                 current[value.key] = value.value | ||||
| @@ -1077,7 +1103,10 @@ class Message(ABC): | ||||
|         defaults = self._betterproto.default_gen | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|             field_is_repeated = defaults[field_name] is list | ||||
|             value = getattr(self, field_name) | ||||
|             try: | ||||
|                 value = getattr(self, field_name) | ||||
|             except AttributeError: | ||||
|                 value = self._get_field_default(field_name) | ||||
|             cased_name = casing(field_name).rstrip("_")  # type: ignore | ||||
|             if meta.proto_type == TYPE_MESSAGE: | ||||
|                 if isinstance(value, datetime): | ||||
| @@ -1209,7 +1238,7 @@ class Message(ABC): | ||||
|  | ||||
|             if value[key] is not None: | ||||
|                 if meta.proto_type == TYPE_MESSAGE: | ||||
|                     v = getattr(self, field_name) | ||||
|                     v = self._get_field_default(field_name) | ||||
|                     cls = self._betterproto.cls_by_field[field_name] | ||||
|                     if isinstance(v, list): | ||||
|                         if cls == datetime: | ||||
| @@ -1486,7 +1515,6 @@ class Message(ABC): | ||||
|         field_name_to_meta = cls._betterproto_meta.meta_by_field_name  # type: ignore | ||||
|  | ||||
|         for group, field_set in group_to_one_ofs.items(): | ||||
|  | ||||
|             if len(field_set) == 1: | ||||
|                 (field,) = field_set | ||||
|                 field_name = field.name | ||||
|   | ||||
| @@ -21,7 +21,6 @@ class ServiceBase(ABC): | ||||
|         stream: grpclib.server.Stream, | ||||
|         request: Any, | ||||
|     ) -> None: | ||||
|  | ||||
|         response_iter = handler(request) | ||||
|         # check if response is actually an AsyncIterator | ||||
|         # this might be false if the method just returns without | ||||
|   | ||||
| @@ -21,7 +21,6 @@ from .models import OutputTemplate | ||||
|  | ||||
|  | ||||
| def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|  | ||||
|     templates_folder = os.path.abspath( | ||||
|         os.path.join(os.path.dirname(__file__), "..", "templates") | ||||
|     ) | ||||
|   | ||||
| @@ -159,7 +159,6 @@ def _make_one_of_field_compiler( | ||||
|     proto_obj: "FieldDescriptorProto", | ||||
|     path: List[int], | ||||
| ) -> FieldCompiler: | ||||
|  | ||||
|     pydantic = output_package.pydantic_dataclasses | ||||
|     Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler | ||||
|     return Cls( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user