diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index d546579..8207c40 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -8,6 +8,7 @@ import sys import typing from abc import ABC from base64 import b64decode, b64encode +from copy import deepcopy from datetime import datetime, timedelta, timezone from dateutil.parser import isoparse from typing import ( @@ -717,6 +718,14 @@ class Message(ABC): for field_name in self._betterproto.meta_by_field_name ) + def __deepcopy__(self: T, _: Any = {}) -> T: + kwargs = {} + for name in self._betterproto.sorted_field_names: + value = self.__raw_get(name) + if value is not PLACEHOLDER: + kwargs[name] = deepcopy(value) + return self.__class__(**kwargs) # type: ignore + @property def _betterproto(self) -> ProtoClassMetadata: """ diff --git a/tests/test_features.py b/tests/test_features.py index 787520d..0fedce2 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from copy import copy, deepcopy from datetime import datetime from inspect import Parameter, signature from typing import Dict, List, Optional @@ -485,3 +486,22 @@ def test_service_argument__expected_parameter(): do_thing_request_parameter = sig.parameters["do_thing_request"] assert do_thing_request_parameter.default is Parameter.empty assert do_thing_request_parameter.annotation == "DoThingRequest" + + +def test_copyability(): + @dataclass + class Spam(betterproto.Message): + foo: bool = betterproto.bool_field(1) + bar: int = betterproto.int32_field(2) + baz: List[str] = betterproto.string_field(3) + + spam = Spam(bar=12, baz=["hello"]) + copied = copy(spam) + assert spam == copied + assert spam is not copied + assert spam.baz is copied.baz + + deepcopied = deepcopy(spam) + assert spam == deepcopied + assert spam is not deepcopied + assert spam.baz is not deepcopied.baz