refactoring

This commit is contained in:
Vincent Maillol 2021-07-10 08:16:27 +02:00
parent c92437c624
commit 08ab4d2610
2 changed files with 64 additions and 47 deletions

View File

@ -1,6 +1,7 @@
from functools import update_wrapper
from inspect import iscoroutinefunction
from typing import Any, Callable, Generator, Iterable, Set
from typing import Any, Callable, Generator, Iterable, Set, ClassVar
import warnings
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL
@ -25,7 +26,7 @@ class PydanticView(AbstractView):
"""
# Allowed HTTP methods; overridden when subclassed.
allowed_methods: Set[str] = {}
allowed_methods: ClassVar[Set[str]] = {}
async def _iter(self) -> StreamResponse:
if (method_name := self.request.method) not in self.allowed_methods:
@ -38,30 +39,30 @@ class PydanticView(AbstractView):
def __init_subclass__(cls, **kwargs) -> None:
"""Define allowed methods and decorate handlers.
Handlers are decorated if and only if they meet the following conditions:
- the handler corresponds to an allowed method;
- the handler method was not inherited from a :class:`PydanticView` base
class. This prevents that methods are decorated multiple times.
Handlers are decorated if and only if they directly bound on the PydanticView class or
PydanticView subclass. This prevents that methods are decorated multiple times and that method
defined in aiohttp.View parent class is decorated.
"""
cls.allowed_methods = {
meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower())
}
for meth_name in METH_ALL:
if meth_name in cls.allowed_methods:
if meth_name.lower() in vars(cls):
handler = getattr(cls, meth_name.lower())
for base_class in cls.__bases__:
if is_pydantic_view(base_class):
parent_handler = getattr(base_class, meth_name.lower(), None)
if handler == parent_handler:
break
else:
decorated_handler = inject_params(handler, cls.parse_func_signature)
setattr(cls, meth_name.lower(), decorated_handler)
decorated_handler = inject_params(handler, cls.parse_func_signature)
setattr(cls, meth_name.lower(), decorated_handler)
def _raise_allowed_methods(self) -> None:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
def raise_not_allowed(self) -> None:
warnings.warn(
"PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods",
DeprecationWarning, stacklevel=2)
self._raise_allowed_methods()
@staticmethod
def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]:
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(

View File

@ -1,6 +1,7 @@
from typing import Any
from aiohttp_pydantic import PydanticView
from aiohttp.web import View
def count_wrappers(obj: Any) -> int:
@ -16,43 +17,58 @@ def count_wrappers(obj: Any) -> int:
raise RuntimeError("Too many wrappers")
class ViewParent(PydanticView):
class AiohttpViewParent(View):
async def put(self):
pass
async def delete(self):
class PydanticViewParent(PydanticView):
async def get(self, id: int, /):
pass
class ViewParentNonPydantic:
async def post(self):
pass
class ViewChild(ViewParent, ViewParentNonPydantic):
async def get(self):
pass
async def delete(self):
pass
async def not_allowed(self):
pass
def test_allowed_methods_are_set_correctly():
assert ViewParent.allowed_methods == {"PUT", "DELETE"}
assert ViewChild.allowed_methods == {"GET", "POST", "PUT", "DELETE"}
def test_allowed_methods_get_decorated_exactly_once():
assert count_wrappers(ViewParent.put) == 1
assert count_wrappers(ViewParent.delete) == 1
assert count_wrappers(ViewChild.get) == 1
assert count_wrappers(ViewChild.post) == 1
assert count_wrappers(ViewChild.put) == 1
assert count_wrappers(ViewChild.post) == 1
assert count_wrappers(ViewChild.put) == 1
assert count_wrappers(ViewChild.not_allowed) == 0
assert count_wrappers(ViewParentNonPydantic.post) == 0
class ChildView(PydanticViewParent):
async def post(self, id: int, /):
pass
class SubChildView(ChildView):
async def get(self, id: int, /):
return super().get(id)
assert count_wrappers(ChildView.post) == 1
assert count_wrappers(ChildView.get) == 1
assert count_wrappers(SubChildView.post) == 1
assert count_wrappers(SubChildView.get) == 1
def test_methods_inherited_from_aiohttp_view_should_not_be_decorated():
class ChildView(AiohttpViewParent, PydanticView):
async def post(self, id: int, /):
pass
assert count_wrappers(ChildView.put) == 0
assert count_wrappers(ChildView.post) == 1
def test_allowed_methods_are_set_correctly():
class ChildView(AiohttpViewParent, PydanticView):
async def post(self, id: int, /):
pass
assert ChildView.allowed_methods == {"POST", "PUT"}
class ChildView(PydanticViewParent):
async def post(self, id: int, /):
pass
assert ChildView.allowed_methods == {"POST", "GET"}
class ChildView(AiohttpViewParent, PydanticViewParent):
async def post(self, id: int, /):
pass
assert ChildView.allowed_methods == {"POST", "PUT", "GET"}