add HTTP header parsing
This commit is contained in:
parent
88fd7b9270
commit
cc0cbfbb5e
@ -42,13 +42,28 @@ def inject_qs(handler):
|
|||||||
Decorator to unpack the query string in the parameters of the web handler
|
Decorator to unpack the query string in the parameters of the web handler
|
||||||
regarding annotations.
|
regarding annotations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
nb_header_params = handler.__code__.co_kwonlyargcount
|
||||||
|
if nb_header_params:
|
||||||
|
from_qs = frozenset(handler.__code__.co_varnames[:-nb_header_params])
|
||||||
|
from_header = frozenset(handler.__code__.co_varnames[-nb_header_params:])
|
||||||
|
else:
|
||||||
|
from_qs = frozenset(handler.__code__.co_varnames)
|
||||||
|
from_header = frozenset()
|
||||||
|
|
||||||
qs_model_class = type(
|
qs_model_class = type(
|
||||||
'QSModel', (BaseModel,),
|
'QSModel', (BaseModel,),
|
||||||
{'__annotations__': handler.__annotations__})
|
{'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_qs and k != 'self'}})
|
||||||
|
|
||||||
|
header_model_class = type(
|
||||||
|
'HeaderModel', (BaseModel,),
|
||||||
|
{'__annotations__': {k: v for k, v in handler.__annotations__.items() if k in from_header and k != 'self'}})
|
||||||
|
|
||||||
async def wrapped_handler(self):
|
async def wrapped_handler(self):
|
||||||
try:
|
try:
|
||||||
qs = qs_model_class(**self.request.query)
|
qs = qs_model_class(**self.request.query)
|
||||||
|
header = header_model_class(**self.request.headers)
|
||||||
|
|
||||||
except ValidationError as error:
|
except ValidationError as error:
|
||||||
return json_response(text=error.json(), status=400)
|
return json_response(text=error.json(), status=400)
|
||||||
# raise HTTPBadRequest(
|
# raise HTTPBadRequest(
|
||||||
@ -56,7 +71,7 @@ def inject_qs(handler):
|
|||||||
# f'Error with query string parameter {", ".join(err["loc"])}:'
|
# f'Error with query string parameter {", ".join(err["loc"])}:'
|
||||||
# f' {err["msg"]}' for err in error.errors()))
|
# f' {err["msg"]}' for err in error.errors()))
|
||||||
|
|
||||||
return await handler(self, **qs.dict())
|
return await handler(self, **qs.dict(), **header.dict())
|
||||||
|
|
||||||
return wrapped_handler
|
return wrapped_handler
|
||||||
|
|
||||||
|
56
tests/test_validation_header.py
Normal file
56
tests/test_validation_header.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from aiohttp_pydantic import PydanticView
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, datetime):
|
||||||
|
return o.isoformat()
|
||||||
|
|
||||||
|
return json.JSONEncoder.default(self, o)
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleView(PydanticView):
|
||||||
|
|
||||||
|
async def get(self, *, signature_expired: datetime):
|
||||||
|
return web.json_response({'signature': signature_expired}, dumps=JSONEncoder().encode)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_article_without_required_header_should_return_an_error_message(aiohttp_client, loop):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_view('/article', ArticleView)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/article', headers={})
|
||||||
|
assert resp.status == 400
|
||||||
|
assert resp.content_type == 'application/json'
|
||||||
|
assert await resp.json() == [{'loc': ['signature_expired'],
|
||||||
|
'msg': 'field required',
|
||||||
|
'type': 'value_error.missing'}]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_article_with_wrong_header_type_should_return_an_error_message(aiohttp_client, loop):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_view('/article', ArticleView)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/article', headers={'signature_expired': 'foo'})
|
||||||
|
assert resp.status == 400
|
||||||
|
assert resp.content_type == 'application/json'
|
||||||
|
assert await resp.json() == [{'loc': ['signature_expired'],
|
||||||
|
'msg': 'invalid datetime format',
|
||||||
|
'type': 'value_error.datetime'}]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_article_with_valid_header_should_return_the_parsed_type(aiohttp_client, loop):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_view('/article', ArticleView)
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get('/article', headers={'signature_expired': '2020-10-04T18:01:00'})
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.content_type == 'application/json'
|
||||||
|
assert await resp.json() == {'signature': '2020-10-04T18:01:00'}
|
@ -34,7 +34,7 @@ async def test_get_article_with_wrong_qs_type_should_return_an_error_message(aio
|
|||||||
'type': 'type_error.bool'}]
|
'type': 'type_error.bool'}]
|
||||||
|
|
||||||
|
|
||||||
async def test_get_article_with_valide_qs_should_return_the_parsed_type(aiohttp_client, loop):
|
async def test_get_article_with_valid_qs_should_return_the_parsed_type(aiohttp_client, loop):
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_view('/article', ArticleView)
|
app.router.add_view('/article', ArticleView)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user