From cc0cbfbb5e171f09d60ecdc7b582612b85a48f67 Mon Sep 17 00:00:00 2001 From: Vincent Maillol Date: Sun, 4 Oct 2020 20:55:13 +0200 Subject: [PATCH] add HTTP header parsing --- aiohttp_pydantic/__init__.py | 19 ++++++++- tests/test_validation_header.py | 56 +++++++++++++++++++++++++++ tests/test_validation_query_string.py | 2 +- 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 tests/test_validation_header.py diff --git a/aiohttp_pydantic/__init__.py b/aiohttp_pydantic/__init__.py index 3acbe6e..7ac156b 100644 --- a/aiohttp_pydantic/__init__.py +++ b/aiohttp_pydantic/__init__.py @@ -42,13 +42,28 @@ def inject_qs(handler): Decorator to unpack the query string in the parameters of the web handler regarding annotations. """ + + nb_header_params = handler.__code__.co_kwonlyargcount + if nb_header_params: + from_qs = frozenset(handler.__code__.co_varnames[:-nb_header_params]) + from_header = frozenset(handler.__code__.co_varnames[-nb_header_params:]) + else: + from_qs = frozenset(handler.__code__.co_varnames) + from_header = frozenset() + qs_model_class = type( 'QSModel', (BaseModel,), - {'__annotations__': handler.__annotations__}) + {'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_qs and k != 'self'}}) + + header_model_class = type( + 'HeaderModel', (BaseModel,), + {'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_header and k != 'self'}}) async def wrapped_handler(self): try: qs = qs_model_class(**self.request.query) + header = header_model_class(**self.request.headers) + except ValidationError as error: return json_response(text=error.json(), status=400) # raise HTTPBadRequest( @@ -56,7 +71,7 @@ def inject_qs(handler): # f'Error with query string parameter {", ".join(err["loc"])}:' # f' {err["msg"]}' for err in error.errors())) - return await handler(self, **qs.dict()) + return await handler(self, **qs.dict(), **header.dict()) return wrapped_handler diff --git a/tests/test_validation_header.py b/tests/test_validation_header.py new file mode 100644 index 0000000..2ccc1c3 --- /dev/null +++ b/tests/test_validation_header.py @@ -0,0 +1,56 @@ +from aiohttp import web +from aiohttp_pydantic import PydanticView +from datetime import datetime +import json + + +class JSONEncoder(json.JSONEncoder): + + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + + return json.JSONEncoder.default(self, o) + + +class ArticleView(PydanticView): + + async def get(self, *, signature_expired: datetime): + return web.json_response({'signature': signature_expired}, dumps=JSONEncoder().encode) + + +async def test_get_article_without_required_header_should_return_an_error_message(aiohttp_client, loop): + app = web.Application() + app.router.add_view('/article', ArticleView) + + client = await aiohttp_client(app) + resp = await client.get('/article', headers={}) + assert resp.status == 400 + assert resp.content_type == 'application/json' + assert await resp.json() == [{'loc': ['signature_expired'], + 'msg': 'field required', + 'type': 'value_error.missing'}] + + +async def test_get_article_with_wrong_header_type_should_return_an_error_message(aiohttp_client, loop): + app = web.Application() + app.router.add_view('/article', ArticleView) + + client = await aiohttp_client(app) + resp = await client.get('/article', headers={'signature_expired': 'foo'}) + assert resp.status == 400 + assert resp.content_type == 'application/json' + assert await resp.json() == [{'loc': ['signature_expired'], + 'msg': 'invalid datetime format', + 'type': 'value_error.datetime'}] + + +async def test_get_article_with_valid_header_should_return_the_parsed_type(aiohttp_client, loop): + app = web.Application() + app.router.add_view('/article', ArticleView) + + client = await aiohttp_client(app) + resp = await client.get('/article', headers={'signature_expired': '2020-10-04T18:01:00'}) + assert resp.status == 200 + assert resp.content_type == 'application/json' + assert await resp.json() == {'signature': '2020-10-04T18:01:00'} diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 905c4da..bd43366 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -34,7 +34,7 @@ async def test_get_article_with_wrong_qs_type_should_return_an_error_message(aio 'type': 'type_error.bool'}] -async def test_get_article_with_valide_qs_should_return_the_parsed_type(aiohttp_client, loop): +async def test_get_article_with_valid_qs_should_return_the_parsed_type(aiohttp_client, loop): app = web.Application() app.router.add_view('/article', ArticleView)