diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index 3030c3b..42b3664 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -24,28 +24,42 @@ class PydanticView(AbstractView): An AIOHTTP View that validate request using function annotations. """ + # Allowed HTTP methods; overridden when subclassed. + allowed_methods: set = {} + async def _iter(self) -> StreamResponse: - method = getattr(self, self.request.method.lower(), None) - resp = await method() - return resp + if (method_name := self.request.method) not in self.allowed_methods: + self._raise_allowed_methods() + return await getattr(self, method_name.lower())() def __await__(self) -> Generator[Any, None, StreamResponse]: return self._iter().__await__() - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: + """Define allowed methods and decorate handlers. + + Handlers are decorated if and only if they meet the following conditions: + - the handler corresponds to an allowed method; + - the handler method was not inherited from a :class:`PydanticView` base + class. This prevents that methods are decorated multiple times. + """ cls.allowed_methods = { meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower()) } for meth_name in METH_ALL: - if meth_name not in cls.allowed_methods: - setattr(cls, meth_name.lower(), cls.raise_not_allowed) - else: + if meth_name in cls.allowed_methods: handler = getattr(cls, meth_name.lower()) - decorated_handler = inject_params(handler, cls.parse_func_signature) - setattr(cls, meth_name.lower(), decorated_handler) + for base_class in cls.__bases__: + if is_pydantic_view(base_class): + parent_handler = getattr(base_class, meth_name.lower(), None) + if handler == parent_handler: + break + else: + decorated_handler = inject_params(handler, cls.parse_func_signature) + setattr(cls, meth_name.lower(), decorated_handler) - async def raise_not_allowed(self): + def _raise_allowed_methods(self) -> None: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) @staticmethod diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py new file mode 100644 index 0000000..0375c01 --- /dev/null +++ b/tests/test_inheritance.py @@ -0,0 +1,58 @@ +from typing import Any + +from aiohttp_pydantic import PydanticView + + +def count_wrappers(obj: Any) -> int: + """Count the number of times that an object is wrapped.""" + i = 0 + while i < 10: + try: + obj = obj.__wrapped__ + except AttributeError: + return i + else: + i += 1 + raise RuntimeError("Too many wrappers") + + +class ViewParent(PydanticView): + async def put(self): + pass + + async def delete(self): + pass + + +class ViewParentNonPydantic: + async def post(self): + pass + + +class ViewChild(ViewParent, ViewParentNonPydantic): + async def get(self): + pass + + async def delete(self): + pass + + async def not_allowed(self): + pass + + +def test_allowed_methods_are_set_correctly(): + assert ViewParent.allowed_methods == {"PUT", "DELETE"} + assert ViewChild.allowed_methods == {"GET", "POST", "PUT", "DELETE"} + + +def test_allowed_methods_get_decorated_exactly_once(): + assert count_wrappers(ViewParent.put) == 1 + assert count_wrappers(ViewParent.delete) == 1 + assert count_wrappers(ViewChild.get) == 1 + assert count_wrappers(ViewChild.post) == 1 + assert count_wrappers(ViewChild.put) == 1 + assert count_wrappers(ViewChild.post) == 1 + assert count_wrappers(ViewChild.put) == 1 + + assert count_wrappers(ViewChild.not_allowed) == 0 + assert count_wrappers(ViewParentNonPydantic.post) == 0