Add group parameter feature
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
import abc
|
||||
import typing
|
||||
from inspect import signature
|
||||
from inspect import signature, getmro
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Callable, Tuple, Literal
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, Tuple, Literal, Type
|
||||
|
||||
from aiohttp.web_exceptions import HTTPBadRequest
|
||||
from aiohttp.web_request import BaseRequest
|
||||
from multidict import MultiDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .utils import is_pydantic_base_model
|
||||
|
||||
from .utils import is_pydantic_base_model, robuste_issubclass
|
||||
|
||||
CONTEXT = Literal["body", "headers", "path", "query string"]
|
||||
|
||||
@@ -20,6 +20,8 @@ class AbstractInjector(metaclass=abc.ABCMeta):
|
||||
An injector parse HTTP request and inject params to the view.
|
||||
"""
|
||||
|
||||
model: Type[BaseModel]
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def context(self) -> CONTEXT:
|
||||
@@ -96,8 +98,17 @@ class QueryGetter(AbstractInjector):
|
||||
context = "query string"
|
||||
|
||||
def __init__(self, args_spec: dict, default_values: dict):
|
||||
args_spec = args_spec.copy()
|
||||
|
||||
self._groups = {}
|
||||
for group_name, group in args_spec.items():
|
||||
if robuste_issubclass(group, Group):
|
||||
self._groups[group_name] = (group, _get_group_signature(group)[0])
|
||||
|
||||
_unpack_group_in_signature(args_spec, default_values)
|
||||
attrs = {"__annotations__": args_spec}
|
||||
attrs.update(default_values)
|
||||
|
||||
self.model = type("QueryModel", (BaseModel,), attrs)
|
||||
self.args_spec = args_spec
|
||||
self._is_multiple = frozenset(
|
||||
@@ -105,7 +116,14 @@ class QueryGetter(AbstractInjector):
|
||||
)
|
||||
|
||||
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||
kwargs_view.update(self.model(**self._query_to_dict(request.query)).dict())
|
||||
data = self._query_to_dict(request.query)
|
||||
cleaned = self.model(**data).dict()
|
||||
for group_name, (group_cls, group_attrs) in self._groups.items():
|
||||
group = group_cls()
|
||||
for attr_name in group_attrs:
|
||||
setattr(group, attr_name, cleaned.pop(attr_name))
|
||||
cleaned[group_name] = group
|
||||
kwargs_view.update(**cleaned)
|
||||
|
||||
def _query_to_dict(self, query: MultiDict):
|
||||
"""
|
||||
@@ -130,18 +148,74 @@ class HeadersGetter(AbstractInjector):
|
||||
context = "headers"
|
||||
|
||||
def __init__(self, args_spec: dict, default_values: dict):
|
||||
args_spec = args_spec.copy()
|
||||
|
||||
self._groups = {}
|
||||
for group_name, group in args_spec.items():
|
||||
if robuste_issubclass(group, Group):
|
||||
self._groups[group_name] = (group, _get_group_signature(group)[0])
|
||||
|
||||
_unpack_group_in_signature(args_spec, default_values)
|
||||
|
||||
attrs = {"__annotations__": args_spec}
|
||||
attrs.update(default_values)
|
||||
self.model = type("HeaderModel", (BaseModel,), attrs)
|
||||
|
||||
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
|
||||
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
|
||||
kwargs_view.update(self.model(**header).dict())
|
||||
cleaned = self.model(**header).dict()
|
||||
for group_name, (group_cls, group_attrs) in self._groups.items():
|
||||
group = group_cls()
|
||||
for attr_name in group_attrs:
|
||||
setattr(group, attr_name, cleaned.pop(attr_name))
|
||||
cleaned[group_name] = group
|
||||
kwargs_view.update(cleaned)
|
||||
|
||||
|
||||
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]:
|
||||
class Group(SimpleNamespace):
|
||||
"""
|
||||
Analyse function signature and returns 4-tuple:
|
||||
Class to group header or query string parameters.
|
||||
|
||||
The parameter from query string or header will be set in the group
|
||||
and the group will be passed as function parameter.
|
||||
|
||||
Example:
|
||||
|
||||
class Pagination(Group):
|
||||
current_page: int = 1
|
||||
page_size: int = 15
|
||||
|
||||
class PetView(PydanticView):
|
||||
def get(self, page: Pagination):
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
def _get_group_signature(cls) -> Tuple[dict, dict]:
|
||||
"""
|
||||
Analyse Group subclass annotations and return them with default values.
|
||||
"""
|
||||
|
||||
sig = {}
|
||||
defaults = {}
|
||||
mro = getmro(cls)
|
||||
for base in reversed(mro[: mro.index(Group)]):
|
||||
attrs = vars(base)
|
||||
for attr_name, type_ in base.__annotations__.items():
|
||||
sig[attr_name] = type_
|
||||
if (default := attrs.get(attr_name)) is None:
|
||||
defaults.pop(attr_name, None)
|
||||
else:
|
||||
defaults[attr_name] = default
|
||||
|
||||
return sig, defaults
|
||||
|
||||
|
||||
def _parse_func_signature(
|
||||
func: Callable, unpack_group: bool = False
|
||||
) -> Tuple[dict, dict, dict, dict, dict]:
|
||||
"""
|
||||
Analyse function signature and returns 5-tuple:
|
||||
0 - arguments will be set from the url path
|
||||
1 - argument will be set from the request body.
|
||||
2 - argument will be set from the query string.
|
||||
@@ -178,4 +252,46 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]
|
||||
else:
|
||||
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
|
||||
|
||||
if unpack_group:
|
||||
try:
|
||||
_unpack_group_in_signature(qs_args, defaults)
|
||||
_unpack_group_in_signature(header_args, defaults)
|
||||
except DuplicateNames as error:
|
||||
raise TypeError(
|
||||
f"Parameters conflict in function {func},"
|
||||
f" the group {error.group} has an attribute named {error.attr_name}"
|
||||
) from None
|
||||
|
||||
return path_args, body_args, qs_args, header_args, defaults
|
||||
|
||||
|
||||
class DuplicateNames(Exception):
|
||||
"""
|
||||
Raised when a same parameter name is used in group and function signature.
|
||||
"""
|
||||
|
||||
group: Type[Group]
|
||||
attr_name: str
|
||||
|
||||
def __init__(self, group: Type[Group], attr_name: str):
|
||||
self.group = group
|
||||
self.attr_name = attr_name
|
||||
super().__init__(
|
||||
f"Conflict with {group}.{attr_name} and function parameter name"
|
||||
)
|
||||
|
||||
|
||||
def _unpack_group_in_signature(args: dict, defaults: dict) -> None:
|
||||
"""
|
||||
Unpack in place each Group found in args.
|
||||
"""
|
||||
for group_name, group in args.copy().items():
|
||||
if robuste_issubclass(group, Group):
|
||||
group_sig, group_default = _get_group_signature(group)
|
||||
for attr_name in group_sig:
|
||||
if attr_name in args and attr_name != group_name:
|
||||
raise DuplicateNames(group, attr_name)
|
||||
|
||||
del args[group_name]
|
||||
args.update(group_sig)
|
||||
defaults.update(group_default)
|
||||
|
||||
Reference in New Issue
Block a user