diff --git a/aiohttp_pydantic/__init__.py b/aiohttp_pydantic/__init__.py index 9604bdd..ac08759 100644 --- a/aiohttp_pydantic/__init__.py +++ b/aiohttp_pydantic/__init__.py @@ -1,5 +1,5 @@ from .view import PydanticView -__version__ = "1.9.0" +__version__ = "1.9.1" __all__ = ("PydanticView", "__version__") diff --git a/aiohttp_pydantic/oas/struct.py b/aiohttp_pydantic/oas/struct.py index 576079a..cf01e9c 100644 --- a/aiohttp_pydantic/oas/struct.py +++ b/aiohttp_pydantic/oas/struct.py @@ -305,7 +305,10 @@ class Components: class OpenApiSpec3: def __init__(self): - self._spec = {"openapi": "3.0.0", "info": {"version": "1.0.0", "title": "Aiohttp pydantic application"}} + self._spec = { + "openapi": "3.0.0", + "info": {"version": "1.0.0", "title": "Aiohttp pydantic application"}, + } @property def info(self) -> Info: diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index d91bb8e..b2fc9e5 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -147,7 +147,11 @@ def _add_http_method_to_oas( ) -def generate_oas(apps: List[Application], version_spec: Optional[str] = None, title_spec: Optional[str] = None) -> dict: +def generate_oas( + apps: List[Application], + version_spec: Optional[str] = None, + title_spec: Optional[str] = None, +) -> dict: """ Generate and return Open Api Specification from PydanticView in application. """ diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index 3030c3b..6196ddd 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -1,6 +1,7 @@ from functools import update_wrapper from inspect import iscoroutinefunction -from typing import Any, Callable, Generator, Iterable +from typing import Any, Callable, Generator, Iterable, Set, ClassVar +import warnings from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL @@ -24,30 +25,46 @@ class PydanticView(AbstractView): An AIOHTTP View that validate request using function annotations. """ + # Allowed HTTP methods; overridden when subclassed. + allowed_methods: ClassVar[Set[str]] = {} + async def _iter(self) -> StreamResponse: - method = getattr(self, self.request.method.lower(), None) - resp = await method() - return resp + if (method_name := self.request.method) not in self.allowed_methods: + self._raise_allowed_methods() + return await getattr(self, method_name.lower())() def __await__(self) -> Generator[Any, None, StreamResponse]: return self._iter().__await__() - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: + """Define allowed methods and decorate handlers. + + Handlers are decorated if and only if they directly bound on the PydanticView class or + PydanticView subclass. This prevents that methods are decorated multiple times and that method + defined in aiohttp.View parent class is decorated. + """ + cls.allowed_methods = { meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower()) } for meth_name in METH_ALL: - if meth_name not in cls.allowed_methods: - setattr(cls, meth_name.lower(), cls.raise_not_allowed) - else: + if meth_name.lower() in vars(cls): handler = getattr(cls, meth_name.lower()) decorated_handler = inject_params(handler, cls.parse_func_signature) setattr(cls, meth_name.lower(), decorated_handler) - async def raise_not_allowed(self): + def _raise_allowed_methods(self) -> None: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + def raise_not_allowed(self) -> None: + warnings.warn( + "PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods", + DeprecationWarning, + stacklevel=2, + ) + self._raise_allowed_methods() + @staticmethod def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]: path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py new file mode 100644 index 0000000..d759c0d --- /dev/null +++ b/tests/test_inheritance.py @@ -0,0 +1,71 @@ +from typing import Any + +from aiohttp_pydantic import PydanticView +from aiohttp.web import View + + +def count_wrappers(obj: Any) -> int: + """Count the number of times that an object is wrapped.""" + i = 0 + while i < 10: + try: + obj = obj.__wrapped__ + except AttributeError: + return i + else: + i += 1 + raise RuntimeError("Too many wrappers") + + +class AiohttpViewParent(View): + async def put(self): + pass + + +class PydanticViewParent(PydanticView): + async def get(self, id: int, /): + pass + + +def test_allowed_methods_get_decorated_exactly_once(): + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + class SubChildView(ChildView): + async def get(self, id: int, /): + return super().get(id) + + assert count_wrappers(ChildView.post) == 1 + assert count_wrappers(ChildView.get) == 1 + assert count_wrappers(SubChildView.post) == 1 + assert count_wrappers(SubChildView.get) == 1 + + +def test_methods_inherited_from_aiohttp_view_should_not_be_decorated(): + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert count_wrappers(ChildView.put) == 0 + assert count_wrappers(ChildView.post) == 1 + + +def test_allowed_methods_are_set_correctly(): + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT"} + + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "GET"} + + class ChildView(AiohttpViewParent, PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT", "GET"} diff --git a/tests/test_oas/test_cmd/test_cmd.py b/tests/test_oas/test_cmd/test_cmd.py index bd50f27..f9d78ca 100644 --- a/tests/test_oas/test_cmd/test_cmd.py +++ b/tests/test_oas/test_cmd/test_cmd.py @@ -22,7 +22,7 @@ def test_show_oas_of_app(cmd_line): args.func(args) expected = dedent( - """ + """ { "info": { "title": "Aiohttp pydantic application", @@ -73,7 +73,7 @@ def test_show_oas_of_sub_app(cmd_line): args.output = StringIO() args.func(args) expected = dedent( - """ + """ { "info": { "title": "Aiohttp pydantic application", diff --git a/tests/test_oas/test_struct/test_info.py b/tests/test_oas/test_struct/test_info.py index 937782c..642072b 100644 --- a/tests/test_oas/test_struct/test_info.py +++ b/tests/test_oas/test_struct/test_info.py @@ -37,7 +37,10 @@ def test_info_version(): assert oas.info.version == "1.0.0" oas.info.version = "3.14" assert oas.info.version == "3.14" - assert oas.spec == {"info": {"version": "3.14", "title": "Aiohttp pydantic application"}, "openapi": "3.0.0"} + assert oas.spec == { + "info": {"version": "3.14", "title": "Aiohttp pydantic application"}, + "openapi": "3.0.0", + } def test_info_terms_of_service(): diff --git a/tests/test_oas/test_view.py b/tests/test_oas/test_view.py index 364e8eb..b95c60b 100644 --- a/tests/test_oas/test_view.py +++ b/tests/test_oas/test_view.py @@ -318,22 +318,32 @@ async def test_simple_type_route_should_have_get_method(generated_oas): }, } + async def test_generated_view_info_default(): apps = (web.Application(),) spec = generate_oas(apps) - assert spec == {'info': {'title': 'Aiohttp pydantic application', 'version': '1.0.0'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "Aiohttp pydantic application", "version": "1.0.0"}, + "openapi": "3.0.0", + } async def test_generated_view_info_as_version(): apps = (web.Application(),) spec = generate_oas(apps, version_spec="test version") - assert spec == {'info': {'title': 'Aiohttp pydantic application', 'version': 'test version'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "Aiohttp pydantic application", "version": "test version"}, + "openapi": "3.0.0", + } async def test_generated_view_info_as_title(): apps = (web.Application(),) spec = generate_oas(apps, title_spec="test title") - assert spec == {'info': {'title': 'test title', 'version': '1.0.0'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "test title", "version": "1.0.0"}, + "openapi": "3.0.0", + } diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 57d886e..4b8913f 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -11,7 +11,7 @@ class ArticleView(PydanticView): with_comments: bool, age: Optional[int] = None, nb_items: int = 7, - tags: List[str] = Field(default_factory=list) + tags: List[str] = Field(default_factory=list), ): return web.json_response( {