Implement __deepcopy__ for Message (#339)

This commit is contained in:
James Hilton-Balfe 2022-02-16 23:12:51 +00:00 committed by GitHub
parent 3f377e3bfd
commit 74205e3319
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import sys
import typing import typing
from abc import ABC from abc import ABC
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from copy import deepcopy
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from dateutil.parser import isoparse from dateutil.parser import isoparse
from typing import ( from typing import (
@ -717,6 +718,14 @@ class Message(ABC):
for field_name in self._betterproto.meta_by_field_name for field_name in self._betterproto.meta_by_field_name
) )
def __deepcopy__(self: T, _: Any = {}) -> T:
kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
if value is not PLACEHOLDER:
kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore
@property @property
def _betterproto(self) -> ProtoClassMetadata: def _betterproto(self) -> ProtoClassMetadata:
""" """

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from copy import copy, deepcopy
from datetime import datetime from datetime import datetime
from inspect import Parameter, signature from inspect import Parameter, signature
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -485,3 +486,22 @@ def test_service_argument__expected_parameter():
do_thing_request_parameter = sig.parameters["do_thing_request"] do_thing_request_parameter = sig.parameters["do_thing_request"]
assert do_thing_request_parameter.default is Parameter.empty assert do_thing_request_parameter.default is Parameter.empty
assert do_thing_request_parameter.annotation == "DoThingRequest" assert do_thing_request_parameter.annotation == "DoThingRequest"
def test_copyability():
@dataclass
class Spam(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
spam = Spam(bar=12, baz=["hello"])
copied = copy(spam)
assert spam == copied
assert spam is not copied
assert spam.baz is copied.baz
deepcopied = deepcopy(spam)
assert spam == deepcopied
assert spam is not deepcopied
assert spam.baz is not deepcopied.baz