From c4c18ee4a10ed2160db11deda399ae9fb4d7c056 Mon Sep 17 00:00:00 2001 From: Vincent Maillol Date: Sat, 28 Nov 2020 12:39:09 +0100 Subject: [PATCH] increase pydantic integration with headers, query string and url path --- README.rst | 6 ++-- aiohttp_pydantic/injectors.py | 48 +++++++++++++++++++++------ aiohttp_pydantic/oas/typing.py | 2 +- aiohttp_pydantic/oas/view.py | 40 +++++++++++++--------- aiohttp_pydantic/view.py | 35 +++++++++++-------- tests/test_oas/test_cmd/test_cmd.py | 4 +++ tests/test_oas/test_view.py | 12 +++---- tests/test_parse_func_signature.py | 20 ++++++++--- tests/test_validation_body.py | 8 ++++- tests/test_validation_header.py | 47 ++++++++++++++++++++++++++ tests/test_validation_path.py | 22 +++++++++++- tests/test_validation_query_string.py | 16 ++++++--- 12 files changed, 199 insertions(+), 61 deletions(-) diff --git a/README.rst b/README.rst index 7ff68e5..c56fb34 100644 --- a/README.rst +++ b/README.rst @@ -68,6 +68,7 @@ Example: $ curl -X GET http://127.0.0.1:8080/article?with_comments=a [ { + "in": "query string", "loc": [ "with_comments" ], @@ -82,6 +83,7 @@ Example: $ curl -H "Content-Type: application/json" -X post http://127.0.0.1:8080/article --data '{}' [ { + "in": "body", "loc": [ "name" ], @@ -116,7 +118,7 @@ Example: Inject Query String Parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To declare a query parameters, you must declare your argument as simple argument: +To declare a query parameters, you must declare your argument as a simple argument: .. code-block:: python3 @@ -131,7 +133,7 @@ To declare a query parameters, you must declare your argument as simple argument Inject Request Body ~~~~~~~~~~~~~~~~~~~ -To declare a body parameters, you must declare your argument as a simple argument annotated with `pydantic Model`_. +To declare a body parameter, you must declare your argument as a simple argument annotated with `pydantic Model`_. .. code-block:: python3 diff --git a/aiohttp_pydantic/injectors.py b/aiohttp_pydantic/injectors.py index 74e7eb9..caed0fa 100644 --- a/aiohttp_pydantic/injectors.py +++ b/aiohttp_pydantic/injectors.py @@ -15,8 +15,16 @@ class AbstractInjector(metaclass=abc.ABCMeta): An injector parse HTTP request and inject params to the view. """ + @property @abc.abstractmethod - def __init__(self, args_spec: dict): + def context(self) -> str: + """ + The name of part of parsed request + i.e "HTTP header", "URL path", ... + """ + + @abc.abstractmethod + def __init__(self, args_spec: dict, default_values: dict): """ args_spec - ordered mapping: arg_name -> type """ @@ -33,8 +41,12 @@ class MatchInfoGetter(AbstractInjector): Validates and injects the part of URL path inside the view positional args. """ - def __init__(self, args_spec: dict): - self.model = type("PathModel", (BaseModel,), {"__annotations__": args_spec}) + context = "path" + + def __init__(self, args_spec: dict, default_values: dict): + attrs = {"__annotations__": args_spec} + attrs.update(default_values) + self.model = type("PathModel", (BaseModel,), attrs) def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): args_view.extend(self.model(**request.match_info).dict().values()) @@ -45,7 +57,9 @@ class BodyGetter(AbstractInjector): Validates and injects the content of request body inside the view kwargs. """ - def __init__(self, args_spec: dict): + context = "body" + + def __init__(self, args_spec: dict, default_values: dict): self.arg_name, self.model = next(iter(args_spec.items())) async def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): @@ -64,8 +78,12 @@ class QueryGetter(AbstractInjector): Validates and injects the query string inside the view kwargs. """ - def __init__(self, args_spec: dict): - self.model = type("QueryModel", (BaseModel,), {"__annotations__": args_spec}) + context = "query string" + + def __init__(self, args_spec: dict, default_values: dict): + attrs = {"__annotations__": args_spec} + attrs.update(default_values) + self.model = type("QueryModel", (BaseModel,), attrs) def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): kwargs_view.update(self.model(**request.query).dict()) @@ -76,27 +94,33 @@ class HeadersGetter(AbstractInjector): Validates and injects the HTTP headers inside the view kwargs. """ - def __init__(self, args_spec: dict): - self.model = type("HeaderModel", (BaseModel,), {"__annotations__": args_spec}) + context = "headers" + + def __init__(self, args_spec: dict, default_values: dict): + attrs = {"__annotations__": args_spec} + attrs.update(default_values) + self.model = type("HeaderModel", (BaseModel,), attrs) def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()} kwargs_view.update(self.model(**header).dict()) -def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]: +def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]: """ Analyse function signature and returns 4-tuple: 0 - arguments will be set from the url path 1 - argument will be set from the request body. 2 - argument will be set from the query string. 3 - argument will be set from the HTTP headers. + 4 - Default value for each parameters """ path_args = {} body_args = {} qs_args = {} header_args = {} + defaults = {} for param_name, param_spec in signature(func).parameters.items(): if param_name == "self": @@ -105,8 +129,12 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]: if param_spec.annotation == param_spec.empty: raise RuntimeError(f"The parameter {param_name} must have an annotation") + if param_spec.default is not param_spec.empty: + defaults[param_name] = param_spec.default + if param_spec.kind is param_spec.POSITIONAL_ONLY: path_args[param_name] = param_spec.annotation + elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD: if is_pydantic_base_model(param_spec.annotation): body_args[param_name] = param_spec.annotation @@ -117,4 +145,4 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]: else: raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters") - return path_args, body_args, qs_args, header_args + return path_args, body_args, qs_args, header_args, defaults diff --git a/aiohttp_pydantic/oas/typing.py b/aiohttp_pydantic/oas/typing.py index a5f4708..da9e7da 100644 --- a/aiohttp_pydantic/oas/typing.py +++ b/aiohttp_pydantic/oas/typing.py @@ -13,7 +13,7 @@ Example: from functools import lru_cache from types import new_class -from typing import Protocol, TypeVar, Optional, Type +from typing import Protocol, TypeVar RespContents = TypeVar("RespContents", covariant=True) diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index de1f222..343117b 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -7,6 +7,7 @@ from uuid import UUID from aiohttp.web import Response, json_response from aiohttp.web_app import Application +from pydantic import BaseModel from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem @@ -93,7 +94,9 @@ def _add_http_method_to_oas( http_method = http_method.lower() oas_operation: OperationObject = getattr(oas_path, http_method) handler = getattr(view, http_method) - path_args, body_args, qs_args, header_args = _parse_func_signature(handler) + path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( + handler + ) description = getdoc(handler) if description: oas_operation.description = description @@ -114,12 +117,15 @@ def _add_http_method_to_oas( oas_operation.parameters[i].in_ = args_location oas_operation.parameters[i].name = name optional_type = _handle_optional(type_) - if optional_type is None: - oas_operation.parameters[i].schema = JSON_SCHEMA_TYPES[type_] - oas_operation.parameters[i].required = True - else: - oas_operation.parameters[i].schema = JSON_SCHEMA_TYPES[optional_type] - oas_operation.parameters[i].required = False + + attrs = {"__annotations__": {"__root__": type_}} + if name in defaults: + attrs["__root__"] = defaults[name] + + oas_operation.parameters[i].schema = type( + name, (BaseModel,), attrs + ).schema() + oas_operation.parameters[i].required = optional_type is None return_type = handler.__annotations__.get("return") if return_type is not None: @@ -134,15 +140,17 @@ def generate_oas(apps: List[Application]) -> dict: for app in apps: for resources in app.router.resources(): for resource_route in resources: - if is_pydantic_view(resource_route.handler): - view: Type[PydanticView] = resource_route.handler - info = resource_route.get_info() - path = oas.paths[info.get("path", info.get("formatter"))] - if resource_route.method == "*": - for method_name in view.allowed_methods: - _add_http_method_to_oas(path, method_name, view) - else: - _add_http_method_to_oas(path, resource_route.method, view) + if not is_pydantic_view(resource_route.handler): + continue + + view: Type[PydanticView] = resource_route.handler + info = resource_route.get_info() + path = oas.paths[info.get("path", info.get("formatter"))] + if resource_route.method == "*": + for method_name in view.allowed_methods: + _add_http_method_to_oas(path, method_name, view) + else: + _add_http_method_to_oas(path, resource_route.method, view) return oas.spec diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index e04f71e..2218a3a 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -9,14 +9,8 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_response import StreamResponse from pydantic import ValidationError -from .injectors import ( - AbstractInjector, - BodyGetter, - HeadersGetter, - MatchInfoGetter, - QueryGetter, - _parse_func_signature, -) +from .injectors import (AbstractInjector, BodyGetter, HeadersGetter, + MatchInfoGetter, QueryGetter, _parse_func_signature) class PydanticView(AbstractView): @@ -50,16 +44,25 @@ class PydanticView(AbstractView): @staticmethod def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]: - path_args, body_args, qs_args, header_args = _parse_func_signature(func) + path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( + func + ) injectors = [] + + def default_value(args: dict) -> dict: + """ + Returns the default values of args. + """ + return {name: defaults[name] for name in args if name in defaults} + if path_args: - injectors.append(MatchInfoGetter(path_args)) + injectors.append(MatchInfoGetter(path_args, default_value(path_args))) if body_args: - injectors.append(BodyGetter(body_args)) + injectors.append(BodyGetter(body_args, default_value(body_args))) if qs_args: - injectors.append(QueryGetter(qs_args)) + injectors.append(QueryGetter(qs_args, default_value(qs_args))) if header_args: - injectors.append(HeadersGetter(header_args)) + injectors.append(HeadersGetter(header_args, default_value(header_args))) return injectors @@ -83,7 +86,11 @@ def inject_params( else: injector.inject(self.request, args, kwargs) except ValidationError as error: - return json_response(text=error.json(), status=400) + errors = error.errors() + for error in errors: + error["in"] = injector.context + + return json_response(data=errors, status=400) return await handler(self, *args, **kwargs) diff --git a/tests/test_oas/test_cmd/test_cmd.py b/tests/test_oas/test_cmd/test_cmd.py index 754413a..189e46b 100644 --- a/tests/test_oas/test_cmd/test_cmd.py +++ b/tests/test_oas/test_cmd/test_cmd.py @@ -30,6 +30,7 @@ def test_show_oad_of_app(cmd_line, capfd): "name": "a", "required": true, "schema": { + "title": "a", "type": "integer" } } @@ -44,6 +45,7 @@ def test_show_oad_of_app(cmd_line, capfd): "name": "b", "required": true, "schema": { + "title": "b", "type": "integer" } } @@ -75,6 +77,7 @@ def test_show_oad_of_sub_app(cmd_line, capfd): "name": "b", "required": true, "schema": { + "title": "b", "type": "integer" } } @@ -106,6 +109,7 @@ def test_show_oad_of_a_callable(cmd_line, capfd): "name": "a", "required": true, "schema": { + "title": "a", "type": "integer" } } diff --git a/tests/test_oas/test_view.py b/tests/test_oas/test_view.py index a789435..a3144e5 100644 --- a/tests/test_oas/test_view.py +++ b/tests/test_oas/test_view.py @@ -65,19 +65,19 @@ async def test_pets_route_should_have_get_method(generated_oas): "in": "query", "name": "format", "required": True, - "schema": {"type": "string"}, + "schema": {"title": "format", "type": "string"}, }, { "in": "query", "name": "name", "required": False, - "schema": {"type": "string"}, + "schema": {"title": "name", "type": "string"}, }, { "in": "header", "name": "promo", "required": False, - "schema": {"format": "uuid", "type": "string"}, + "schema": {"title": "promo", "format": "uuid", "type": "string"}, }, ], "responses": { @@ -152,7 +152,7 @@ async def test_pets_id_route_should_have_delete_method(generated_oas): "required": True, "in": "path", "name": "id", - "schema": {"type": "integer"}, + "schema": {"title": "id", "type": "integer"}, } ], "responses": {"204": {"content": {}}}, @@ -166,7 +166,7 @@ async def test_pets_id_route_should_have_get_method(generated_oas): "in": "path", "name": "id", "required": True, - "schema": {"type": "integer"}, + "schema": {"title": "id", "type": "integer"}, } ], "responses": { @@ -197,7 +197,7 @@ async def test_pets_id_route_should_have_put_method(generated_oas): "in": "path", "name": "id", "required": True, - "schema": {"type": "integer"}, + "schema": {"title": "id", "type": "integer"}, } ], "requestBody": { diff --git a/tests/test_parse_func_signature.py b/tests/test_parse_func_signature.py index c2cc47b..6e588e7 100644 --- a/tests/test_parse_func_signature.py +++ b/tests/test_parse_func_signature.py @@ -38,32 +38,42 @@ def test_parse_func_signature(): def path_body_qs_and_header(self, id: str, /, user: User, page: int, *, auth: UUID): pass - assert _parse_func_signature(body_only) == ({}, {"user": User}, {}, {}) - assert _parse_func_signature(path_only) == ({"id": str}, {}, {}, {}) - assert _parse_func_signature(qs_only) == ({}, {}, {"page": int}, {}) - assert _parse_func_signature(header_only) == ({}, {}, {}, {"auth": UUID}) - assert _parse_func_signature(path_and_qs) == ({"id": str}, {}, {"page": int}, {}) + assert _parse_func_signature(body_only) == ({}, {"user": User}, {}, {}, {}) + assert _parse_func_signature(path_only) == ({"id": str}, {}, {}, {}, {}) + assert _parse_func_signature(qs_only) == ({}, {}, {"page": int}, {}, {}) + assert _parse_func_signature(header_only) == ({}, {}, {}, {"auth": UUID}, {}) + assert _parse_func_signature(path_and_qs) == ( + {"id": str}, + {}, + {"page": int}, + {}, + {}, + ) assert _parse_func_signature(path_and_header) == ( {"id": str}, {}, {}, {"auth": UUID}, + {}, ) assert _parse_func_signature(qs_and_header) == ( {}, {}, {"page": int}, {"auth": UUID}, + {}, ) assert _parse_func_signature(path_qs_and_header) == ( {"id": str}, {}, {"page": int}, {"auth": UUID}, + {}, ) assert _parse_func_signature(path_body_qs_and_header) == ( {"id": str}, {"user": User}, {"page": int}, {"auth": UUID}, + {}, ) diff --git a/tests/test_validation_body.py b/tests/test_validation_body.py index 18cb97f..424f136 100644 --- a/tests/test_validation_body.py +++ b/tests/test_validation_body.py @@ -27,7 +27,12 @@ async def test_post_an_article_without_required_field_should_return_an_error_mes assert resp.status == 400 assert resp.content_type == "application/json" assert await resp.json() == [ - {"loc": ["name"], "msg": "field required", "type": "value_error.missing"} + { + "in": "body", + "loc": ["name"], + "msg": "field required", + "type": "value_error.missing", + } ] @@ -43,6 +48,7 @@ async def test_post_an_article_with_wrong_type_field_should_return_an_error_mess assert resp.content_type == "application/json" assert await resp.json() == [ { + "in": "body", "loc": ["nb_page"], "msg": "value is not a valid integer", "type": "type_error.integer", diff --git a/tests/test_validation_header.py b/tests/test_validation_header.py index 83558e1..e18a863 100644 --- a/tests/test_validation_header.py +++ b/tests/test_validation_header.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from enum import Enum from aiohttp import web @@ -21,6 +22,16 @@ class ArticleView(PydanticView): ) +class FormatEnum(str, Enum): + UTM = "UMT" + MGRS = "MGRS" + + +class ViewWithEnumType(PydanticView): + async def get(self, *, format: FormatEnum): + return web.json_response({"format": format}, dumps=JSONEncoder().encode) + + async def test_get_article_without_required_header_should_return_an_error_message( aiohttp_client, loop ): @@ -33,6 +44,7 @@ async def test_get_article_without_required_header_should_return_an_error_messag assert resp.content_type == "application/json" assert await resp.json() == [ { + "in": "headers", "loc": ["signature_expired"], "msg": "field required", "type": "value_error.missing", @@ -52,6 +64,7 @@ async def test_get_article_with_wrong_header_type_should_return_an_error_message assert resp.content_type == "application/json" assert await resp.json() == [ { + "in": "headers", "loc": ["signature_expired"], "msg": "invalid datetime format", "type": "value_error.datetime", @@ -87,3 +100,37 @@ async def test_get_article_with_valid_header_containing_hyphen_should_be_returne assert resp.status == 200 assert resp.content_type == "application/json" assert await resp.json() == {"signature": "2020-10-04T18:01:00"} + + +async def test_wrong_value_to_header_defined_with_str_enum(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/coord", ViewWithEnumType) + + client = await aiohttp_client(app) + resp = await client.get("/coord", headers={"format": "WGS84"}) + assert ( + await resp.json() + == [ + { + "ctx": {"enum_values": ["UMT", "MGRS"]}, + "in": "headers", + "loc": ["format"], + "msg": "value is not a valid enumeration member; permitted: 'UMT', 'MGRS'", + "type": "type_error.enum", + } + ] + != {"signature": "2020-10-04T18:01:00"} + ) + assert resp.status == 400 + assert resp.content_type == "application/json" + + +async def test_correct_value_to_header_defined_with_str_enum(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/coord", ViewWithEnumType) + + client = await aiohttp_client(app) + resp = await client.get("/coord", headers={"format": "UMT"}) + assert await resp.json() == {"format": "UMT"} + assert resp.status == 200 + assert resp.content_type == "application/json" diff --git a/tests/test_validation_path.py b/tests/test_validation_path.py index 0e74e65..bfa2016 100644 --- a/tests/test_validation_path.py +++ b/tests/test_validation_path.py @@ -8,7 +8,7 @@ class ArticleView(PydanticView): return web.json_response({"path": [author_id, tag, date]}) -async def test_get_article_without_required_qs_should_return_an_error_message( +async def test_get_article_with_correct_path_parameters_should_return_parameters_in_path( aiohttp_client, loop ): app = web.Application() @@ -19,3 +19,23 @@ async def test_get_article_without_required_qs_should_return_an_error_message( assert resp.status == 200 assert resp.content_type == "application/json" assert await resp.json() == {"path": ["1234", "music", 1980]} + + +async def test_get_article_with_wrong_path_parameters_should_return_error( + aiohttp_client, loop +): + app = web.Application() + app.router.add_view("/article/{author_id}/tag/{tag}/before/{date}", ArticleView) + + client = await aiohttp_client(app) + resp = await client.get("/article/1234/tag/music/before/now") + assert resp.status == 400 + assert resp.content_type == "application/json" + assert await resp.json() == [ + { + "in": "path", + "loc": ["date"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index cae7704..363c461 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -6,8 +6,12 @@ from aiohttp_pydantic import PydanticView class ArticleView(PydanticView): - async def get(self, with_comments: bool, age: Optional[int] = None): - return web.json_response({"with_comments": with_comments, "age": age}) + async def get( + self, with_comments: bool, age: Optional[int] = None, nb_items: int = 7 + ): + return web.json_response( + {"with_comments": with_comments, "age": age, "nb_items": nb_items} + ) async def test_get_article_without_required_qs_should_return_an_error_message( @@ -22,6 +26,7 @@ async def test_get_article_without_required_qs_should_return_an_error_message( assert resp.content_type == "application/json" assert await resp.json() == [ { + "in": "query string", "loc": ["with_comments"], "msg": "field required", "type": "value_error.missing", @@ -41,6 +46,7 @@ async def test_get_article_with_wrong_qs_type_should_return_an_error_message( assert resp.content_type == "application/json" assert await resp.json() == [ { + "in": "query string", "loc": ["with_comments"], "msg": "value could not be parsed to a boolean", "type": "type_error.bool", @@ -59,10 +65,10 @@ async def test_get_article_with_valid_qs_should_return_the_parsed_type( resp = await client.get("/article", params={"with_comments": "yes", "age": 3}) assert resp.status == 200 assert resp.content_type == "application/json" - assert await resp.json() == {"with_comments": True, "age": 3} + assert await resp.json() == {"with_comments": True, "age": 3, "nb_items": 7} -async def test_get_article_with_valid_qs_and_omitted_optional_should_return_none( +async def test_get_article_with_valid_qs_and_omitted_optional_should_return_default_value( aiohttp_client, loop ): app = web.Application() @@ -71,6 +77,6 @@ async def test_get_article_with_valid_qs_and_omitted_optional_should_return_none client = await aiohttp_client(app) resp = await client.get("/article", params={"with_comments": "yes"}) + assert await resp.json() == {"with_comments": True, "age": None, "nb_items": 7} assert resp.status == 200 assert resp.content_type == "application/json" - assert await resp.json() == {"with_comments": True, "age": None}