Add group parameter feature

This commit is contained in:
Vincent Maillol
2021-09-26 19:08:39 +02:00
parent 4a49d3b53d
commit 799080bbd0
11 changed files with 472 additions and 34 deletions

View File

@@ -1,5 +1,5 @@
from .view import PydanticView
__version__ = "1.11.0"
__version__ = "1.12.0"
__all__ = ("PydanticView", "__version__")

View File

@@ -1,16 +1,16 @@
import abc
import typing
from inspect import signature
from inspect import signature, getmro
from json.decoder import JSONDecodeError
from typing import Callable, Tuple, Literal
from types import SimpleNamespace
from typing import Callable, Tuple, Literal, Type
from aiohttp.web_exceptions import HTTPBadRequest
from aiohttp.web_request import BaseRequest
from multidict import MultiDict
from pydantic import BaseModel
from .utils import is_pydantic_base_model
from .utils import is_pydantic_base_model, robuste_issubclass
CONTEXT = Literal["body", "headers", "path", "query string"]
@@ -20,6 +20,8 @@ class AbstractInjector(metaclass=abc.ABCMeta):
An injector parse HTTP request and inject params to the view.
"""
model: Type[BaseModel]
@property
@abc.abstractmethod
def context(self) -> CONTEXT:
@@ -96,8 +98,17 @@ class QueryGetter(AbstractInjector):
context = "query string"
def __init__(self, args_spec: dict, default_values: dict):
args_spec = args_spec.copy()
self._groups = {}
for group_name, group in args_spec.items():
if robuste_issubclass(group, Group):
self._groups[group_name] = (group, _get_group_signature(group)[0])
_unpack_group_in_signature(args_spec, default_values)
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("QueryModel", (BaseModel,), attrs)
self.args_spec = args_spec
self._is_multiple = frozenset(
@@ -105,7 +116,14 @@ class QueryGetter(AbstractInjector):
)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
kwargs_view.update(self.model(**self._query_to_dict(request.query)).dict())
data = self._query_to_dict(request.query)
cleaned = self.model(**data).dict()
for group_name, (group_cls, group_attrs) in self._groups.items():
group = group_cls()
for attr_name in group_attrs:
setattr(group, attr_name, cleaned.pop(attr_name))
cleaned[group_name] = group
kwargs_view.update(**cleaned)
def _query_to_dict(self, query: MultiDict):
"""
@@ -130,18 +148,74 @@ class HeadersGetter(AbstractInjector):
context = "headers"
def __init__(self, args_spec: dict, default_values: dict):
args_spec = args_spec.copy()
self._groups = {}
for group_name, group in args_spec.items():
if robuste_issubclass(group, Group):
self._groups[group_name] = (group, _get_group_signature(group)[0])
_unpack_group_in_signature(args_spec, default_values)
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("HeaderModel", (BaseModel,), attrs)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
kwargs_view.update(self.model(**header).dict())
cleaned = self.model(**header).dict()
for group_name, (group_cls, group_attrs) in self._groups.items():
group = group_cls()
for attr_name in group_attrs:
setattr(group, attr_name, cleaned.pop(attr_name))
cleaned[group_name] = group
kwargs_view.update(cleaned)
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]:
class Group(SimpleNamespace):
"""
Analyse function signature and returns 4-tuple:
Class to group header or query string parameters.
The parameter from query string or header will be set in the group
and the group will be passed as function parameter.
Example:
class Pagination(Group):
current_page: int = 1
page_size: int = 15
class PetView(PydanticView):
def get(self, page: Pagination):
...
"""
def _get_group_signature(cls) -> Tuple[dict, dict]:
"""
Analyse Group subclass annotations and return them with default values.
"""
sig = {}
defaults = {}
mro = getmro(cls)
for base in reversed(mro[: mro.index(Group)]):
attrs = vars(base)
for attr_name, type_ in base.__annotations__.items():
sig[attr_name] = type_
if (default := attrs.get(attr_name)) is None:
defaults.pop(attr_name, None)
else:
defaults[attr_name] = default
return sig, defaults
def _parse_func_signature(
func: Callable, unpack_group: bool = False
) -> Tuple[dict, dict, dict, dict, dict]:
"""
Analyse function signature and returns 5-tuple:
0 - arguments will be set from the url path
1 - argument will be set from the request body.
2 - argument will be set from the query string.
@@ -178,4 +252,46 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]
else:
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
if unpack_group:
try:
_unpack_group_in_signature(qs_args, defaults)
_unpack_group_in_signature(header_args, defaults)
except DuplicateNames as error:
raise TypeError(
f"Parameters conflict in function {func},"
f" the group {error.group} has an attribute named {error.attr_name}"
) from None
return path_args, body_args, qs_args, header_args, defaults
class DuplicateNames(Exception):
"""
Raised when a same parameter name is used in group and function signature.
"""
group: Type[Group]
attr_name: str
def __init__(self, group: Type[Group], attr_name: str):
self.group = group
self.attr_name = attr_name
super().__init__(
f"Conflict with {group}.{attr_name} and function parameter name"
)
def _unpack_group_in_signature(args: dict, defaults: dict) -> None:
"""
Unpack in place each Group found in args.
"""
for group_name, group in args.copy().items():
if robuste_issubclass(group, Group):
group_sig, group_default = _get_group_signature(group)
for attr_name in group_sig:
if attr_name in args and attr_name != group_name:
raise DuplicateNames(group, attr_name)
del args[group_name]
args.update(group_sig)
defaults.update(group_default)

View File

@@ -81,7 +81,7 @@ def _add_http_method_to_oas(
oas_operation: OperationObject = getattr(oas_path, http_method)
handler = getattr(view, http_method)
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
handler
handler, unpack_group=True
)
description = getdoc(handler)
if description:

View File

@@ -5,7 +5,15 @@ def is_pydantic_base_model(obj):
"""
Return true is obj is a pydantic.BaseModel subclass.
"""
return robuste_issubclass(obj, BaseModel)
def robuste_issubclass(cls1, cls2):
"""
function likes issubclass but returns False instead of raise type error
if first parameter is not a class.
"""
try:
return issubclass(obj, BaseModel)
return issubclass(cls1, cls2)
except TypeError:
return False

View File

@@ -18,6 +18,7 @@ from .injectors import (
QueryGetter,
_parse_func_signature,
CONTEXT,
Group,
)
@@ -142,3 +143,14 @@ def is_pydantic_view(obj) -> bool:
return issubclass(obj, PydanticView)
except TypeError:
return False
__all__ = (
"AbstractInjector",
"BodyGetter",
"HeadersGetter",
"MatchInfoGetter",
"QueryGetter",
"CONTEXT",
"Group",
)