diff --git a/aiohttp_pydantic/injectors.py b/aiohttp_pydantic/injectors.py index 8594459..a653fd1 100644 --- a/aiohttp_pydantic/injectors.py +++ b/aiohttp_pydantic/injectors.py @@ -5,6 +5,8 @@ from typing import Callable, Tuple from aiohttp.web_request import BaseRequest from pydantic import BaseModel +from .utils import is_pydantic_base_model + class AbstractInjector(metaclass=abc.ABCMeta): """ @@ -98,7 +100,7 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]: if param_spec.kind is param_spec.POSITIONAL_ONLY: path_args[param_name] = param_spec.annotation elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD: - if issubclass(param_spec.annotation, BaseModel): + if is_pydantic_base_model(param_spec.annotation): body_args[param_name] = param_spec.annotation else: qs_args[param_name] = param_spec.annotation diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index 56d2467..b53c167 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -3,27 +3,17 @@ from typing import List, Type from aiohttp.web import Response, json_response from aiohttp.web_app import Application -from pydantic import BaseModel from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem from ..injectors import _parse_func_signature +from ..utils import is_pydantic_base_model from ..view import PydanticView, is_pydantic_view from .typing import is_status_code_type JSON_SCHEMA_TYPES = {float: "number", str: "string", int: "integer"} -def _is_pydantic_base_model(obj): - """ - Return true is obj is a pydantic.BaseModel subclass. - """ - try: - return issubclass(obj, BaseModel) - except TypeError: - return False - - class _OASResponseBuilder: """ Parse the type annotated as returned by a function and @@ -35,7 +25,7 @@ class _OASResponseBuilder: @staticmethod def _handle_pydantic_base_model(obj): - if _is_pydantic_base_model(obj): + if is_pydantic_base_model(obj): return obj.schema() return {} diff --git a/aiohttp_pydantic/utils.py b/aiohttp_pydantic/utils.py new file mode 100644 index 0000000..efee749 --- /dev/null +++ b/aiohttp_pydantic/utils.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +def is_pydantic_base_model(obj): + """ + Return true is obj is a pydantic.BaseModel subclass. + """ + try: + return issubclass(obj, BaseModel) + except TypeError: + return False diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index e04f71e..a899542 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -9,14 +9,8 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_response import StreamResponse from pydantic import ValidationError -from .injectors import ( - AbstractInjector, - BodyGetter, - HeadersGetter, - MatchInfoGetter, - QueryGetter, - _parse_func_signature, -) +from .injectors import (AbstractInjector, BodyGetter, HeadersGetter, + MatchInfoGetter, QueryGetter, _parse_func_signature) class PydanticView(AbstractView): diff --git a/demo/model.py b/demo/model.py index a6ef842..5a5b425 100644 --- a/demo/model.py +++ b/demo/model.py @@ -4,6 +4,7 @@ from pydantic import BaseModel class Pet(BaseModel): id: int name: str + age: int class Error(BaseModel): diff --git a/demo/view.py b/demo/view.py index b897f00..bed6260 100644 --- a/demo/view.py +++ b/demo/view.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union from aiohttp import web @@ -9,9 +9,11 @@ from .model import Error, Pet class PetCollectionView(PydanticView): - async def get(self) -> r200[List[Pet]]: + async def get(self, age: Optional[int] = None) -> r200[List[Pet]]: pets = self.request.app["model"].list_pets() - return web.json_response([pet.dict() for pet in pets]) + return web.json_response( + [pet.dict() for pet in pets if age is None or age == pet.age] + ) async def post(self, pet: Pet) -> r201[Pet]: self.request.app["model"].add_pet(pet) diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 441cf6e..cae7704 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -1,11 +1,13 @@ +from typing import Optional + from aiohttp import web from aiohttp_pydantic import PydanticView class ArticleView(PydanticView): - async def get(self, with_comments: bool): - return web.json_response({"with_comments": with_comments}) + async def get(self, with_comments: bool, age: Optional[int] = None): + return web.json_response({"with_comments": with_comments, "age": age}) async def test_get_article_without_required_qs_should_return_an_error_message( @@ -53,7 +55,22 @@ async def test_get_article_with_valid_qs_should_return_the_parsed_type( app.router.add_view("/article", ArticleView) client = await aiohttp_client(app) + + resp = await client.get("/article", params={"with_comments": "yes", "age": 3}) + assert resp.status == 200 + assert resp.content_type == "application/json" + assert await resp.json() == {"with_comments": True, "age": 3} + + +async def test_get_article_with_valid_qs_and_omitted_optional_should_return_none( + aiohttp_client, loop +): + app = web.Application() + app.router.add_view("/article", ArticleView) + + client = await aiohttp_client(app) + resp = await client.get("/article", params={"with_comments": "yes"}) assert resp.status == 200 assert resp.content_type == "application/json" - assert await resp.json() == {"with_comments": True} + assert await resp.json() == {"with_comments": True, "age": None}