diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index a4f1c59..f75cb9d 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -1,6 +1,7 @@ from functools import update_wrapper from inspect import iscoroutinefunction -from typing import Any, Callable, Generator, Iterable, Set +from typing import Any, Callable, Generator, Iterable, Set, ClassVar +import warnings from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL @@ -25,7 +26,7 @@ class PydanticView(AbstractView): """ # Allowed HTTP methods; overridden when subclassed. - allowed_methods: Set[str] = {} + allowed_methods: ClassVar[Set[str]] = {} async def _iter(self) -> StreamResponse: if (method_name := self.request.method) not in self.allowed_methods: @@ -38,30 +39,30 @@ class PydanticView(AbstractView): def __init_subclass__(cls, **kwargs) -> None: """Define allowed methods and decorate handlers. - Handlers are decorated if and only if they meet the following conditions: - - the handler corresponds to an allowed method; - - the handler method was not inherited from a :class:`PydanticView` base - class. This prevents that methods are decorated multiple times. + Handlers are decorated if and only if they directly bound on the PydanticView class or + PydanticView subclass. This prevents that methods are decorated multiple times and that method + defined in aiohttp.View parent class is decorated. """ + cls.allowed_methods = { meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower()) } for meth_name in METH_ALL: - if meth_name in cls.allowed_methods: + if meth_name.lower() in vars(cls): handler = getattr(cls, meth_name.lower()) - for base_class in cls.__bases__: - if is_pydantic_view(base_class): - parent_handler = getattr(base_class, meth_name.lower(), None) - if handler == parent_handler: - break - else: - decorated_handler = inject_params(handler, cls.parse_func_signature) - setattr(cls, meth_name.lower(), decorated_handler) + decorated_handler = inject_params(handler, cls.parse_func_signature) + setattr(cls, meth_name.lower(), decorated_handler) def _raise_allowed_methods(self) -> None: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + def raise_not_allowed(self) -> None: + warnings.warn( + "PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods", + DeprecationWarning, stacklevel=2) + self._raise_allowed_methods() + @staticmethod def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]: path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index 0375c01..e6214a1 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -1,6 +1,7 @@ from typing import Any from aiohttp_pydantic import PydanticView +from aiohttp.web import View def count_wrappers(obj: Any) -> int: @@ -16,43 +17,58 @@ def count_wrappers(obj: Any) -> int: raise RuntimeError("Too many wrappers") -class ViewParent(PydanticView): +class AiohttpViewParent(View): async def put(self): pass - async def delete(self): + +class PydanticViewParent(PydanticView): + async def get(self, id: int, /): pass -class ViewParentNonPydantic: - async def post(self): - pass - - -class ViewChild(ViewParent, ViewParentNonPydantic): - async def get(self): - pass - - async def delete(self): - pass - - async def not_allowed(self): - pass - - -def test_allowed_methods_are_set_correctly(): - assert ViewParent.allowed_methods == {"PUT", "DELETE"} - assert ViewChild.allowed_methods == {"GET", "POST", "PUT", "DELETE"} - - def test_allowed_methods_get_decorated_exactly_once(): - assert count_wrappers(ViewParent.put) == 1 - assert count_wrappers(ViewParent.delete) == 1 - assert count_wrappers(ViewChild.get) == 1 - assert count_wrappers(ViewChild.post) == 1 - assert count_wrappers(ViewChild.put) == 1 - assert count_wrappers(ViewChild.post) == 1 - assert count_wrappers(ViewChild.put) == 1 - assert count_wrappers(ViewChild.not_allowed) == 0 - assert count_wrappers(ViewParentNonPydantic.post) == 0 + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + class SubChildView(ChildView): + async def get(self, id: int, /): + return super().get(id) + + assert count_wrappers(ChildView.post) == 1 + assert count_wrappers(ChildView.get) == 1 + assert count_wrappers(SubChildView.post) == 1 + assert count_wrappers(SubChildView.get) == 1 + + +def test_methods_inherited_from_aiohttp_view_should_not_be_decorated(): + + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert count_wrappers(ChildView.put) == 0 + assert count_wrappers(ChildView.post) == 1 + + +def test_allowed_methods_are_set_correctly(): + + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT"} + + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "GET"} + + class ChildView(AiohttpViewParent, PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT", "GET"}