From 5f86e1efdaac5ef5d2852e66f2ee65f01ebb67d8 Mon Sep 17 00:00:00 2001 From: Daan de Ruiter <30779179+drderuiter@users.noreply.github.com> Date: Thu, 13 May 2021 10:59:01 +0200 Subject: [PATCH 1/4] Improve compatibility with web.View and support subclassing Views --- aiohttp_pydantic/view.py | 34 ++++++++++++++++------- tests/test_inheritance.py | 58 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 10 deletions(-) create mode 100644 tests/test_inheritance.py diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index 3030c3b..42b3664 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -24,28 +24,42 @@ class PydanticView(AbstractView): An AIOHTTP View that validate request using function annotations. """ + # Allowed HTTP methods; overridden when subclassed. + allowed_methods: set = {} + async def _iter(self) -> StreamResponse: - method = getattr(self, self.request.method.lower(), None) - resp = await method() - return resp + if (method_name := self.request.method) not in self.allowed_methods: + self._raise_allowed_methods() + return await getattr(self, method_name.lower())() def __await__(self) -> Generator[Any, None, StreamResponse]: return self._iter().__await__() - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: + """Define allowed methods and decorate handlers. + + Handlers are decorated if and only if they meet the following conditions: + - the handler corresponds to an allowed method; + - the handler method was not inherited from a :class:`PydanticView` base + class. This prevents that methods are decorated multiple times. + """ cls.allowed_methods = { meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower()) } for meth_name in METH_ALL: - if meth_name not in cls.allowed_methods: - setattr(cls, meth_name.lower(), cls.raise_not_allowed) - else: + if meth_name in cls.allowed_methods: handler = getattr(cls, meth_name.lower()) - decorated_handler = inject_params(handler, cls.parse_func_signature) - setattr(cls, meth_name.lower(), decorated_handler) + for base_class in cls.__bases__: + if is_pydantic_view(base_class): + parent_handler = getattr(base_class, meth_name.lower(), None) + if handler == parent_handler: + break + else: + decorated_handler = inject_params(handler, cls.parse_func_signature) + setattr(cls, meth_name.lower(), decorated_handler) - async def raise_not_allowed(self): + def _raise_allowed_methods(self) -> None: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) @staticmethod diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py new file mode 100644 index 0000000..0375c01 --- /dev/null +++ b/tests/test_inheritance.py @@ -0,0 +1,58 @@ +from typing import Any + +from aiohttp_pydantic import PydanticView + + +def count_wrappers(obj: Any) -> int: + """Count the number of times that an object is wrapped.""" + i = 0 + while i < 10: + try: + obj = obj.__wrapped__ + except AttributeError: + return i + else: + i += 1 + raise RuntimeError("Too many wrappers") + + +class ViewParent(PydanticView): + async def put(self): + pass + + async def delete(self): + pass + + +class ViewParentNonPydantic: + async def post(self): + pass + + +class ViewChild(ViewParent, ViewParentNonPydantic): + async def get(self): + pass + + async def delete(self): + pass + + async def not_allowed(self): + pass + + +def test_allowed_methods_are_set_correctly(): + assert ViewParent.allowed_methods == {"PUT", "DELETE"} + assert ViewChild.allowed_methods == {"GET", "POST", "PUT", "DELETE"} + + +def test_allowed_methods_get_decorated_exactly_once(): + assert count_wrappers(ViewParent.put) == 1 + assert count_wrappers(ViewParent.delete) == 1 + assert count_wrappers(ViewChild.get) == 1 + assert count_wrappers(ViewChild.post) == 1 + assert count_wrappers(ViewChild.put) == 1 + assert count_wrappers(ViewChild.post) == 1 + assert count_wrappers(ViewChild.put) == 1 + + assert count_wrappers(ViewChild.not_allowed) == 0 + assert count_wrappers(ViewParentNonPydantic.post) == 0 From c92437c624e000edb3d09c66facad62f052409fd Mon Sep 17 00:00:00 2001 From: Daan de Ruiter <30779179+drderuiter@users.noreply.github.com> Date: Thu, 13 May 2021 11:06:54 +0200 Subject: [PATCH 2/4] Further specify allowed_methods type hint --- aiohttp_pydantic/view.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index 42b3664..a4f1c59 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -1,6 +1,6 @@ from functools import update_wrapper from inspect import iscoroutinefunction -from typing import Any, Callable, Generator, Iterable +from typing import Any, Callable, Generator, Iterable, Set from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL @@ -25,7 +25,7 @@ class PydanticView(AbstractView): """ # Allowed HTTP methods; overridden when subclassed. - allowed_methods: set = {} + allowed_methods: Set[str] = {} async def _iter(self) -> StreamResponse: if (method_name := self.request.method) not in self.allowed_methods: From 08ab4d2610ea30a63c6202d6e89cadabc3e7fad6 Mon Sep 17 00:00:00 2001 From: Vincent Maillol Date: Sat, 10 Jul 2021 08:16:27 +0200 Subject: [PATCH 3/4] refactoring --- aiohttp_pydantic/view.py | 31 +++++++-------- tests/test_inheritance.py | 80 +++++++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index a4f1c59..f75cb9d 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -1,6 +1,7 @@ from functools import update_wrapper from inspect import iscoroutinefunction -from typing import Any, Callable, Generator, Iterable, Set +from typing import Any, Callable, Generator, Iterable, Set, ClassVar +import warnings from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL @@ -25,7 +26,7 @@ class PydanticView(AbstractView): """ # Allowed HTTP methods; overridden when subclassed. - allowed_methods: Set[str] = {} + allowed_methods: ClassVar[Set[str]] = {} async def _iter(self) -> StreamResponse: if (method_name := self.request.method) not in self.allowed_methods: @@ -38,30 +39,30 @@ class PydanticView(AbstractView): def __init_subclass__(cls, **kwargs) -> None: """Define allowed methods and decorate handlers. - Handlers are decorated if and only if they meet the following conditions: - - the handler corresponds to an allowed method; - - the handler method was not inherited from a :class:`PydanticView` base - class. This prevents that methods are decorated multiple times. + Handlers are decorated if and only if they directly bound on the PydanticView class or + PydanticView subclass. This prevents that methods are decorated multiple times and that method + defined in aiohttp.View parent class is decorated. """ + cls.allowed_methods = { meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower()) } for meth_name in METH_ALL: - if meth_name in cls.allowed_methods: + if meth_name.lower() in vars(cls): handler = getattr(cls, meth_name.lower()) - for base_class in cls.__bases__: - if is_pydantic_view(base_class): - parent_handler = getattr(base_class, meth_name.lower(), None) - if handler == parent_handler: - break - else: - decorated_handler = inject_params(handler, cls.parse_func_signature) - setattr(cls, meth_name.lower(), decorated_handler) + decorated_handler = inject_params(handler, cls.parse_func_signature) + setattr(cls, meth_name.lower(), decorated_handler) def _raise_allowed_methods(self) -> None: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + def raise_not_allowed(self) -> None: + warnings.warn( + "PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods", + DeprecationWarning, stacklevel=2) + self._raise_allowed_methods() + @staticmethod def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]: path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index 0375c01..e6214a1 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -1,6 +1,7 @@ from typing import Any from aiohttp_pydantic import PydanticView +from aiohttp.web import View def count_wrappers(obj: Any) -> int: @@ -16,43 +17,58 @@ def count_wrappers(obj: Any) -> int: raise RuntimeError("Too many wrappers") -class ViewParent(PydanticView): +class AiohttpViewParent(View): async def put(self): pass - async def delete(self): + +class PydanticViewParent(PydanticView): + async def get(self, id: int, /): pass -class ViewParentNonPydantic: - async def post(self): - pass - - -class ViewChild(ViewParent, ViewParentNonPydantic): - async def get(self): - pass - - async def delete(self): - pass - - async def not_allowed(self): - pass - - -def test_allowed_methods_are_set_correctly(): - assert ViewParent.allowed_methods == {"PUT", "DELETE"} - assert ViewChild.allowed_methods == {"GET", "POST", "PUT", "DELETE"} - - def test_allowed_methods_get_decorated_exactly_once(): - assert count_wrappers(ViewParent.put) == 1 - assert count_wrappers(ViewParent.delete) == 1 - assert count_wrappers(ViewChild.get) == 1 - assert count_wrappers(ViewChild.post) == 1 - assert count_wrappers(ViewChild.put) == 1 - assert count_wrappers(ViewChild.post) == 1 - assert count_wrappers(ViewChild.put) == 1 - assert count_wrappers(ViewChild.not_allowed) == 0 - assert count_wrappers(ViewParentNonPydantic.post) == 0 + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + class SubChildView(ChildView): + async def get(self, id: int, /): + return super().get(id) + + assert count_wrappers(ChildView.post) == 1 + assert count_wrappers(ChildView.get) == 1 + assert count_wrappers(SubChildView.post) == 1 + assert count_wrappers(SubChildView.get) == 1 + + +def test_methods_inherited_from_aiohttp_view_should_not_be_decorated(): + + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert count_wrappers(ChildView.put) == 0 + assert count_wrappers(ChildView.post) == 1 + + +def test_allowed_methods_are_set_correctly(): + + class ChildView(AiohttpViewParent, PydanticView): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT"} + + class ChildView(PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "GET"} + + class ChildView(AiohttpViewParent, PydanticViewParent): + async def post(self, id: int, /): + pass + + assert ChildView.allowed_methods == {"POST", "PUT", "GET"} From 89a22f2fcdd77c1f6551dfad9c023f2360809524 Mon Sep 17 00:00:00 2001 From: Vincent Maillol Date: Sun, 11 Jul 2021 07:31:37 +0200 Subject: [PATCH 4/4] code reformatting --- aiohttp_pydantic/__init__.py | 2 +- aiohttp_pydantic/oas/struct.py | 5 ++++- aiohttp_pydantic/oas/view.py | 6 +++++- aiohttp_pydantic/view.py | 4 +++- tests/test_inheritance.py | 3 --- tests/test_oas/test_cmd/test_cmd.py | 4 ++-- tests/test_oas/test_struct/test_info.py | 5 ++++- tests/test_oas/test_view.py | 16 +++++++++++++--- tests/test_validation_query_string.py | 2 +- 9 files changed, 33 insertions(+), 14 deletions(-) diff --git a/aiohttp_pydantic/__init__.py b/aiohttp_pydantic/__init__.py index 9604bdd..ac08759 100644 --- a/aiohttp_pydantic/__init__.py +++ b/aiohttp_pydantic/__init__.py @@ -1,5 +1,5 @@ from .view import PydanticView -__version__ = "1.9.0" +__version__ = "1.9.1" __all__ = ("PydanticView", "__version__") diff --git a/aiohttp_pydantic/oas/struct.py b/aiohttp_pydantic/oas/struct.py index 576079a..cf01e9c 100644 --- a/aiohttp_pydantic/oas/struct.py +++ b/aiohttp_pydantic/oas/struct.py @@ -305,7 +305,10 @@ class Components: class OpenApiSpec3: def __init__(self): - self._spec = {"openapi": "3.0.0", "info": {"version": "1.0.0", "title": "Aiohttp pydantic application"}} + self._spec = { + "openapi": "3.0.0", + "info": {"version": "1.0.0", "title": "Aiohttp pydantic application"}, + } @property def info(self) -> Info: diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index d91bb8e..b2fc9e5 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -147,7 +147,11 @@ def _add_http_method_to_oas( ) -def generate_oas(apps: List[Application], version_spec: Optional[str] = None, title_spec: Optional[str] = None) -> dict: +def generate_oas( + apps: List[Application], + version_spec: Optional[str] = None, + title_spec: Optional[str] = None, +) -> dict: """ Generate and return Open Api Specification from PydanticView in application. """ diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index f75cb9d..6196ddd 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -60,7 +60,9 @@ class PydanticView(AbstractView): def raise_not_allowed(self) -> None: warnings.warn( "PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods", - DeprecationWarning, stacklevel=2) + DeprecationWarning, + stacklevel=2, + ) self._raise_allowed_methods() @staticmethod diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index e6214a1..d759c0d 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -28,7 +28,6 @@ class PydanticViewParent(PydanticView): def test_allowed_methods_get_decorated_exactly_once(): - class ChildView(PydanticViewParent): async def post(self, id: int, /): pass @@ -44,7 +43,6 @@ def test_allowed_methods_get_decorated_exactly_once(): def test_methods_inherited_from_aiohttp_view_should_not_be_decorated(): - class ChildView(AiohttpViewParent, PydanticView): async def post(self, id: int, /): pass @@ -54,7 +52,6 @@ def test_methods_inherited_from_aiohttp_view_should_not_be_decorated(): def test_allowed_methods_are_set_correctly(): - class ChildView(AiohttpViewParent, PydanticView): async def post(self, id: int, /): pass diff --git a/tests/test_oas/test_cmd/test_cmd.py b/tests/test_oas/test_cmd/test_cmd.py index bd50f27..f9d78ca 100644 --- a/tests/test_oas/test_cmd/test_cmd.py +++ b/tests/test_oas/test_cmd/test_cmd.py @@ -22,7 +22,7 @@ def test_show_oas_of_app(cmd_line): args.func(args) expected = dedent( - """ + """ { "info": { "title": "Aiohttp pydantic application", @@ -73,7 +73,7 @@ def test_show_oas_of_sub_app(cmd_line): args.output = StringIO() args.func(args) expected = dedent( - """ + """ { "info": { "title": "Aiohttp pydantic application", diff --git a/tests/test_oas/test_struct/test_info.py b/tests/test_oas/test_struct/test_info.py index 937782c..642072b 100644 --- a/tests/test_oas/test_struct/test_info.py +++ b/tests/test_oas/test_struct/test_info.py @@ -37,7 +37,10 @@ def test_info_version(): assert oas.info.version == "1.0.0" oas.info.version = "3.14" assert oas.info.version == "3.14" - assert oas.spec == {"info": {"version": "3.14", "title": "Aiohttp pydantic application"}, "openapi": "3.0.0"} + assert oas.spec == { + "info": {"version": "3.14", "title": "Aiohttp pydantic application"}, + "openapi": "3.0.0", + } def test_info_terms_of_service(): diff --git a/tests/test_oas/test_view.py b/tests/test_oas/test_view.py index 364e8eb..b95c60b 100644 --- a/tests/test_oas/test_view.py +++ b/tests/test_oas/test_view.py @@ -318,22 +318,32 @@ async def test_simple_type_route_should_have_get_method(generated_oas): }, } + async def test_generated_view_info_default(): apps = (web.Application(),) spec = generate_oas(apps) - assert spec == {'info': {'title': 'Aiohttp pydantic application', 'version': '1.0.0'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "Aiohttp pydantic application", "version": "1.0.0"}, + "openapi": "3.0.0", + } async def test_generated_view_info_as_version(): apps = (web.Application(),) spec = generate_oas(apps, version_spec="test version") - assert spec == {'info': {'title': 'Aiohttp pydantic application', 'version': 'test version'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "Aiohttp pydantic application", "version": "test version"}, + "openapi": "3.0.0", + } async def test_generated_view_info_as_title(): apps = (web.Application(),) spec = generate_oas(apps, title_spec="test title") - assert spec == {'info': {'title': 'test title', 'version': '1.0.0'}, 'openapi': '3.0.0'} + assert spec == { + "info": {"title": "test title", "version": "1.0.0"}, + "openapi": "3.0.0", + } diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 57d886e..4b8913f 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -11,7 +11,7 @@ class ArticleView(PydanticView): with_comments: bool, age: Optional[int] = None, nb_items: int = 7, - tags: List[str] = Field(default_factory=list) + tags: List[str] = Field(default_factory=list), ): return web.json_response( {