parse path and header
This commit is contained in:
parent
cc0cbfbb5e
commit
b37c03e9d9
@ -1,102 +1,3 @@
|
|||||||
from aiohttp.abc import AbstractView
|
from .view import PydanticView
|
||||||
from aiohttp.hdrs import METH_ALL
|
|
||||||
from aiohttp.web_exceptions import HTTPMethodNotAllowed
|
|
||||||
from aiohttp.web_response import StreamResponse
|
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from typing import Generator, Any
|
|
||||||
from aiohttp.web import json_response
|
|
||||||
|
|
||||||
|
__all__ = ("PydanticView",)
|
||||||
class PydanticView(AbstractView):
|
|
||||||
|
|
||||||
async def _iter(self) -> StreamResponse:
|
|
||||||
method = getattr(self, self.request.method.lower(), None)
|
|
||||||
resp = await method()
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def __await__(self) -> Generator[Any, None, StreamResponse]:
|
|
||||||
return self._iter().__await__()
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
allowed_methods = {
|
|
||||||
meth_name for meth_name in METH_ALL
|
|
||||||
if hasattr(cls, meth_name.lower())}
|
|
||||||
|
|
||||||
async def raise_not_allowed(self):
|
|
||||||
raise HTTPMethodNotAllowed(self.request.method, allowed_methods)
|
|
||||||
|
|
||||||
if 'GET' in allowed_methods:
|
|
||||||
cls.get = inject_qs(cls.get)
|
|
||||||
if 'POST' in allowed_methods:
|
|
||||||
cls.post = inject_body(cls.post)
|
|
||||||
if 'PUT' in allowed_methods:
|
|
||||||
cls.put = inject_body(cls.put)
|
|
||||||
|
|
||||||
for meth_name in METH_ALL:
|
|
||||||
if meth_name not in allowed_methods:
|
|
||||||
setattr(cls, meth_name.lower(), raise_not_allowed)
|
|
||||||
|
|
||||||
|
|
||||||
def inject_qs(handler):
|
|
||||||
"""
|
|
||||||
Decorator to unpack the query string in the parameters of the web handler
|
|
||||||
regarding annotations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nb_header_params = handler.__code__.co_kwonlyargcount
|
|
||||||
if nb_header_params:
|
|
||||||
from_qs = frozenset(handler.__code__.co_varnames[:-nb_header_params])
|
|
||||||
from_header = frozenset(handler.__code__.co_varnames[-nb_header_params:])
|
|
||||||
else:
|
|
||||||
from_qs = frozenset(handler.__code__.co_varnames)
|
|
||||||
from_header = frozenset()
|
|
||||||
|
|
||||||
qs_model_class = type(
|
|
||||||
'QSModel', (BaseModel,),
|
|
||||||
{'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_qs and k != 'self'}})
|
|
||||||
|
|
||||||
header_model_class = type(
|
|
||||||
'HeaderModel', (BaseModel,),
|
|
||||||
{'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_header and k != 'self'}})
|
|
||||||
|
|
||||||
async def wrapped_handler(self):
|
|
||||||
try:
|
|
||||||
qs = qs_model_class(**self.request.query)
|
|
||||||
header = header_model_class(**self.request.headers)
|
|
||||||
|
|
||||||
except ValidationError as error:
|
|
||||||
return json_response(text=error.json(), status=400)
|
|
||||||
# raise HTTPBadRequest(
|
|
||||||
# reason='\n'.join(
|
|
||||||
# f'Error with query string parameter {", ".join(err["loc"])}:'
|
|
||||||
# f' {err["msg"]}' for err in error.errors()))
|
|
||||||
|
|
||||||
return await handler(self, **qs.dict(), **header.dict())
|
|
||||||
|
|
||||||
return wrapped_handler
|
|
||||||
|
|
||||||
|
|
||||||
def inject_body(handler):
|
|
||||||
"""
|
|
||||||
Decorator to inject the request body as parameter of the web handler
|
|
||||||
regarding annotations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
arg_name, model_class = next(
|
|
||||||
((arg_name, arg_type)
|
|
||||||
for arg_name, arg_type in handler.__annotations__.items()
|
|
||||||
if issubclass(arg_type, BaseModel)), (None, None))
|
|
||||||
|
|
||||||
if arg_name is None:
|
|
||||||
return handler
|
|
||||||
|
|
||||||
async def wrapped_handler(self):
|
|
||||||
body = await self.request.json()
|
|
||||||
try:
|
|
||||||
model = model_class(**body)
|
|
||||||
except ValidationError as error:
|
|
||||||
return json_response(text=error.json(), status=400)
|
|
||||||
|
|
||||||
return await handler(self, **{arg_name: model})
|
|
||||||
|
|
||||||
return wrapped_handler
|
|
||||||
|
109
aiohttp_pydantic/injectors.py
Normal file
109
aiohttp_pydantic/injectors.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from aiohttp.web_request import BaseRequest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractInjector(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
An injector parse HTTP request and inject params to the view.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(self, args_spec: dict):
|
||||||
|
"""
|
||||||
|
args_spec - ordered mapping: arg_name -> type
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||||
|
"""
|
||||||
|
Get elements in request and inject them in args_view or kwargs_view.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MatchInfoGetter(AbstractInjector):
|
||||||
|
"""
|
||||||
|
Validates and injects the part of URL path inside the view positional args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args_spec: dict):
|
||||||
|
self.model = type("PathModel", (BaseModel,), {"__annotations__": args_spec})
|
||||||
|
|
||||||
|
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||||
|
args_view.extend(self.model(**request.match_info).dict().values())
|
||||||
|
|
||||||
|
|
||||||
|
class BodyGetter(AbstractInjector):
|
||||||
|
"""
|
||||||
|
Validates and injects the content of request body inside the view kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args_spec: dict):
|
||||||
|
self.arg_name, self.model = next(iter(args_spec.items()))
|
||||||
|
|
||||||
|
async def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||||
|
body = await request.json()
|
||||||
|
kwargs_view[self.arg_name] = self.model(**body)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryGetter(AbstractInjector):
|
||||||
|
"""
|
||||||
|
Validates and injects the query string inside the view kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args_spec: dict):
|
||||||
|
self.model = type("QueryModel", (BaseModel,), {"__annotations__": args_spec})
|
||||||
|
|
||||||
|
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||||
|
kwargs_view.update(self.model(**request.query).dict())
|
||||||
|
|
||||||
|
|
||||||
|
class HeadersGetter(AbstractInjector):
|
||||||
|
"""
|
||||||
|
Validates and injects the HTTP headers inside the view kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args_spec: dict):
|
||||||
|
self.model = type("HeaderModel", (BaseModel,), {"__annotations__": args_spec})
|
||||||
|
|
||||||
|
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||||
|
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
|
||||||
|
kwargs_view.update(self.model(**header).dict())
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]:
|
||||||
|
"""
|
||||||
|
Analyse function signature and returns 4-tuple:
|
||||||
|
0 - arguments will be set from the url path
|
||||||
|
1 - argument will be set from the request body.
|
||||||
|
2 - argument will be set from the query string.
|
||||||
|
3 - argument will be set from the HTTP headers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path_args = {}
|
||||||
|
body_args = {}
|
||||||
|
qs_args = {}
|
||||||
|
header_args = {}
|
||||||
|
|
||||||
|
for param_name, param_spec in signature(func).parameters.items():
|
||||||
|
if param_name == "self":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param_spec.kind is param_spec.POSITIONAL_ONLY:
|
||||||
|
path_args[param_name] = param_spec.annotation
|
||||||
|
elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD:
|
||||||
|
if issubclass(param_spec.annotation, BaseModel):
|
||||||
|
body_args[param_name] = param_spec.annotation
|
||||||
|
else:
|
||||||
|
qs_args[param_name] = param_spec.annotation
|
||||||
|
elif param_spec.kind is param_spec.KEYWORD_ONLY:
|
||||||
|
header_args[param_name] = param_spec.annotation
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
|
||||||
|
|
||||||
|
return path_args, body_args, qs_args, header_args
|
92
aiohttp_pydantic/view.py
Normal file
92
aiohttp_pydantic/view.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from inspect import iscoroutinefunction
|
||||||
|
|
||||||
|
from aiohttp.abc import AbstractView
|
||||||
|
from aiohttp.hdrs import METH_ALL
|
||||||
|
from aiohttp.web_exceptions import HTTPMethodNotAllowed
|
||||||
|
from aiohttp.web_response import StreamResponse
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from typing import Generator, Any, Callable, List, Iterable
|
||||||
|
from aiohttp.web import json_response
|
||||||
|
from functools import update_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
from .injectors import (
|
||||||
|
MatchInfoGetter,
|
||||||
|
HeadersGetter,
|
||||||
|
QueryGetter,
|
||||||
|
BodyGetter,
|
||||||
|
AbstractInjector,
|
||||||
|
_parse_func_signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticView(AbstractView):
|
||||||
|
"""
|
||||||
|
An AIOHTTP View that validate request using function annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _iter(self) -> StreamResponse:
|
||||||
|
method = getattr(self, self.request.method.lower(), None)
|
||||||
|
resp = await method()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def __await__(self) -> Generator[Any, None, StreamResponse]:
|
||||||
|
return self._iter().__await__()
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
allowed_methods = {
|
||||||
|
meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower())
|
||||||
|
}
|
||||||
|
|
||||||
|
async def raise_not_allowed(self):
|
||||||
|
raise HTTPMethodNotAllowed(self.request.method, allowed_methods)
|
||||||
|
|
||||||
|
for meth_name in METH_ALL:
|
||||||
|
if meth_name not in allowed_methods:
|
||||||
|
setattr(cls, meth_name.lower(), raise_not_allowed)
|
||||||
|
else:
|
||||||
|
handler = getattr(cls, meth_name.lower())
|
||||||
|
decorated_handler = inject_params(handler, cls.parse_func_signature)
|
||||||
|
setattr(cls, meth_name.lower(), decorated_handler)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]:
|
||||||
|
path_args, body_args, qs_args, header_args = _parse_func_signature(func)
|
||||||
|
injectors = []
|
||||||
|
if path_args:
|
||||||
|
injectors.append(MatchInfoGetter(path_args))
|
||||||
|
if body_args:
|
||||||
|
injectors.append(BodyGetter(body_args))
|
||||||
|
if qs_args:
|
||||||
|
injectors.append(QueryGetter(qs_args))
|
||||||
|
if header_args:
|
||||||
|
injectors.append(HeadersGetter(header_args))
|
||||||
|
return injectors
|
||||||
|
|
||||||
|
|
||||||
|
def inject_params(
|
||||||
|
handler, parse_func_signature: Callable[[Callable], Iterable[AbstractInjector]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator to unpack the query string, route path, body and http header in
|
||||||
|
the parameters of the web handler regarding annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
injectors = parse_func_signature(handler)
|
||||||
|
|
||||||
|
async def wrapped_handler(self):
|
||||||
|
args = []
|
||||||
|
kwargs = {}
|
||||||
|
for injector in injectors:
|
||||||
|
try:
|
||||||
|
if iscoroutinefunction(injector.inject):
|
||||||
|
await injector.inject(self.request, args, kwargs)
|
||||||
|
else:
|
||||||
|
injector.inject(self.request, args, kwargs)
|
||||||
|
except ValidationError as error:
|
||||||
|
return json_response(text=error.json(), status=400)
|
||||||
|
|
||||||
|
return await handler(self, *args, **kwargs)
|
||||||
|
|
||||||
|
update_wrapper(wrapped_handler, handler)
|
||||||
|
return wrapped_handler
|
1
setup.py
1
setup.py
@ -17,7 +17,6 @@ setup(
|
|||||||
'Programming Language :: Python',
|
'Programming Language :: Python',
|
||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
'Programming Language :: Python :: 3 :: Only',
|
'Programming Language :: Python :: 3 :: Only',
|
||||||
'Programming Language :: Python :: 3.6',
|
|
||||||
'Programming Language :: Python :: 3.7',
|
'Programming Language :: Python :: 3.7',
|
||||||
'Programming Language :: Python :: 3.8',
|
'Programming Language :: Python :: 3.8',
|
||||||
'Programming Language :: Python :: 3.9',
|
'Programming Language :: Python :: 3.9',
|
||||||
|
49
tests/test_parse_func_signature.py
Normal file
49
tests/test_parse_func_signature.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from aiohttp_pydantic.injectors import _parse_func_signature
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
firstname: str
|
||||||
|
lastname: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_func_signature():
|
||||||
|
|
||||||
|
def body_only(self, user: User):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def path_only(self, id: str, /):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def qs_only(self, page: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def header_only(self, *, auth: UUID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def path_and_qs(self, id: str, /, page: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def path_and_header(self, id: str, /, *, auth: UUID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def qs_and_header(self, page: int, *, auth: UUID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def path_qs_and_header(self, id: str, /, page: int, *, auth: UUID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def path_body_qs_and_header(self, id: str, /, user: User, page: int, *, auth: UUID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert _parse_func_signature(body_only) == ({}, {'user': User}, {}, {})
|
||||||
|
assert _parse_func_signature(path_only) == ({'id': str}, {}, {}, {})
|
||||||
|
assert _parse_func_signature(qs_only) == ({}, {}, {'page': int}, {})
|
||||||
|
assert _parse_func_signature(header_only) == ({}, {}, {}, {'auth': UUID})
|
||||||
|
assert _parse_func_signature(path_and_qs) == ({'id': str}, {}, {'page': int}, {})
|
||||||
|
assert _parse_func_signature(path_and_header) == ({'id': str}, {}, {}, {'auth': UUID})
|
||||||
|
assert _parse_func_signature(qs_and_header) == ({}, {}, {'page': int}, {'auth': UUID})
|
||||||
|
assert _parse_func_signature(path_qs_and_header) == ({'id': str}, {}, {'page': int}, {'auth': UUID})
|
||||||
|
assert _parse_func_signature(path_body_qs_and_header) == ({'id': str}, {'user': User}, {'page': int}, {'auth': UUID})
|
||||||
|
|
@ -54,3 +54,14 @@ async def test_get_article_with_valid_header_should_return_the_parsed_type(aioht
|
|||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert resp.content_type == 'application/json'
|
assert resp.content_type == 'application/json'
|
||||||
assert await resp.json() == {'signature': '2020-10-04T18:01:00'}
|
assert await resp.json() == {'signature': '2020-10-04T18:01:00'}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_article_with_valid_header_containing_hyphen_should_be_returned(aiohttp_client, loop):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_view('/article', ArticleView)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/article', headers={'Signature-Expired': '2020-10-04T18:01:00'})
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.content_type == 'application/json'
|
||||||
|
assert await resp.json() == {'signature': '2020-10-04T18:01:00'}
|
||||||
|
20
tests/test_validation_path.py
Normal file
20
tests/test_validation_path.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from aiohttp_pydantic import PydanticView
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleView(PydanticView):
|
||||||
|
|
||||||
|
async def get(self, author_id: str, tag: str, date: int, /):
|
||||||
|
return web.json_response({'path': [author_id, tag, date]})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_article_without_required_qs_should_return_an_error_message(aiohttp_client, loop):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_view('/article/{author_id}/tag/{tag}/before/{date}', ArticleView)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/article/1234/tag/music/before/1980')
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.content_type == 'application/json'
|
||||||
|
assert await resp.json() == {'path': ['1234', 'music', 1980]}
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user