diff --git a/README.rst b/README.rst index 815665d..6cea178 100644 --- a/README.rst +++ b/README.rst @@ -182,8 +182,59 @@ on the same route, you must use *apps_to_expose* parameters oas.setup(app, apps_to_expose=[app, sub_app_1]) + +Add annotation to define response content +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The module aiohttp_pydantic.oas.typing provides class to annotate a +response content. + +For example *r200[List[Pet]]* means the server responses with +the status code 200 and the response content is a List of Pet where Pet will be +defined using a pydantic.BaseModel + + +.. code-block:: python3 + + from aiohttp_pydantic import PydanticView + from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 + + + class Pet(BaseModel): + id: int + name: str + + + class Error(BaseModel): + error: str + + + class PetCollectionView(PydanticView): + async def get(self) -> r200[List[Pet]]: + pets = self.request.app["model"].list_pets() + return web.json_response([pet.dict() for pet in pets]) + + async def post(self, pet: Pet) -> r201[Pet]: + self.request.app["model"].add_pet(pet) + return web.json_response(pet.dict()) + + + class PetItemView(PydanticView): + async def get(self, id: int, /) -> Union[r200[Pet], r404[Error]]: + pet = self.request.app["model"].find_pet(id) + return web.json_response(pet.dict()) + + async def put(self, id: int, /, pet: Pet) -> r200[Pet]: + self.request.app["model"].update_pet(id, pet) + return web.json_response(pet.dict()) + + async def delete(self, id: int, /) -> r204: + self.request.app["model"].remove_pet(id) + return web.Response(status=204) + + Demo -==== +---- Have a look at `demo`_ for a complete example @@ -197,5 +248,4 @@ Have a look at `demo`_ for a complete example Go to http://127.0.0.1:8080/oas - .. _demo: https://github.com/Maillol/aiohttp-pydantic/tree/main/demo diff --git a/aiohttp_pydantic/injectors.py b/aiohttp_pydantic/injectors.py index f96fccc..8594459 100644 --- a/aiohttp_pydantic/injectors.py +++ b/aiohttp_pydantic/injectors.py @@ -1,11 +1,9 @@ +import abc +from inspect import signature from typing import Callable, Tuple from aiohttp.web_request import BaseRequest from pydantic import BaseModel -from inspect import signature - - -import abc class AbstractInjector(metaclass=abc.ABCMeta): diff --git a/aiohttp_pydantic/oas/__init__.py b/aiohttp_pydantic/oas/__init__.py index 939ec13..ef45dcf 100644 --- a/aiohttp_pydantic/oas/__init__.py +++ b/aiohttp_pydantic/oas/__init__.py @@ -1,11 +1,12 @@ -from typing import Iterable from importlib import resources +from typing import Iterable import jinja2 from aiohttp import web -from .view import get_oas, oas_ui from swagger_ui_bundle import swagger_ui_path +from .view import get_oas, oas_ui + def setup( app: web.Application, diff --git a/aiohttp_pydantic/oas/struct.py b/aiohttp_pydantic/oas/struct.py index 4d62a19..6751c3f 100644 --- a/aiohttp_pydantic/oas/struct.py +++ b/aiohttp_pydantic/oas/struct.py @@ -1,3 +1,6 @@ +from typing import Union + + class Info: def __init__(self, spec: dict): self._spec = spec.setdefault("info", {}) @@ -115,6 +118,39 @@ class Parameters: return Parameter(spec) +class Response: + def __init__(self, spec: dict): + self._spec = spec + + @property + def description(self) -> str: + return self._spec["description"] + + @description.setter + def description(self, description: str): + self._spec["description"] = description + + @property + def content(self): + return self._spec["content"] + + @content.setter + def content(self, content: dict): + self._spec["content"] = content + + +class Responses: + def __init__(self, spec: dict): + self._spec = spec.setdefault("responses", {}) + + def __getitem__(self, status_code: Union[int, str]) -> Response: + if not (100 <= int(status_code) < 600): + raise ValueError("status_code must be between 100 and 599") + + spec = self._spec.setdefault(str(status_code), {}) + return Response(spec) + + class OperationObject: def __init__(self, spec: dict): self._spec = spec @@ -143,6 +179,10 @@ class OperationObject: def parameters(self) -> Parameters: return Parameters(self._spec) + @property + def responses(self) -> Responses: + return Responses(self._spec) + class PathItem: def __init__(self, spec: dict): diff --git a/aiohttp_pydantic/oas/typing.py b/aiohttp_pydantic/oas/typing.py new file mode 100644 index 0000000..5be1cd7 --- /dev/null +++ b/aiohttp_pydantic/oas/typing.py @@ -0,0 +1,47 @@ +""" +This module provides type to annotate the content of web.Response returned by +the HTTP handlers. + +The type are: r100, r101, ..., r599 + +Example: + + class PetCollectionView(PydanticView): + async def get(self) -> Union[r200[List[Pet]], r404]: + ... +""" + +from functools import lru_cache +from types import new_class +from typing import Protocol, TypeVar + +RespContents = TypeVar("RespContents", covariant=True) + +_status_code = frozenset(f"r{code}" for code in range(100, 600)) + + +@lru_cache(maxsize=len(_status_code)) +def _make_status_code_type(status_code): + if status_code in _status_code: + return new_class(status_code, (Protocol[RespContents],)) + + +def is_status_code_type(obj): + """ + Return True if obj is a status code type such as _200 or _404. + """ + name = getattr(obj, "__name__", None) + if name not in _status_code: + return False + + return obj is _make_status_code_type(name) + + +def __getattr__(name): + if (status_code_type := _make_status_code_type(name)) is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + return status_code_type + + +__all__ = list(_status_code) +__all__.append("is_status_code_type") diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index 52e0408..7daa830 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -1,44 +1,109 @@ -from aiohttp.web import json_response, Response +import typing +from typing import Type + +from aiohttp.web import Response, json_response +from pydantic import BaseModel from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem -from typing import Type from ..injectors import _parse_func_signature from ..view import PydanticView, is_pydantic_view - +from .typing import is_status_code_type JSON_SCHEMA_TYPES = {float: "number", str: "string", int: "integer"} -def _add_http_method_to_oas(oas_path: PathItem, method: str, view: Type[PydanticView]): - method = method.lower() - mtd: OperationObject = getattr(oas_path, method) - handler = getattr(view, method) +def _is_pydantic_base_model(obj): + """ + Return true is obj is a pydantic.BaseModel subclass. + """ + try: + return issubclass(obj, BaseModel) + except TypeError: + return False + + +class _OASResponseBuilder: + """ + Parse the type annotated as returned by a function and + generate the OAS operation response. + """ + + def __init__(self, oas_operation): + self._oas_operation = oas_operation + + @staticmethod + def _handle_pydantic_base_model(obj): + if _is_pydantic_base_model(obj): + return obj.schema() + return {} + + def _handle_list(self, obj): + if typing.get_origin(obj) is list: + return { + "type": "array", + "items": self._handle_pydantic_base_model(typing.get_args(obj)[0]), + } + return self._handle_pydantic_base_model(obj) + + def _handle_status_code_type(self, obj): + if is_status_code_type(typing.get_origin(obj)): + status_code = typing.get_origin(obj).__name__[1:] + self._oas_operation.responses[status_code].content = { + "application/json": { + "schema": self._handle_list(typing.get_args(obj)[0]) + } + } + + elif is_status_code_type(obj): + status_code = obj.__name__[1:] + self._oas_operation.responses[status_code].content = {} + + def _handle_union(self, obj): + if typing.get_origin(obj) is typing.Union: + for arg in typing.get_args(obj): + self._handle_status_code_type(arg) + self._handle_status_code_type(obj) + + def build(self, obj): + self._handle_union(obj) + + +def _add_http_method_to_oas( + oas_path: PathItem, http_method: str, view: Type[PydanticView] +): + http_method = http_method.lower() + oas_operation: OperationObject = getattr(oas_path, http_method) + handler = getattr(view, http_method) path_args, body_args, qs_args, header_args = _parse_func_signature(handler) if body_args: - mtd.request_body.content = { + oas_operation.request_body.content = { "application/json": {"schema": next(iter(body_args.values())).schema()} } i = 0 for i, (name, type_) in enumerate(path_args.items()): - mtd.parameters[i].required = True - mtd.parameters[i].in_ = "path" - mtd.parameters[i].name = name - mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} + oas_operation.parameters[i].required = True + oas_operation.parameters[i].in_ = "path" + oas_operation.parameters[i].name = name + oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} for i, (name, type_) in enumerate(qs_args.items(), i + 1): - mtd.parameters[i].required = False - mtd.parameters[i].in_ = "query" - mtd.parameters[i].name = name - mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} + oas_operation.parameters[i].required = False + oas_operation.parameters[i].in_ = "query" + oas_operation.parameters[i].name = name + oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} for i, (name, type_) in enumerate(header_args.items(), i + 1): - mtd.parameters[i].required = False - mtd.parameters[i].in_ = "header" - mtd.parameters[i].name = name - mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} + oas_operation.parameters[i].required = False + oas_operation.parameters[i].in_ = "header" + oas_operation.parameters[i].name = name + oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]} + + return_type = handler.__annotations__.get("return") + if return_type is not None: + _OASResponseBuilder(oas_operation).build(return_type) async def get_oas(request): diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index e5b8587..a899542 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -1,22 +1,16 @@ +from functools import update_wrapper from inspect import iscoroutinefunction +from typing import Any, Callable, Generator, Iterable + from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL +from aiohttp.web import json_response from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_response import StreamResponse from pydantic import ValidationError -from typing import Generator, Any, Callable, Type, Iterable -from aiohttp.web import json_response -from functools import update_wrapper - -from .injectors import ( - MatchInfoGetter, - HeadersGetter, - QueryGetter, - BodyGetter, - AbstractInjector, - _parse_func_signature, -) +from .injectors import (AbstractInjector, BodyGetter, HeadersGetter, + MatchInfoGetter, QueryGetter, _parse_func_signature) class PydanticView(AbstractView): diff --git a/demo/__main__.py b/demo/__main__.py index 58025fb..5ca79a0 100644 --- a/demo/__main__.py +++ b/demo/__main__.py @@ -1,10 +1,10 @@ from aiohttp import web - -from aiohttp_pydantic import oas from aiohttp.web import middleware -from .view import PetItemView, PetCollectionView +from aiohttp_pydantic import oas + from .model import Model +from .view import PetCollectionView, PetItemView @middleware diff --git a/demo/model.py b/demo/model.py index 3e1d4c7..a6ef842 100644 --- a/demo/model.py +++ b/demo/model.py @@ -6,6 +6,10 @@ class Pet(BaseModel): name: str +class Error(BaseModel): + error: str + + class Model: """ To keep simple this demo, we use a simple dict as database to diff --git a/demo/view.py b/demo/view.py index fe12ffc..b897f00 100644 --- a/demo/view.py +++ b/demo/view.py @@ -1,28 +1,32 @@ -from aiohttp_pydantic import PydanticView +from typing import List, Union + from aiohttp import web -from .model import Pet +from aiohttp_pydantic import PydanticView +from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 + +from .model import Error, Pet class PetCollectionView(PydanticView): - async def get(self): + async def get(self) -> r200[List[Pet]]: pets = self.request.app["model"].list_pets() return web.json_response([pet.dict() for pet in pets]) - async def post(self, pet: Pet): + async def post(self, pet: Pet) -> r201[Pet]: self.request.app["model"].add_pet(pet) return web.json_response(pet.dict()) class PetItemView(PydanticView): - async def get(self, id: int, /): + async def get(self, id: int, /) -> Union[r200[Pet], r404[Error]]: pet = self.request.app["model"].find_pet(id) return web.json_response(pet.dict()) - async def put(self, id: int, /, pet: Pet): + async def put(self, id: int, /, pet: Pet) -> r200[Pet]: self.request.app["model"].update_pet(id, pet) return web.json_response(pet.dict()) - async def delete(self, id: int, /): + async def delete(self, id: int, /) -> r204: self.request.app["model"].remove_pet(id) - return web.json_response(id) + return web.Response(status=204) diff --git a/tests/test_oas/test_view.py b/tests/test_oas/test_view.py index 78f39ec..17332d0 100644 --- a/tests/test_oas/test_view.py +++ b/tests/test_oas/test_view.py @@ -1,8 +1,11 @@ -from pydantic.main import BaseModel -from aiohttp_pydantic import PydanticView, oas -from aiohttp import web +from typing import List, Union import pytest +from aiohttp import web +from pydantic.main import BaseModel + +from aiohttp_pydantic import PydanticView, oas +from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 class Pet(BaseModel): @@ -11,21 +14,21 @@ class Pet(BaseModel): class PetCollectionView(PydanticView): - async def get(self): + async def get(self) -> r200[List[Pet]]: return web.json_response() - async def post(self, pet: Pet): + async def post(self, pet: Pet) -> r201[Pet]: return web.json_response() class PetItemView(PydanticView): - async def get(self, id: int, /): + async def get(self, id: int, /) -> Union[r200[Pet], r404]: return web.json_response() async def put(self, id: int, /, pet: Pet): return web.json_response() - async def delete(self, id: int, /): + async def delete(self, id: int, /) -> r204: return web.json_response() @@ -48,7 +51,28 @@ async def test_generated_oas_should_have_pets_paths(generated_oas): async def test_pets_route_should_have_get_method(generated_oas): - assert generated_oas["paths"]["/pets"]["get"] == {} + assert generated_oas["paths"]["/pets"]["get"] == { + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "title": "Pet", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + }, + } + } + } + } + } + } async def test_pets_route_should_have_post_method(generated_oas): @@ -57,17 +81,34 @@ async def test_pets_route_should_have_post_method(generated_oas): "content": { "application/json": { "schema": { + "title": "Pet", + "type": "object", "properties": { "id": {"title": "Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}, }, "required": ["id", "name"], - "title": "Pet", - "type": "object", } } } - } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "title": "Pet", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + } + } + } + } + }, } @@ -79,12 +120,13 @@ async def test_pets_id_route_should_have_delete_method(generated_oas): assert generated_oas["paths"]["/pets/{id}"]["delete"] == { "parameters": [ { + "required": True, "in": "path", "name": "id", - "required": True, "schema": {"type": "integer"}, } - ] + ], + "responses": {"204": {"content": {}}}, } @@ -97,7 +139,25 @@ async def test_pets_id_route_should_have_get_method(generated_oas): "required": True, "schema": {"type": "integer"}, } - ] + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "properties": { + "id": {"title": "Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "Pet", + "type": "object", + } + } + } + }, + "404": {"content": {}}, + }, } diff --git a/tests/test_parse_func_signature.py b/tests/test_parse_func_signature.py index 3c91e9f..c2cc47b 100644 --- a/tests/test_parse_func_signature.py +++ b/tests/test_parse_func_signature.py @@ -1,7 +1,9 @@ -from aiohttp_pydantic.injectors import _parse_func_signature -from pydantic import BaseModel from uuid import UUID +from pydantic import BaseModel + +from aiohttp_pydantic.injectors import _parse_func_signature + class User(BaseModel): firstname: str diff --git a/tests/test_validation_body.py b/tests/test_validation_body.py index 33951fb..18cb97f 100644 --- a/tests/test_validation_body.py +++ b/tests/test_validation_body.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel from typing import Optional + from aiohttp import web +from pydantic import BaseModel + from aiohttp_pydantic import PydanticView diff --git a/tests/test_validation_header.py b/tests/test_validation_header.py index 511f45c..83558e1 100644 --- a/tests/test_validation_header.py +++ b/tests/test_validation_header.py @@ -1,7 +1,9 @@ -from aiohttp import web -from aiohttp_pydantic import PydanticView -from datetime import datetime import json +from datetime import datetime + +from aiohttp import web + +from aiohttp_pydantic import PydanticView class JSONEncoder(json.JSONEncoder): diff --git a/tests/test_validation_path.py b/tests/test_validation_path.py index fe7466f..0e74e65 100644 --- a/tests/test_validation_path.py +++ b/tests/test_validation_path.py @@ -1,4 +1,5 @@ from aiohttp import web + from aiohttp_pydantic import PydanticView diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 3df2336..441cf6e 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -1,4 +1,5 @@ from aiohttp import web + from aiohttp_pydantic import PydanticView