Implement __deepcopy__ for Message (#339)

This commit is contained in:
James Hilton-Balfe 2022-02-16 23:12:51 +00:00 committed by GitHub
parent 3f377e3bfd
commit 74205e3319
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import sys
import typing
from abc import ABC
from base64 import b64decode, b64encode
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from dateutil.parser import isoparse
from typing import (
@ -717,6 +718,14 @@ class Message(ABC):
for field_name in self._betterproto.meta_by_field_name
)
def __deepcopy__(self: T, _: Any = {}) -> T:
kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
if value is not PLACEHOLDER:
kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore
@property
def _betterproto(self) -> ProtoClassMetadata:
"""

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from copy import copy, deepcopy
from datetime import datetime
from inspect import Parameter, signature
from typing import Dict, List, Optional
@ -485,3 +486,22 @@ def test_service_argument__expected_parameter():
do_thing_request_parameter = sig.parameters["do_thing_request"]
assert do_thing_request_parameter.default is Parameter.empty
assert do_thing_request_parameter.annotation == "DoThingRequest"
def test_copyability():
@dataclass
class Spam(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
spam = Spam(bar=12, baz=["hello"])
copied = copy(spam)
assert spam == copied
assert spam is not copied
assert spam.baz is copied.baz
deepcopied = deepcopy(spam)
assert spam == deepcopied
assert spam is not deepcopied
assert spam.baz is not deepcopied.baz