Merge branch 'main' into fixed

# Conflicts:
#	aiohttp_pydantic/oas/__init__.py
#	aiohttp_pydantic/view.py
This commit is contained in:
Georg K
2022-03-30 17:30:15 +03:00
36 changed files with 2068 additions and 247 deletions

View File

@@ -1,5 +1,5 @@
from .view import PydanticView
__version__ = "1.6.1"
__version__ = "1.12.1"
__all__ = ("PydanticView", "__version__")

View File

@@ -1,13 +1,18 @@
import abc
from inspect import signature
import typing
from inspect import signature, getmro
from json.decoder import JSONDecodeError
from typing import Callable, Tuple
from types import SimpleNamespace
from typing import Callable, Tuple, Literal, Type, get_type_hints
from aiohttp.web_exceptions import HTTPBadRequest
from aiohttp.web_request import BaseRequest
from multidict import MultiDict
from pydantic import BaseModel
from .utils import is_pydantic_base_model
from .utils import is_pydantic_base_model, robuste_issubclass
CONTEXT = Literal["body", "headers", "path", "query string"]
class AbstractInjector(metaclass=abc.ABCMeta):
@@ -15,9 +20,11 @@ class AbstractInjector(metaclass=abc.ABCMeta):
An injector parse HTTP request and inject params to the view.
"""
model: Type[BaseModel]
@property
@abc.abstractmethod
def context(self) -> str:
def context(self) -> CONTEXT:
"""
The name of part of parsed request
i.e "HTTP header", "URL path", ...
@@ -61,6 +68,7 @@ class BodyGetter(AbstractInjector):
def __init__(self, args_spec: dict, default_values: dict):
self.arg_name, self.model = next(iter(args_spec.items()))
self._expect_object = self.model.schema()["type"] == "object"
async def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
try:
@@ -70,7 +78,16 @@ class BodyGetter(AbstractInjector):
text='{"error": "Malformed JSON"}', content_type="application/json"
) from None
kwargs_view[self.arg_name] = self.model(**body)
# Pydantic tries to cast certain structures, such as a list of 2-tuples,
# to a dict. Prevent this by requiring the body to be a dict for object models.
if self._expect_object and not isinstance(body, dict):
raise HTTPBadRequest(
text='[{"in": "body", "loc": ["__root__"], "msg": "value is not a '
'valid dict", "type": "type_error.dict"}]',
content_type="application/json",
) from None
kwargs_view[self.arg_name] = self.model.parse_obj(body)
class QueryGetter(AbstractInjector):
@@ -81,12 +98,46 @@ class QueryGetter(AbstractInjector):
context = "query string"
def __init__(self, args_spec: dict, default_values: dict):
args_spec = args_spec.copy()
self._groups = {}
for group_name, group in args_spec.items():
if robuste_issubclass(group, Group):
self._groups[group_name] = (group, _get_group_signature(group)[0])
_unpack_group_in_signature(args_spec, default_values)
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("QueryModel", (BaseModel,), attrs)
self.args_spec = args_spec
self._is_multiple = frozenset(
name for name, spec in args_spec.items() if typing.get_origin(spec) is list
)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
kwargs_view.update(self.model(**request.query).dict())
data = self._query_to_dict(request.query)
cleaned = self.model(**data).dict()
for group_name, (group_cls, group_attrs) in self._groups.items():
group = group_cls()
for attr_name in group_attrs:
setattr(group, attr_name, cleaned.pop(attr_name))
cleaned[group_name] = group
kwargs_view.update(**cleaned)
def _query_to_dict(self, query: MultiDict):
"""
Return a dict with list as value from the MultiDict.
The value will be wrapped in a list if the args spec is define as a list or if
the multiple values are sent (i.e ?foo=1&foo=2)
"""
return {
key: values
if len(values := query.getall(key)) > 1 or key in self._is_multiple
else value
for key, value in query.items()
}
class HeadersGetter(AbstractInjector):
@@ -97,18 +148,80 @@ class HeadersGetter(AbstractInjector):
context = "headers"
def __init__(self, args_spec: dict, default_values: dict):
args_spec = args_spec.copy()
self._groups = {}
for group_name, group in args_spec.items():
if robuste_issubclass(group, Group):
self._groups[group_name] = (group, _get_group_signature(group)[0])
_unpack_group_in_signature(args_spec, default_values)
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("HeaderModel", (BaseModel,), attrs)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
kwargs_view.update(self.model(**header).dict())
cleaned = self.model(**header).dict()
for group_name, (group_cls, group_attrs) in self._groups.items():
group = group_cls()
for attr_name in group_attrs:
setattr(group, attr_name, cleaned.pop(attr_name))
cleaned[group_name] = group
kwargs_view.update(cleaned)
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]:
class Group(SimpleNamespace):
"""
Analyse function signature and returns 4-tuple:
Class to group header or query string parameters.
The parameter from query string or header will be set in the group
and the group will be passed as function parameter.
Example:
class Pagination(Group):
current_page: int = 1
page_size: int = 15
class PetView(PydanticView):
def get(self, page: Pagination):
...
"""
def _get_group_signature(cls) -> Tuple[dict, dict]:
"""
Analyse Group subclass annotations and return them with default values.
"""
sig = {}
defaults = {}
mro = getmro(cls)
for base in reversed(mro[: mro.index(Group)]):
attrs = vars(base)
# Use __annotations__ to know if an attribute is
# overwrite to remove the default value.
for attr_name, type_ in base.__annotations__.items():
if (default := attrs.get(attr_name)) is None:
defaults.pop(attr_name, None)
else:
defaults[attr_name] = default
# Use get_type_hints to have postponed annotations.
for attr_name, type_ in get_type_hints(base).items():
sig[attr_name] = type_
return sig, defaults
def _parse_func_signature(
func: Callable, unpack_group: bool = False
) -> Tuple[dict, dict, dict, dict, dict]:
"""
Analyse function signature and returns 5-tuple:
0 - arguments will be set from the url path
1 - argument will be set from the request body.
2 - argument will be set from the query string.
@@ -122,27 +235,72 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]
header_args = {}
defaults = {}
annotations = get_type_hints(func)
for param_name, param_spec in signature(func).parameters.items():
if param_name == "self":
continue
if param_spec.annotation == param_spec.empty:
raise RuntimeError(f"The parameter {param_name} must have an annotation")
annotation = annotations[param_name]
if param_spec.default is not param_spec.empty:
defaults[param_name] = param_spec.default
if param_spec.kind is param_spec.POSITIONAL_ONLY:
path_args[param_name] = param_spec.annotation
path_args[param_name] = annotation
elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD:
if is_pydantic_base_model(param_spec.annotation):
body_args[param_name] = param_spec.annotation
if is_pydantic_base_model(annotation):
body_args[param_name] = annotation
else:
qs_args[param_name] = param_spec.annotation
qs_args[param_name] = annotation
elif param_spec.kind is param_spec.KEYWORD_ONLY:
header_args[param_name] = param_spec.annotation
header_args[param_name] = annotation
else:
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
if unpack_group:
try:
_unpack_group_in_signature(qs_args, defaults)
_unpack_group_in_signature(header_args, defaults)
except DuplicateNames as error:
raise TypeError(
f"Parameters conflict in function {func},"
f" the group {error.group} has an attribute named {error.attr_name}"
) from None
return path_args, body_args, qs_args, header_args, defaults
class DuplicateNames(Exception):
"""
Raised when a same parameter name is used in group and function signature.
"""
group: Type[Group]
attr_name: str
def __init__(self, group: Type[Group], attr_name: str):
self.group = group
self.attr_name = attr_name
super().__init__(
f"Conflict with {group}.{attr_name} and function parameter name"
)
def _unpack_group_in_signature(args: dict, defaults: dict) -> None:
"""
Unpack in place each Group found in args.
"""
for group_name, group in args.copy().items():
if robuste_issubclass(group, Group):
group_sig, group_default = _get_group_signature(group)
for attr_name in group_sig:
if attr_name in args and attr_name != group_name:
raise DuplicateNames(group, attr_name)
del args[group_name]
args.update(group_sig)
defaults.update(group_default)

View File

@@ -1,5 +1,5 @@
from importlib import resources
from typing import Iterable
from typing import Iterable, Optional
import jinja2
from aiohttp import web
@@ -13,6 +13,8 @@ def setup(
apps_to_expose: Iterable[web.Application] = (),
url_prefix: str = "/oas",
enable: bool = True,
version_spec: Optional[str] = None,
title_spec: Optional[str] = None,
raise_validation_errors: bool = False,
):
if enable:
@@ -23,6 +25,9 @@ def setup(
oas_app["index template"] = jinja2.Template(
resources.read_text("aiohttp_pydantic.oas", "index.j2")
)
oas_app["version_spec"] = version_spec
oas_app["title_spec"] = title_spec
oas_app.router.add_get("/spec", get_oas, name="spec")
oas_app.router.add_static("/static", swagger_ui_path, name="static")
oas_app.router.add_get("", oas_ui, name="index")

View File

@@ -1,10 +1,28 @@
import argparse
import importlib
import json
from typing import Dict, Protocol, Optional, Callable
import sys
from .view import generate_oas
class YamlModule(Protocol):
"""
Yaml Module type hint
"""
def dump(self, data) -> str:
pass
yaml: Optional[YamlModule]
try:
import yaml
except ImportError:
yaml = None
def application_type(value):
"""
Return aiohttp application defined in the value.
@@ -26,6 +44,35 @@ def application_type(value):
raise argparse.ArgumentTypeError(error) from error
def base_oas_file_type(value) -> Dict:
"""
Load base oas file
"""
try:
with open(value) as oas_file:
data = oas_file.read()
except OSError as error:
raise argparse.ArgumentTypeError(error) from error
return json.loads(data)
def format_type(value) -> Callable:
"""
Date Dumper one of (json, yaml)
"""
dumpers = {"json": lambda data: json.dumps(data, sort_keys=True, indent=4)}
if yaml is not None:
dumpers["yaml"] = yaml.dump
try:
return dumpers[value]
except KeyError:
raise argparse.ArgumentTypeError(
f"Wrong format value. (allowed values: {tuple(dumpers.keys())})"
) from None
def setup(parser: argparse.ArgumentParser):
parser.add_argument(
"apps",
@@ -35,11 +82,52 @@ def setup(parser: argparse.ArgumentParser):
help="The name of the module containing the asyncio.web.Application."
" By default the variable named 'app' is loaded but you can define"
" an other variable name ending the name of module with : characters"
" and the name of variable. Example: my_package.my_module:my_app",
" and the name of variable. Example: my_package.my_module:my_app"
" If your asyncio.web.Application is returned by a function, you can"
" use the syntax: my_package.my_module:my_app()",
)
parser.add_argument(
"-b",
"--base-oas-file",
metavar="FILE",
dest="base",
type=base_oas_file_type,
help="A file that will be used as base to generate OAS",
default={},
)
parser.add_argument(
"-o",
"--output",
metavar="FILE",
type=argparse.FileType("w"),
help="File to write the output",
default=sys.stdout,
)
if yaml:
help_output_format = (
"The output format, can be 'json' or 'yaml' (default is json)"
)
else:
help_output_format = "The output format, only 'json' is available install pyyaml to have yaml output format"
parser.add_argument(
"-f",
"--format",
metavar="FORMAT",
dest="formatter",
type=format_type,
help=help_output_format,
default=format_type("json"),
)
parser.set_defaults(func=show_oas)
def show_oas(args: argparse.Namespace):
print(json.dumps(generate_oas(args.apps), sort_keys=True, indent=4))
"""
Display Open API Specification on the stdout.
"""
spec = args.base
spec.update(generate_oas(args.apps))
print(args.formatter(spec), file=args.output)

View File

@@ -0,0 +1,136 @@
"""
Utility to extract extra OAS description from docstring.
"""
import re
import textwrap
from typing import Dict, List
class LinesIterator:
def __init__(self, lines: str):
self._lines = lines.splitlines()
self._i = -1
def next_line(self) -> str:
if self._i == len(self._lines) - 1:
raise StopIteration from None
self._i += 1
return self._lines[self._i]
def rewind(self) -> str:
if self._i == -1:
raise StopIteration from None
self._i -= 1
return self._lines[self._i]
def __iter__(self):
return self
def __next__(self):
return self.next_line()
def _i_extract_block(lines: LinesIterator):
"""
Iter the line within an indented block and dedent them.
"""
# Go to the first not empty or not white space line.
try:
line = next(lines)
except StopIteration:
return # No block to extract.
while line.strip() == "":
try:
line = next(lines)
except StopIteration:
return
indent = re.fullmatch("( *).*", line).groups()[0]
indentation = len(indent)
start_of_other_block = re.compile(f" {{0,{indentation}}}[^ ].*")
yield line[indentation:]
# Yield lines until the indentation is the same or is greater than
# the first block line.
try:
line = next(lines)
except StopIteration:
return
while not start_of_other_block.fullmatch(line):
yield line[indentation:]
try:
line = next(lines)
except StopIteration:
return
lines.rewind()
def _dedent_under_first_line(text: str) -> str:
"""
Apply textwrap.dedent ignoring the first line.
"""
lines = text.splitlines()
other_lines = "\n".join(lines[1:])
if other_lines:
return f"{lines[0]}\n{textwrap.dedent(other_lines)}"
return text
def status_code(docstring: str) -> Dict[int, str]:
"""
Extract the "Status Code:" block of the docstring.
"""
iterator = LinesIterator(docstring)
for line in iterator:
if re.fullmatch("status\\s+codes?\\s*:", line, re.IGNORECASE):
iterator.rewind()
blocks = []
lines = []
i_block = _i_extract_block(iterator)
next(i_block)
for line_of_block in i_block:
if re.search("^\\s*\\d{3}\\s*:", line_of_block):
if lines:
blocks.append("\n".join(lines))
lines = []
lines.append(line_of_block)
if lines:
blocks.append("\n".join(lines))
return {
int(status.strip()): _dedent_under_first_line(desc.strip())
for status, desc in (block.split(":", 1) for block in blocks)
}
return {}
def tags(docstring: str) -> List[str]:
"""
Extract the "Tags:" block of the docstring.
"""
iterator = LinesIterator(docstring)
for line in iterator:
if re.fullmatch("tags\\s*:.*", line, re.IGNORECASE):
iterator.rewind()
lines = " ".join(_i_extract_block(iterator))
return [" ".join(e.split()) for e in re.split("[,;]", lines.split(":")[1])]
return []
def operation(docstring: str) -> str:
"""
Extract all docstring except the "Status Code:" block.
"""
lines = LinesIterator(docstring)
ret = []
for line in lines:
if re.fullmatch("status\\s+codes?\\s*:|tags\\s*:.*", line, re.IGNORECASE):
lines.rewind()
for _ in _i_extract_block(lines):
pass
else:
ret.append(line)
return ("\n".join(ret)).strip()

View File

@@ -2,7 +2,7 @@
Utility to write Open Api Specifications using the Python language.
"""
from typing import Union
from typing import Union, List
class Info:
@@ -133,6 +133,7 @@ class Parameters:
class Response:
def __init__(self, spec: dict):
self._spec = spec
self._spec.setdefault("description", "")
@property
def description(self) -> str:
@@ -156,7 +157,7 @@ class Responses:
self._spec = spec.setdefault("responses", {})
def __getitem__(self, status_code: Union[int, str]) -> Response:
if not (100 <= int(status_code) < 600):
if not 100 <= int(status_code) < 600:
raise ValueError("status_code must be between 100 and 599")
spec = self._spec.setdefault(str(status_code), {})
@@ -195,6 +196,17 @@ class OperationObject:
def responses(self) -> Responses:
return Responses(self._spec)
@property
def tags(self) -> List[str]:
return self._spec.get("tags", [])[:]
@tags.setter
def tags(self, tags: List[str]):
if tags:
self._spec["tags"] = tags[:]
else:
self._spec.pop("tags", None)
class PathItem:
def __init__(self, spec: dict):
@@ -304,7 +316,10 @@ class Components:
class OpenApiSpec3:
def __init__(self):
self._spec = {"openapi": "3.0.0"}
self._spec = {
"openapi": "3.0.0",
"info": {"version": "1.0.0", "title": "Aiohttp pydantic application"},
}
@property
def info(self) -> Info:

View File

@@ -1,13 +1,14 @@
import typing
from inspect import getdoc
from itertools import count
from typing import List, Type
from typing import List, Type, Optional, get_type_hints
from aiohttp.web import Response, json_response
from aiohttp.web_app import Application
from pydantic import BaseModel
from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem
from . import docstring_parser
from ..injectors import _parse_func_signature
from ..utils import is_pydantic_base_model
@@ -15,35 +16,23 @@ from ..view import PydanticView, is_pydantic_view
from .typing import is_status_code_type
def _handle_optional(type_):
"""
Returns the type wrapped in Optional or None.
>>> _handle_optional(int)
>>> _handle_optional(Optional[str])
<class 'str'>
"""
if typing.get_origin(type_) is typing.Union:
args = typing.get_args(type_)
if len(args) == 2 and type(None) in args:
return next(iter(set(args) - {type(None)}))
return None
class _OASResponseBuilder:
"""
Parse the type annotated as returned by a function and
generate the OAS operation response.
"""
def __init__(self, oas: OpenApiSpec3, oas_operation):
def __init__(self, oas: OpenApiSpec3, oas_operation, status_code_descriptions):
self._oas_operation = oas_operation
self._oas = oas
self._status_code_descriptions = status_code_descriptions
def _handle_pydantic_base_model(self, obj):
if is_pydantic_base_model(obj):
response_schema = obj.schema(ref_template="#/components/schemas/{model}")
if def_sub_schemas := response_schema.get("definitions", None):
response_schema = obj.schema(
ref_template="#/components/schemas/{model}"
).copy()
if def_sub_schemas := response_schema.pop("definitions", None):
self._oas.components.schemas.update(def_sub_schemas)
return response_schema
return {}
@@ -64,10 +53,16 @@ class _OASResponseBuilder:
"schema": self._handle_list(typing.get_args(obj)[0])
}
}
desc = self._status_code_descriptions.get(int(status_code))
if desc:
self._oas_operation.responses[status_code].description = desc
elif is_status_code_type(obj):
status_code = obj.__name__[1:]
self._oas_operation.responses[status_code].content = {}
desc = self._status_code_descriptions.get(int(status_code))
if desc:
self._oas_operation.responses[status_code].description = desc
def _handle_union(self, obj):
if typing.get_origin(obj) is typing.Union:
@@ -86,17 +81,23 @@ def _add_http_method_to_oas(
oas_operation: OperationObject = getattr(oas_path, http_method)
handler = getattr(view, http_method)
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
handler
handler, unpack_group=True
)
description = getdoc(handler)
if description:
oas_operation.description = description
oas_operation.description = docstring_parser.operation(description)
oas_operation.tags = docstring_parser.tags(description)
status_code_descriptions = docstring_parser.status_code(description)
else:
status_code_descriptions = {}
if body_args:
body_schema = next(iter(body_args.values())).schema(
ref_template="#/components/schemas/{model}"
body_schema = (
next(iter(body_args.values()))
.schema(ref_template="#/components/schemas/{model}")
.copy()
)
if def_sub_schemas := body_schema.get("definitions", None):
if def_sub_schemas := body_schema.pop("definitions", None):
oas.components.schemas.update(def_sub_schemas)
oas_operation.request_body.content = {
@@ -113,28 +114,41 @@ def _add_http_method_to_oas(
i = next(indexes)
oas_operation.parameters[i].in_ = args_location
oas_operation.parameters[i].name = name
optional_type = _handle_optional(type_)
attrs = {"__annotations__": {"__root__": type_}}
if name in defaults:
attrs["__root__"] = defaults[name]
oas_operation.parameters[i].required = False
else:
oas_operation.parameters[i].required = True
oas_operation.parameters[i].schema = type(name, (BaseModel,), attrs).schema(
ref_template="#/components/schemas/{model}"
)
oas_operation.parameters[i].required = optional_type is None
return_type = handler.__annotations__.get("return")
return_type = get_type_hints(handler).get("return")
if return_type is not None:
_OASResponseBuilder(oas, oas_operation).build(return_type)
_OASResponseBuilder(oas, oas_operation, status_code_descriptions).build(
return_type
)
def generate_oas(apps: List[Application]) -> dict:
def generate_oas(
apps: List[Application],
version_spec: Optional[str] = None,
title_spec: Optional[str] = None,
) -> dict:
"""
Generate and return Open Api Specification from PydanticView in application.
"""
oas = OpenApiSpec3()
if version_spec is not None:
oas.info.version = version_spec
if title_spec is not None:
oas.info.title = title_spec
for app in apps:
for resources in app.router.resources():
for resource_route in resources:
@@ -158,7 +172,9 @@ async def get_oas(request):
View to generate the Open Api Specification from PydanticView in application.
"""
apps = request.app["apps to expose"]
return json_response(generate_oas(apps))
version_spec = request.app["version_spec"]
title_spec = request.app["title_spec"]
return json_response(generate_oas(apps, version_spec, title_spec))
async def oas_ui(request):

View File

@@ -5,7 +5,15 @@ def is_pydantic_base_model(obj):
"""
Return true is obj is a pydantic.BaseModel subclass.
"""
return robuste_issubclass(obj, BaseModel)
def robuste_issubclass(cls1, cls2):
"""
function likes issubclass but returns False instead of raise type error
if first parameter is not a class.
"""
try:
return issubclass(obj, BaseModel)
return issubclass(cls1, cls2)
except TypeError:
return False

View File

@@ -1,6 +1,7 @@
from functools import update_wrapper
from inspect import iscoroutinefunction
from typing import Any, Callable, Generator, Iterable
from typing import Any, Callable, Generator, Iterable, Set, ClassVar
import warnings
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL
@@ -9,8 +10,16 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aiohttp.web_response import StreamResponse
from pydantic import ValidationError
from .injectors import (AbstractInjector, BodyGetter, HeadersGetter,
MatchInfoGetter, QueryGetter, _parse_func_signature)
from .injectors import (
AbstractInjector,
BodyGetter,
HeadersGetter,
MatchInfoGetter,
QueryGetter,
_parse_func_signature,
CONTEXT,
Group,
)
class PydanticView(AbstractView):
@@ -18,30 +27,46 @@ class PydanticView(AbstractView):
An AIOHTTP View that validate request using function annotations.
"""
# Allowed HTTP methods; overridden when subclassed.
allowed_methods: ClassVar[Set[str]] = {}
async def _iter(self) -> StreamResponse:
method = getattr(self, self.request.method.lower(), None)
resp = await method()
return resp
if (method_name := self.request.method) not in self.allowed_methods:
self._raise_allowed_methods()
return await getattr(self, method_name.lower())()
def __await__(self) -> Generator[Any, None, StreamResponse]:
return self._iter().__await__()
def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, **kwargs) -> None:
"""Define allowed methods and decorate handlers.
Handlers are decorated if and only if they directly bound on the PydanticView class or
PydanticView subclass. This prevents that methods are decorated multiple times and that method
defined in aiohttp.View parent class is decorated.
"""
cls.allowed_methods = {
meth_name for meth_name in METH_ALL if hasattr(cls, meth_name.lower())
}
for meth_name in METH_ALL:
if meth_name not in cls.allowed_methods:
setattr(cls, meth_name.lower(), cls.raise_not_allowed)
else:
if meth_name.lower() in vars(cls):
handler = getattr(cls, meth_name.lower())
decorated_handler = inject_params(handler, cls.parse_func_signature)
setattr(cls, meth_name.lower(), decorated_handler)
async def raise_not_allowed(self):
def _raise_allowed_methods(self) -> None:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
def raise_not_allowed(self) -> None:
warnings.warn(
"PydanticView.raise_not_allowed is deprecated and renamed _raise_allowed_methods",
DeprecationWarning,
stacklevel=2,
)
self._raise_allowed_methods()
@staticmethod
def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]:
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
@@ -65,6 +90,22 @@ class PydanticView(AbstractView):
injectors.append(HeadersGetter(header_args, default_value(header_args)))
return injectors
async def on_validation_error(
self, exception: ValidationError, context: CONTEXT
) -> StreamResponse:
"""
This method is a hook to intercept ValidationError.
This hook can be redefined to return a custom HTTP response error.
The exception is a pydantic.ValidationError and the context is "body",
"headers", "path" or "query string"
"""
errors = exception.errors()
for error in errors:
error["in"] = context
return json_response(data=errors, status=400)
def inject_params(
handler, parse_func_signature: Callable[[Callable], Iterable[AbstractInjector]]
@@ -89,11 +130,7 @@ def inject_params(
if self.request.app['raise_validation_errors']:
raise
else:
errors = error.errors()
for error in errors:
error["in"] = injector.context
return json_response(data=errors, status=400)
return await self.on_validation_error(error, injector.context)
return await handler(self, *args, **kwargs)
@@ -109,3 +146,14 @@ def is_pydantic_view(obj) -> bool:
return issubclass(obj, PydanticView)
except TypeError:
return False
__all__ = (
"AbstractInjector",
"BodyGetter",
"HeadersGetter",
"MatchInfoGetter",
"QueryGetter",
"CONTEXT",
"Group",
)