Add type to define OAS responses
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from typing import Iterable
|
||||
from importlib import resources
|
||||
from typing import Iterable
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
from .view import get_oas, oas_ui
|
||||
from swagger_ui_bundle import swagger_ui_path
|
||||
|
||||
from .view import get_oas, oas_ui
|
||||
|
||||
|
||||
def setup(
|
||||
app: web.Application,
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Info:
|
||||
def __init__(self, spec: dict):
|
||||
self._spec = spec.setdefault("info", {})
|
||||
@@ -115,6 +118,39 @@ class Parameters:
|
||||
return Parameter(spec)
|
||||
|
||||
|
||||
class Response:
|
||||
def __init__(self, spec: dict):
|
||||
self._spec = spec
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._spec["description"]
|
||||
|
||||
@description.setter
|
||||
def description(self, description: str):
|
||||
self._spec["description"] = description
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self._spec["content"]
|
||||
|
||||
@content.setter
|
||||
def content(self, content: dict):
|
||||
self._spec["content"] = content
|
||||
|
||||
|
||||
class Responses:
|
||||
def __init__(self, spec: dict):
|
||||
self._spec = spec.setdefault("responses", {})
|
||||
|
||||
def __getitem__(self, status_code: Union[int, str]) -> Response:
|
||||
if not (100 <= int(status_code) < 600):
|
||||
raise ValueError("status_code must be between 100 and 599")
|
||||
|
||||
spec = self._spec.setdefault(str(status_code), {})
|
||||
return Response(spec)
|
||||
|
||||
|
||||
class OperationObject:
|
||||
def __init__(self, spec: dict):
|
||||
self._spec = spec
|
||||
@@ -143,6 +179,10 @@ class OperationObject:
|
||||
def parameters(self) -> Parameters:
|
||||
return Parameters(self._spec)
|
||||
|
||||
@property
|
||||
def responses(self) -> Responses:
|
||||
return Responses(self._spec)
|
||||
|
||||
|
||||
class PathItem:
|
||||
def __init__(self, spec: dict):
|
||||
|
||||
47
aiohttp_pydantic/oas/typing.py
Normal file
47
aiohttp_pydantic/oas/typing.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This module provides type to annotate the content of web.Response returned by
|
||||
the HTTP handlers.
|
||||
|
||||
The type are: r100, r101, ..., r599
|
||||
|
||||
Example:
|
||||
|
||||
class PetCollectionView(PydanticView):
|
||||
async def get(self) -> Union[r200[List[Pet]], r404]:
|
||||
...
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from types import new_class
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
RespContents = TypeVar("RespContents", covariant=True)
|
||||
|
||||
_status_code = frozenset(f"r{code}" for code in range(100, 600))
|
||||
|
||||
|
||||
@lru_cache(maxsize=len(_status_code))
|
||||
def _make_status_code_type(status_code):
|
||||
if status_code in _status_code:
|
||||
return new_class(status_code, (Protocol[RespContents],))
|
||||
|
||||
|
||||
def is_status_code_type(obj):
|
||||
"""
|
||||
Return True if obj is a status code type such as _200 or _404.
|
||||
"""
|
||||
name = getattr(obj, "__name__", None)
|
||||
if name not in _status_code:
|
||||
return False
|
||||
|
||||
return obj is _make_status_code_type(name)
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if (status_code_type := _make_status_code_type(name)) is None:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
return status_code_type
|
||||
|
||||
|
||||
__all__ = list(_status_code)
|
||||
__all__.append("is_status_code_type")
|
||||
@@ -1,44 +1,109 @@
|
||||
from aiohttp.web import json_response, Response
|
||||
import typing
|
||||
from typing import Type
|
||||
|
||||
from aiohttp.web import Response, json_response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem
|
||||
from typing import Type
|
||||
|
||||
from ..injectors import _parse_func_signature
|
||||
from ..view import PydanticView, is_pydantic_view
|
||||
|
||||
from .typing import is_status_code_type
|
||||
|
||||
JSON_SCHEMA_TYPES = {float: "number", str: "string", int: "integer"}
|
||||
|
||||
|
||||
def _add_http_method_to_oas(oas_path: PathItem, method: str, view: Type[PydanticView]):
|
||||
method = method.lower()
|
||||
mtd: OperationObject = getattr(oas_path, method)
|
||||
handler = getattr(view, method)
|
||||
def _is_pydantic_base_model(obj):
|
||||
"""
|
||||
Return true is obj is a pydantic.BaseModel subclass.
|
||||
"""
|
||||
try:
|
||||
return issubclass(obj, BaseModel)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
class _OASResponseBuilder:
|
||||
"""
|
||||
Parse the type annotated as returned by a function and
|
||||
generate the OAS operation response.
|
||||
"""
|
||||
|
||||
def __init__(self, oas_operation):
|
||||
self._oas_operation = oas_operation
|
||||
|
||||
@staticmethod
|
||||
def _handle_pydantic_base_model(obj):
|
||||
if _is_pydantic_base_model(obj):
|
||||
return obj.schema()
|
||||
return {}
|
||||
|
||||
def _handle_list(self, obj):
|
||||
if typing.get_origin(obj) is list:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": self._handle_pydantic_base_model(typing.get_args(obj)[0]),
|
||||
}
|
||||
return self._handle_pydantic_base_model(obj)
|
||||
|
||||
def _handle_status_code_type(self, obj):
|
||||
if is_status_code_type(typing.get_origin(obj)):
|
||||
status_code = typing.get_origin(obj).__name__[1:]
|
||||
self._oas_operation.responses[status_code].content = {
|
||||
"application/json": {
|
||||
"schema": self._handle_list(typing.get_args(obj)[0])
|
||||
}
|
||||
}
|
||||
|
||||
elif is_status_code_type(obj):
|
||||
status_code = obj.__name__[1:]
|
||||
self._oas_operation.responses[status_code].content = {}
|
||||
|
||||
def _handle_union(self, obj):
|
||||
if typing.get_origin(obj) is typing.Union:
|
||||
for arg in typing.get_args(obj):
|
||||
self._handle_status_code_type(arg)
|
||||
self._handle_status_code_type(obj)
|
||||
|
||||
def build(self, obj):
|
||||
self._handle_union(obj)
|
||||
|
||||
|
||||
def _add_http_method_to_oas(
|
||||
oas_path: PathItem, http_method: str, view: Type[PydanticView]
|
||||
):
|
||||
http_method = http_method.lower()
|
||||
oas_operation: OperationObject = getattr(oas_path, http_method)
|
||||
handler = getattr(view, http_method)
|
||||
path_args, body_args, qs_args, header_args = _parse_func_signature(handler)
|
||||
|
||||
if body_args:
|
||||
mtd.request_body.content = {
|
||||
oas_operation.request_body.content = {
|
||||
"application/json": {"schema": next(iter(body_args.values())).schema()}
|
||||
}
|
||||
|
||||
i = 0
|
||||
for i, (name, type_) in enumerate(path_args.items()):
|
||||
mtd.parameters[i].required = True
|
||||
mtd.parameters[i].in_ = "path"
|
||||
mtd.parameters[i].name = name
|
||||
mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
oas_operation.parameters[i].required = True
|
||||
oas_operation.parameters[i].in_ = "path"
|
||||
oas_operation.parameters[i].name = name
|
||||
oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
|
||||
for i, (name, type_) in enumerate(qs_args.items(), i + 1):
|
||||
mtd.parameters[i].required = False
|
||||
mtd.parameters[i].in_ = "query"
|
||||
mtd.parameters[i].name = name
|
||||
mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
oas_operation.parameters[i].required = False
|
||||
oas_operation.parameters[i].in_ = "query"
|
||||
oas_operation.parameters[i].name = name
|
||||
oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
|
||||
for i, (name, type_) in enumerate(header_args.items(), i + 1):
|
||||
mtd.parameters[i].required = False
|
||||
mtd.parameters[i].in_ = "header"
|
||||
mtd.parameters[i].name = name
|
||||
mtd.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
oas_operation.parameters[i].required = False
|
||||
oas_operation.parameters[i].in_ = "header"
|
||||
oas_operation.parameters[i].name = name
|
||||
oas_operation.parameters[i].schema = {"type": JSON_SCHEMA_TYPES[type_]}
|
||||
|
||||
return_type = handler.__annotations__.get("return")
|
||||
if return_type is not None:
|
||||
_OASResponseBuilder(oas_operation).build(return_type)
|
||||
|
||||
|
||||
async def get_oas(request):
|
||||
|
||||
Reference in New Issue
Block a user