Merge branch 'main' into fixed
# Conflicts: # aiohttp_pydantic/oas/__init__.py # aiohttp_pydantic/view.py
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .view import PydanticView
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.12.1"
|
||||
|
||||
__all__ = ("PydanticView", "__version__")
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import abc
|
||||
from inspect import signature
|
||||
import typing
|
||||
from inspect import signature, getmro
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Callable, Tuple
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, Tuple, Literal, Type, get_type_hints
|
||||
|
||||
from aiohttp.web_exceptions import HTTPBadRequest
|
||||
from aiohttp.web_request import BaseRequest
|
||||
from multidict import MultiDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .utils import is_pydantic_base_model
|
||||
from .utils import is_pydantic_base_model, robuste_issubclass
|
||||
|
||||
CONTEXT = Literal["body", "headers", "path", "query string"]
|
||||
|
||||
|
||||
class AbstractInjector(metaclass=abc.ABCMeta):
|
||||
@@ -15,9 +20,11 @@ class AbstractInjector(metaclass=abc.ABCMeta):
|
||||
An injector parse HTTP request and inject params to the view.
|
||||
"""
|
||||
|
||||
model: Type[BaseModel]
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def context(self) -> str:
|
||||
def context(self) -> CONTEXT:
|
||||
"""
|
||||
The name of part of parsed request
|
||||
i.e "HTTP header", "URL path", ...
|
||||
@@ -61,6 +68,7 @@ class BodyGetter(AbstractInjector):
|
||||
|
||||
def __init__(self, args_spec: dict, default_values: dict):
|
||||
self.arg_name, self.model = next(iter(args_spec.items()))
|
||||
self._expect_object = self.model.schema()["type"] == "object"
|
||||
|
||||
async def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||
try:
|
||||
@@ -70,7 +78,16 @@ class BodyGetter(AbstractInjector):
|
||||
text='{"error": "Malformed JSON"}', content_type="application/json"
|
||||
) from None
|
||||
|
||||
kwargs_view[self.arg_name] = self.model(**body)
|
||||
# Pydantic tries to cast certain structures, such as a list of 2-tuples,
|
||||
# to a dict. Prevent this by requiring the body to be a dict for object models.
|
||||
if self._expect_object and not isinstance(body, dict):
|
||||
raise HTTPBadRequest(
|
||||
text='[{"in": "body", "loc": ["__root__"], "msg": "value is not a '
|
||||
'valid dict", "type": "type_error.dict"}]',
|
||||
content_type="application/json",
|
||||
) from None
|
||||
|
||||
kwargs_view[self.arg_name] = self.model.parse_obj(body)
|
||||
|
||||
|
||||
class QueryGetter(AbstractInjector):
|
||||
@@ -81,12 +98,46 @@ class QueryGetter(AbstractInjector):
|
||||
context = "query string"
|
||||
|
||||
def __init__(self, args_spec: dict, default_values: dict):
|
||||
args_spec = args_spec.copy()
|
||||
|
||||
self._groups = {}
|
||||
for group_name, group in args_spec.items():
|
||||
if robuste_issubclass(group, Group):
|
||||
self._groups[group_name] = (group, _get_group_signature(group)[0])
|
||||
|
||||
_unpack_group_in_signature(args_spec, default_values)
|
||||
attrs = {"__annotations__": args_spec}
|
||||
attrs.update(default_values)
|
||||
|
||||
self.model = type("QueryModel", (BaseModel,), attrs)
|
||||
self.args_spec = args_spec
|
||||
self._is_multiple = frozenset(
|
||||
name for name, spec in args_spec.items() if typing.get_origin(spec) is list
|
||||
)
|
||||
|
||||
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||
kwargs_view.update(self.model(**request.query).dict())
|
||||
data = self._query_to_dict(request.query)
|
||||
cleaned = self.model(**data).dict()
|
||||
for group_name, (group_cls, group_attrs) in self._groups.items():
|
||||
group = group_cls()
|
||||
for attr_name in group_attrs:
|
||||
setattr(group, attr_name, cleaned.pop(attr_name))
|
||||
cleaned[group_name] = group
|
||||
kwargs_view.update(**cleaned)
|
||||
|
||||
def _query_to_dict(self, query: MultiDict):
|
||||
"""
|
||||
Return a dict with list as value from the MultiDict.
|
||||
|
||||
The value will be wrapped in a list if the args spec is define as a list or if
|
||||
the multiple values are sent (i.e ?foo=1&foo=2)
|
||||
"""
|
||||
return {
|
||||
key: values
|
||||
if len(values := query.getall(key)) > 1 or key in self._is_multiple
|
||||
else value
|
||||
for key, value in query.items()
|
||||
}
|
||||
|
||||
|
||||
class HeadersGetter(AbstractInjector):
|
||||
@@ -97,18 +148,80 @@ class HeadersGetter(AbstractInjector):
|
||||
context = "headers"
|
||||
|
||||
def __init__(self, args_spec: dict, default_values: dict):
|
||||
args_spec = args_spec.copy()
|
||||
|
||||
self._groups = {}
|
||||
for group_name, group in args_spec.items():
|
||||
if robuste_issubclass(group, Group):
|
||||
self._groups[group_name] = (group, _get_group_signature(group)[0])
|
||||
|
||||
_unpack_group_in_signature(args_spec, default_values)
|
||||
|
||||
attrs = {"__annotations__": args_spec}
|
||||
attrs.update(default_values)
|
||||
self.model = type("HeaderModel", (BaseModel,), attrs)
|
||||
|
||||
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
|
||||
kwargs_view.update(self.model(**header).dict())
|
||||
cleaned = self.model(**header).dict()
|
||||
for group_name, (group_cls, group_attrs) in self._groups.items():
|
||||
group = group_cls()
|
||||
for attr_name in group_attrs:
|
||||
setattr(group, attr_name, cleaned.pop(attr_name))
|
||||
cleaned[group_name] = group
|
||||
kwargs_view.update(cleaned)
|
||||
|
||||
|
||||
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]:
|
||||
class Group(SimpleNamespace):
|
||||
"""
|
||||
Analyse function signature and returns 4-tuple:
|
||||
Class to group header or query string parameters.
|
||||
|
||||
The parameter from query string or header will be set in the group
|
||||
and the group will be passed as function parameter.
|
||||
|
||||
Example:
|
||||
|
||||
class Pagination(Group):
|
||||
current_page: int = 1
|
||||
page_size: int = 15
|
||||
|
||||
class PetView(PydanticView):
|
||||
def get(self, page: Pagination):
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
def _get_group_signature(cls) -> Tuple[dict, dict]:
|
||||
"""
|
||||
Analyse Group subclass annotations and return them with default values.
|
||||
"""
|
||||
|
||||
sig = {}
|
||||
defaults = {}
|
||||
mro = getmro(cls)
|
||||
for base in reversed(mro[: mro.index(Group)]):
|
||||
attrs = vars(base)
|
||||
|
||||
# Use __annotations__ to know if an attribute is
|
||||
# overwrite to remove the default value.
|
||||
for attr_name, type_ in base.__annotations__.items():
|
||||
if (default := attrs.get(attr_name)) is None:
|
||||
defaults.pop(attr_name, None)
|
||||
else:
|
||||
defaults[attr_name] = default
|
||||
|
||||
# Use get_type_hints to have postponed annotations.
|
||||
for attr_name, type_ in get_type_hints(base).items():
|
||||
sig[attr_name] = type_
|
||||
|
||||
return sig, defaults
|
||||
|
||||
|
||||
def _parse_func_signature(
|
||||
func: Callable, unpack_group: bool = False
|
||||
) -> Tuple[dict, dict, dict, dict, dict]:
|
||||
"""
|
||||
Analyse function signature and returns 5-tuple:
|
||||
0 - arguments will be set from the url path
|
||||
1 - argument will be set from the request body.
|
||||
2 - argument will be set from the query string.
|
||||
@@ -122,27 +235,72 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]
|
||||
header_args = {}
|
||||
defaults = {}
|
||||
|
||||
annotations = get_type_hints(func)
|
||||
for param_name, param_spec in signature(func).parameters.items():
|
||||
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
if param_spec.annotation == param_spec.empty:
|
||||
raise RuntimeError(f"The parameter {param_name} must have an annotation")
|
||||
|
||||
annotation = annotations[param_name]
|
||||
if param_spec.default is not param_spec.empty:
|
||||
defaults[param_name] = param_spec.default
|
||||
|
||||
if param_spec.kind is param_spec.POSITIONAL_ONLY:
|
||||
path_args[param_name] = param_spec.annotation
|
||||
path_args[param_name] = annotation
|
||||
|
||||
elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD:
|
||||
if is_pydantic_base_model(param_spec.annotation):
|
||||
body_args[param_name] = param_spec.annotation
|
||||
if is_pydantic_base_model(annotation):
|
||||
body_args[param_name] = annotation
|
||||
else:
|
||||
qs_args[param_name] = param_spec.annotation
|
||||
qs_args[param_name] = annotation
|
||||
elif param_spec.kind is param_spec.KEYWORD_ONLY:
|
||||
header_args[param_name] = param_spec.annotation
|
||||
header_args[param_name] = annotation
|
||||
else:
|
||||
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
|
||||
|
||||
if unpack_group:
|
||||
try:
|
||||
_unpack_group_in_signature(qs_args, defaults)
|
||||
_unpack_group_in_signature(header_args, defaults)
|
||||
except DuplicateNames as error:
|
||||
raise TypeError(
|
||||
f"Parameters conflict in function {func},"
|
||||
f" the group {error.group} has an attribute named {error.attr_name}"
|
||||
) from None
|
||||
|
||||
return path_args, body_args, qs_args, header_args, defaults
|
||||
|
||||
|
||||
class DuplicateNames(Exception):
|
||||
"""
|
||||
Raised when a same parameter name is used in group and function signature.
|
||||
"""
|
||||
|
||||
group: Type[Group]
|
||||
attr_name: str
|
||||
|
||||
def __init__(self, group: Type[Group], attr_name: str):
|
||||
self.group = group
|
||||
self.attr_name = attr_name
|
||||
super().__init__(
|
||||
f"Conflict with {group}.{attr_name} and function parameter name"
|
||||
)
|
||||
|
||||
|
||||
def _unpack_group_in_signature(args: dict, defaults: dict) -> None:
|
||||
"""
|
||||
Unpack in place each Group found in args.
|
||||
"""
|
||||
for group_name, group in args.copy().items():
|
||||
if robuste_issubclass(group, Group):
|
||||
group_sig, group_default = _get_group_signature(group)
|
||||
for attr_name in group_sig:
|
||||
if attr_name in args and attr_name != group_name:
|
||||
raise DuplicateNames(group, attr_name)
|
||||
|
||||
del args[group_name]
|
||||
args.update(group_sig)
|
||||
defaults.update(group_default)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from importlib import resources
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
@@ -13,6 +13,8 @@ def setup(
|
||||
apps_to_expose: Iterable[web.Application] = (),
|
||||
url_prefix: str = "/oas",
|
||||
enable: bool = True,
|
||||
version_spec: Optional[str] = None,
|
||||
title_spec: Optional[str] = None,
|
||||
raise_validation_errors: bool = False,
|
||||
):
|
||||
if enable:
|
||||
@@ -23,6 +25,9 @@ def setup(
|
||||
oas_app["index template"] = jinja2.Template(
|
||||
resources.read_text("aiohttp_pydantic.oas", "index.j2")
|
||||
)
|
||||
oas_app["version_spec"] = version_spec
|
||||
oas_app["title_spec"] = title_spec
|
||||
|
||||
oas_app.router.add_get("/spec", get_oas, name="spec")
|
||||
oas_app.router.add_static("/static", swagger_ui_path, name="static")
|
||||
oas_app.router.add_get("", oas_ui, name="index")
|
||||
|
||||
@@ -1,10 +1,28 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
|
||||
from typing import Dict, Protocol, Optional, Callable
|
||||
import sys
|
||||
from .view import generate_oas
|
||||
|
||||
|
||||
class YamlModule(Protocol):
|
||||
"""
|
||||
Yaml Module type hint
|
||||
"""
|
||||
|
||||
def dump(self, data) -> str:
|
||||
pass
|
||||
|
||||
|
||||
yaml: Optional[YamlModule]
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
|
||||
def application_type(value):
|
||||
"""
|
||||
Return aiohttp application defined in the value.
|
||||
@@ -26,6 +44,35 @@ def application_type(value):
|
||||
raise argparse.ArgumentTypeError(error) from error
|
||||
|
||||
|
||||
def base_oas_file_type(value) -> Dict:
|
||||
"""
|
||||
Load base oas file
|
||||
"""
|
||||
try:
|
||||
with open(value) as oas_file:
|
||||
data = oas_file.read()
|
||||
except OSError as error:
|
||||
raise argparse.ArgumentTypeError(error) from error
|
||||
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
def format_type(value) -> Callable:
|
||||
"""
|
||||
Date Dumper one of (json, yaml)
|
||||
"""
|
||||
dumpers = {"json": lambda data: json.dumps(data, sort_keys=True, indent=4)}
|
||||
if yaml is not None:
|
||||
dumpers["yaml"] = yaml.dump
|
||||
|
||||
try:
|
||||
return dumpers[value]
|
||||
except KeyError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Wrong format value. (allowed values: {tuple(dumpers.keys())})"
|
||||
) from None
|
||||
|
||||
|
||||
def setup(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"apps",
|
||||
@@ -35,11 +82,52 @@ def setup(parser: argparse.ArgumentParser):
|
||||
help="The name of the module containing the asyncio.web.Application."
|
||||
" By default the variable named 'app' is loaded but you can define"
|
||||
" an other variable name ending the name of module with : characters"
|
||||
" and the name of variable. Example: my_package.my_module:my_app",
|
||||
" and the name of variable. Example: my_package.my_module:my_app"
|
||||
" If your asyncio.web.Application is returned by a function, you can"
|
||||
" use the syntax: my_package.my_module:my_app()",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--base-oas-file",
|
||||
metavar="FILE",
|
||||
dest="base",
|
||||
type=base_oas_file_type,
|
||||
help="A file that will be used as base to generate OAS",
|
||||
default={},
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
metavar="FILE",
|
||||
type=argparse.FileType("w"),
|
||||
help="File to write the output",
|
||||
default=sys.stdout,
|
||||
)
|
||||
|
||||
if yaml:
|
||||
help_output_format = (
|
||||
"The output format, can be 'json' or 'yaml' (default is json)"
|
||||
)
|
||||
else:
|
||||
help_output_format = "The output format, only 'json' is available install pyyaml to have yaml output format"
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--format",
|
||||
metavar="FORMAT",
|
||||
dest="formatter",
|
||||
type=format_type,
|
||||
help=help_output_format,
|
||||
default=format_type("json"),
|
||||
)
|
||||
|
||||
parser.set_defaults(func=show_oas)
|
||||
|
||||
|
||||
def show_oas(args: argparse.Namespace):
|
||||
print(json.dumps(generate_oas(args.apps), sort_keys=True, indent=4))
|
||||
"""
|
||||
Display Open API Specification on the stdout.
|
||||
"""
|
||||
spec = args.base
|
||||
spec.update(generate_oas(args.apps))
|
||||
print(args.formatter(spec), file=args.output)
|
||||
|
||||
136
aiohttp_pydantic/oas/docstring_parser.py
Normal file
136
aiohttp_pydantic/oas/docstring_parser.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Utility to extract extra OAS description from docstring.
|
||||
"""
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class LinesIterator:
|
||||
def __init__(self, lines: str):
|
||||
self._lines = lines.splitlines()
|
||||
self._i = -1
|
||||
|
||||
def next_line(self) -> str:
|
||||
if self._i == len(self._lines) - 1:
|
||||
raise StopIteration from None
|
||||
self._i += 1
|
||||
return self._lines[self._i]
|
||||
|
||||
def rewind(self) -> str:
|
||||
if self._i == -1:
|
||||
raise StopIteration from None
|
||||
self._i -= 1
|
||||
return self._lines[self._i]
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.next_line()
|
||||
|
||||
|
||||
def _i_extract_block(lines: LinesIterator):
|
||||
"""
|
||||
Iter the line within an indented block and dedent them.
|
||||
"""
|
||||
|
||||
# Go to the first not empty or not white space line.
|
||||
try:
|
||||
line = next(lines)
|
||||
except StopIteration:
|
||||
return # No block to extract.
|
||||
while line.strip() == "":
|
||||
try:
|
||||
line = next(lines)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
indent = re.fullmatch("( *).*", line).groups()[0]
|
||||
indentation = len(indent)
|
||||
start_of_other_block = re.compile(f" {{0,{indentation}}}[^ ].*")
|
||||
yield line[indentation:]
|
||||
|
||||
# Yield lines until the indentation is the same or is greater than
|
||||
# the first block line.
|
||||
try:
|
||||
line = next(lines)
|
||||
except StopIteration:
|
||||
return
|
||||
while not start_of_other_block.fullmatch(line):
|
||||
yield line[indentation:]
|
||||
try:
|
||||
line = next(lines)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
lines.rewind()
|
||||
|
||||
|
||||
def _dedent_under_first_line(text: str) -> str:
|
||||
"""
|
||||
Apply textwrap.dedent ignoring the first line.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
other_lines = "\n".join(lines[1:])
|
||||
if other_lines:
|
||||
return f"{lines[0]}\n{textwrap.dedent(other_lines)}"
|
||||
return text
|
||||
|
||||
|
||||
def status_code(docstring: str) -> Dict[int, str]:
|
||||
"""
|
||||
Extract the "Status Code:" block of the docstring.
|
||||
"""
|
||||
iterator = LinesIterator(docstring)
|
||||
for line in iterator:
|
||||
if re.fullmatch("status\\s+codes?\\s*:", line, re.IGNORECASE):
|
||||
iterator.rewind()
|
||||
blocks = []
|
||||
lines = []
|
||||
i_block = _i_extract_block(iterator)
|
||||
next(i_block)
|
||||
for line_of_block in i_block:
|
||||
if re.search("^\\s*\\d{3}\\s*:", line_of_block):
|
||||
if lines:
|
||||
blocks.append("\n".join(lines))
|
||||
lines = []
|
||||
lines.append(line_of_block)
|
||||
if lines:
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
return {
|
||||
int(status.strip()): _dedent_under_first_line(desc.strip())
|
||||
for status, desc in (block.split(":", 1) for block in blocks)
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def tags(docstring: str) -> List[str]:
|
||||
"""
|
||||
Extract the "Tags:" block of the docstring.
|
||||
"""
|
||||
iterator = LinesIterator(docstring)
|
||||
for line in iterator:
|
||||
if re.fullmatch("tags\\s*:.*", line, re.IGNORECASE):
|
||||
iterator.rewind()
|
||||
lines = " ".join(_i_extract_block(iterator))
|
||||
return [" ".join(e.split()) for e in re.split("[,;]", lines.split(":")[1])]
|
||||
return []
|
||||
|
||||
|
||||
def operation(docstring: str) -> str:
|
||||
"""
|
||||
Extract all docstring except the "Status Code:" block.
|
||||
"""
|
||||
lines = LinesIterator(docstring)
|
||||
ret = []
|
||||
for line in lines:
|
||||
if re.fullmatch("status\\s+codes?\\s*:|tags\\s*:.*", line, re.IGNORECASE):
|
||||
lines.rewind()
|
||||
for _ in _i_extract_block(lines):
|
||||
pass
|
||||
else:
|
||||
ret.append(line)
|
||||
return ("\n".join(ret)).strip()
|
||||
@@ -2,7 +2,7 @@
|
||||
Utility to write Open Api Specifications using the Python language.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class Info:
|
||||
@@ -133,6 +133,7 @@ class Parameters:
|
||||
class Response:
|
||||
def __init__(self, spec: dict):
|
||||
self._spec = spec
|
||||
self._spec.setdefault("description", "")
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -156,7 +157,7 @@ class Responses:
|
||||
self._spec = spec.setdefault("responses", {})
|
||||
|
||||
def __getitem__(self, status_code: Union[int, str]) -> Response:
|
||||
if not (100 <= int(status_code) < 600):
|
||||
if not 100 <= int(status_code) < 600:
|
||||
raise ValueError("status_code must be between 100 and 599")
|
||||
|
||||
spec = self._spec.setdefault(str(status_code), {})
|
||||
@@ -195,6 +196,17 @@ class OperationObject:
|
||||
def responses(self) -> Responses:
|
||||
return Responses(self._spec)
|
||||
|
||||
@property
|
||||
def tags(self) -> List[str]:
|
||||
return self._spec.get("tags", [])[:]
|
||||
|
||||
@tags.setter
|
||||
def tags(self, tags: List[str]):
|
||||
if tags:
|
||||
self._spec["tags"] = tags[:]
|
||||
else:
|
||||
self._spec.pop("tags", None)
|
||||
|
||||
|
||||
class PathItem:
|
||||
def __init__(self, spec: dict):
|
||||
@@ -304,7 +316,10 @@ class Components:
|
||||
|
||||
class OpenApiSpec3:
|
||||
def __init__(self):
|
||||
self._spec = {"openapi": "3.0.0"}
|
||||
self._spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"version": "1.0.0", "title": "Aiohttp pydantic application"},
|
||||
}
|
||||
|
||||
@property
|
||||
def info(self) -> Info:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import typing
|
||||
from inspect import getdoc
|
||||
from itertools import count
|
||||
from typing import List, Type
|
||||
from typing import List, Type, Optional, get_type_hints
|
||||
|
||||
from aiohttp.web import Response, json_response
|
||||
from aiohttp.web_app import Application
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem
|
||||
from . import docstring_parser
|
||||
|
||||
from ..injectors import _parse_func_signature
|
||||
from ..utils import is_pydantic_base_model
|
||||
@@ -15,35 +16,23 @@ from ..view import PydanticView, is_pydantic_view
|
||||
from .typing import is_status_code_type
|
||||
|
||||
|
||||
def _handle_optional(type_):
|
||||
"""
|
||||
Returns the type wrapped in Optional or None.
|
||||
|
||||
>>> _handle_optional(int)
|
||||
>>> _handle_optional(Optional[str])
|
||||
<class 'str'>
|
||||
"""
|
||||
if typing.get_origin(type_) is typing.Union:
|
||||
args = typing.get_args(type_)
|
||||
if len(args) == 2 and type(None) in args:
|
||||
return next(iter(set(args) - {type(None)}))
|
||||
return None
|
||||
|
||||
|
||||
class _OASResponseBuilder:
|
||||
"""
|
||||
Parse the type annotated as returned by a function and
|
||||
generate the OAS operation response.
|
||||
"""
|
||||
|
||||
def __init__(self, oas: OpenApiSpec3, oas_operation):
|
||||
def __init__(self, oas: OpenApiSpec3, oas_operation, status_code_descriptions):
|
||||
self._oas_operation = oas_operation
|
||||
self._oas = oas
|
||||
self._status_code_descriptions = status_code_descriptions
|
||||
|
||||
def _handle_pydantic_base_model(self, obj):
|
||||
if is_pydantic_base_model(obj):
|
||||
response_schema = obj.schema(ref_template="#/components/schemas/{model}")
|
||||
if def_sub_schemas := response_schema.get("definitions", None):
|
||||
response_schema = obj.schema(
|
||||
ref_template="#/components/schemas/{model}"
|
||||
).copy()
|
||||
if def_sub_schemas := response_schema.pop("definitions", None):
|
||||
self._oas.components.schemas.update(def_sub_schemas)
|
||||
return response_schema
|
||||
return {}
|
||||
@@ -64,10 +53,16 @@ class _OASResponseBuilder:
|
||||
"schema": self._handle_list(typing.get_args(obj)[0])
|
||||
}
|
||||
}
|
||||
desc = self._status_code_descriptions.get(int(status_code))
|
||||
if desc:
|
||||
self._oas_operation.responses[status_code].description = desc
|
||||
|
||||
elif is_status_code_type(obj):
|
||||
status_code = obj.__name__[1:]
|
||||
self._oas_operation.responses[status_code].content = {}
|
||||
desc = self._status_code_descriptions.get(int(status_code))
|
||||
if desc:
|
||||
self._oas_operation.responses[status_code].description = desc
|
||||
|
||||
def _handle_union(self, obj):
|
||||
if typing.get_origin(obj) is typing.Union:
|
||||
@@ -86,17 +81,23 @@ def _add_http_method_to_oas(
|
||||
oas_operation: OperationObject = getattr(oas_path, http_method)
|
||||
handler = getattr(view, http_method)
|
||||
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
|
||||
handler
|
||||
handler, unpack_group=True
|
||||
)
|
||||
description = getdoc(handler)
|
||||
if description:
|
||||
oas_operation.description = description
|
||||
oas_operation.description = docstring_parser.operation(description)
|
||||
oas_operation.tags = docstring_parser.tags(description)
|
||||
status_code_descriptions = docstring_parser.status_code(description)
|
||||
else:
|
||||
status_code_descriptions = {}
|
||||
|
||||
if body_args:
|
||||
body_schema = next(iter(body_args.values())).schema(
|
||||
ref_template="#/components/schemas/{model}"
|
||||
body_schema = (
|
||||
next(iter(body_args.values()))
|
||||
.schema(ref_template="#/components/schemas/{model}")
|
||||
.copy()
|
||||
)
|
||||
if def_sub_schemas := body_schema.get("definitions", None):
|
||||
if def_sub_schemas := body_schema.pop("definitions", None):
|
||||
oas.components.schemas.update(def_sub_schemas)
|
||||
|
||||
oas_operation.request_body.content = {
|
||||
@@ -113,28 +114,41 @@ def _add_http_method_to_oas(
|
||||
i = next(indexes)
|
||||
oas_operation.parameters[i].in_ = args_location
|
||||
oas_operation.parameters[i].name = name
|
||||
optional_type = _handle_optional(type_)
|
||||
|
||||
attrs = {"__annotations__": {"__root__": type_}}
|
||||
if name in defaults:
|
||||
attrs["__root__"] = defaults[name]
|
||||
oas_operation.parameters[i].required = False
|
||||
else:
|
||||
oas_operation.parameters[i].required = True
|
||||
|
||||
oas_operation.parameters[i].schema = type(name, (BaseModel,), attrs).schema(
|
||||
ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
|
||||
oas_operation.parameters[i].required = optional_type is None
|
||||
|
||||
return_type = handler.__annotations__.get("return")
|
||||
return_type = get_type_hints(handler).get("return")
|
||||
if return_type is not None:
|
||||
_OASResponseBuilder(oas, oas_operation).build(return_type)
|
||||
_OASResponseBuilder(oas, oas_operation, status_code_descriptions).build(
|
||||
return_type
|
||||
)
|
||||
|
||||
|
||||
def generate_oas(apps: List[Application]) -> dict:
|
||||
def generate_oas(
|
||||
apps: List[Application],
|
||||
version_spec: Optional[str] = None,
|
||||
title_spec: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate and return Open Api Specification from PydanticView in application.
|
||||
"""
|
||||
oas = OpenApiSpec3()
|
||||
|
||||
if version_spec is not None:
|
||||
oas.info.version = version_spec
|
||||
|
||||
if title_spec is not None:
|
||||
oas.info.title = title_spec
|
||||
|
||||
for app in apps:
|
||||
for resources in app.router.resources():
|
||||
for resource_route in resources:
|
||||
@@ -158,7 +172,9 @@ async def get_oas(request):
|
||||
View to generate the Open Api Specification from PydanticView in application.
|
||||
"""
|
||||
apps = request.app["apps to expose"]
|
||||
return json_response(generate_oas(apps))
|
||||
version_spec = request.app["version_spec"]
|
||||
title_spec = request.app["title_spec"]
|
||||
return json_response(generate_oas(apps, version_spec, title_spec))
|
||||
|
||||
|
||||
async def oas_ui(request):
|
||||
|
||||
@@ -5,7 +5,15 @@ def is_pydantic_base_model(obj):
|
||||
"""
|
||||
Return true is obj is a pydantic.BaseModel subclass.
|
||||
"""
|
||||
return robuste_issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def robuste_issubclass(cls1, cls2):
|
||||
"""
|
||||
function likes issubclass but returns False instead of raise type error
|
||||
if first parameter is not a class.
|
||||
"""
|
||||
try:
|
||||
return issubclass(obj, BaseModel)
|
||||
return issubclass(cls1, cls2)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from functools import update_wrapper
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Any, Callable, Generator, Iterable
|
||||
from typing import Any, Callable, Generator, Iterable, Set, ClassVar
|
||||
import warnings
|
||||
|
||||
from aiohttp.abc import AbstractView
|
||||
from aiohttp.hdrs import METH_ALL
|
||||
@@ -9,8 +10,16 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed
|
||||
from aiohttp.web_response import StreamResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .injectors import (AbstractInjector, BodyGetter, HeadersGetter,
|
||||
MatchInfoGetter, QueryGetter, _parse_func_signature)
|
||||
from .injectors import (
|
||||
AbstractInjector,
|
||||
BodyGetter,
|
||||
HeadersGetter,
|
||||
MatchInfoGetter,
|
||||
QueryGetter,
|
||||
_parse_func_signature,
|
||||
CONTEXT,
|
||||
Group,
|
||||
)
|
||||
|
||||
|
||||
class PydanticView(AbstractView):
|
||||
@@ -18,30 +27,46 @@ class PydanticView(AbstractView):
|
||||
An AIOHTTP View that validate request using function annotations.
|
||||
"""
|
||||
|
||||
# Allowed HTTP methods; overridden when subclassed.
|
||||
allowed_methods: ClassVar[Set[str]] = {}
|
||||
|
||||
async def _iter(self) -> StreamResponse:
|
||||
method = getattr(self, self.request.method.lower(), None)
|
||||
resp = await method()
|
||||
return resp
|
||||
if (method_name := self.request.method) not in self.allowed_methods:
|
||||
self._raise_allowed_methods()
|
||||
return await getattr(self, method_name.lower())()
|
||||
|
||||
def __await__(self) -> Generator[Any, None, StreamResponse]:
|
||||
return self._iter().__await__()
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""Define allowed methods and decorate handlers.
|
||||
|
||||
Handlers are decorated if and only if they directly bound on the PydanticView class or
|
||||
PydanticView subclass. This prevents that methods are decorated multiple times and that method
|
||||
defined in aiohttp.View parent class is decorated.
|
||||
"""
|
||||
|
||||
cls.allowed_methods = {
|
||||
meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower())
|
||||
}
|
||||
|
||||
for meth_name in METH_ALL:
|
||||
if meth_name not in cls.allowed_methods:
|
||||
setattr(cls, meth_name.lower(), cls.raise_not_allowed)
|
||||
else:
|
||||
if meth_name.lower() in vars(cls):
|
||||
handler = getattr(cls, meth_name.lower())
|
||||
decorated_handler = inject_params(handler, cls.parse_func_signature)
|
||||
setattr(cls, meth_name.lower(), decorated_handler)
|
||||
|
||||
async def raise_not_allowed(self):
|
||||
def _raise_allowed_methods(self) -> None:
|
||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||
|
||||
def raise_not_allowed(self) -> None:
|
||||
warnings.warn(
|
||||
"PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._raise_allowed_methods()
|
||||
|
||||
@staticmethod
|
||||
def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]:
|
||||
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
|
||||
@@ -65,6 +90,22 @@ class PydanticView(AbstractView):
|
||||
injectors.append(HeadersGetter(header_args, default_value(header_args)))
|
||||
return injectors
|
||||
|
||||
async def on_validation_error(
|
||||
self, exception: ValidationError, context: CONTEXT
|
||||
) -> StreamResponse:
|
||||
"""
|
||||
This method is a hook to intercept ValidationError.
|
||||
|
||||
This hook can be redefined to return a custom HTTP response error.
|
||||
The exception is a pydantic.ValidationError and the context is "body",
|
||||
"headers", "path" or "query string"
|
||||
"""
|
||||
errors = exception.errors()
|
||||
for error in errors:
|
||||
error["in"] = context
|
||||
|
||||
return json_response(data=errors, status=400)
|
||||
|
||||
|
||||
def inject_params(
|
||||
handler, parse_func_signature: Callable[[Callable], Iterable[AbstractInjector]]
|
||||
@@ -89,11 +130,7 @@ def inject_params(
|
||||
if self.request.app['raise_validation_errors']:
|
||||
raise
|
||||
else:
|
||||
errors = error.errors()
|
||||
for error in errors:
|
||||
error["in"] = injector.context
|
||||
|
||||
return json_response(data=errors, status=400)
|
||||
return await self.on_validation_error(error, injector.context)
|
||||
|
||||
return await handler(self, *args, **kwargs)
|
||||
|
||||
@@ -109,3 +146,14 @@ def is_pydantic_view(obj) -> bool:
|
||||
return issubclass(obj, PydanticView)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AbstractInjector",
|
||||
"BodyGetter",
|
||||
"HeadersGetter",
|
||||
"MatchInfoGetter",
|
||||
"QueryGetter",
|
||||
"CONTEXT",
|
||||
"Group",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user