diff --git a/README.rst b/README.rst index 69b005a..3c5eb03 100644 --- a/README.rst +++ b/README.rst @@ -316,6 +316,69 @@ Open Api Specification. return web.Response(status=204) +Group parameters +---------------- + +If your method has lot of parameters you can group them together inside one or several Groups. + + +.. code-block:: python3 + + class Pagination(Group): + page_num: int = 1 + page_size: int = 15 + + + class ArticleView(PydanticView): + + async def get(self, page: Pagination): + articles = Article.get(page.page_num, page.page_size) + ... + + +The parameters page_num and page_size are expected in the query string, and +set inside a Pagination object passed as page parameter. + +The code above is equivalent to: + + +.. code-block:: python3 + + class ArticleView(PydanticView): + + async def get(self, page_num: int = 1, page_size: int = 15): + articles = Article.get(page_num, page_size) + ... + + +You can add methods or properties to your Group. + + +.. code-block:: python3 + + class Pagination(Group): + page_num: int = 1 + page_size: int = 15 + + @property + def num(self): + return self.page_num + + @property + def size(self): + return self.page_size + + def slice(self): + return slice(self.num, self.size) + + + class ArticleView(PydanticView): + + async def get(self, page: Pagination): + articles = Article.get(page.num, page.size) + ... + + Custom Validation error ----------------------- diff --git a/aiohttp_pydantic/__init__.py b/aiohttp_pydantic/__init__.py index e938416..81d1052 100644 --- a/aiohttp_pydantic/__init__.py +++ b/aiohttp_pydantic/__init__.py @@ -1,5 +1,5 @@ from .view import PydanticView -__version__ = "1.11.0" +__version__ = "1.12.0" __all__ = ("PydanticView", "__version__") diff --git a/aiohttp_pydantic/injectors.py b/aiohttp_pydantic/injectors.py index 8147c4c..848474c 100644 --- a/aiohttp_pydantic/injectors.py +++ b/aiohttp_pydantic/injectors.py @@ -1,16 +1,16 @@ import abc import typing -from inspect import signature +from inspect import signature, getmro from json.decoder import JSONDecodeError -from typing import Callable, Tuple, Literal +from types import SimpleNamespace +from typing import Callable, Tuple, Literal, Type from aiohttp.web_exceptions import HTTPBadRequest from aiohttp.web_request import BaseRequest from multidict import MultiDict from pydantic import BaseModel -from .utils import is_pydantic_base_model - +from .utils import is_pydantic_base_model, robuste_issubclass CONTEXT = Literal["body", "headers", "path", "query string"] @@ -20,6 +20,8 @@ class AbstractInjector(metaclass=abc.ABCMeta): An injector parse HTTP request and inject params to the view. """ + model: Type[BaseModel] + @property @abc.abstractmethod def context(self) -> CONTEXT: @@ -96,8 +98,17 @@ class QueryGetter(AbstractInjector): context = "query string" def __init__(self, args_spec: dict, default_values: dict): + args_spec = args_spec.copy() + + self._groups = {} + for group_name, group in args_spec.items(): + if robuste_issubclass(group, Group): + self._groups[group_name] = (group, _get_group_signature(group)[0]) + + _unpack_group_in_signature(args_spec, default_values) attrs = {"__annotations__": args_spec} attrs.update(default_values) + self.model = type("QueryModel", (BaseModel,), attrs) self.args_spec = args_spec self._is_multiple = frozenset( @@ -105,7 +116,14 @@ class QueryGetter(AbstractInjector): ) def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): - kwargs_view.update(self.model(**self._query_to_dict(request.query)).dict()) + data = self._query_to_dict(request.query) + cleaned = self.model(**data).dict() + for group_name, (group_cls, group_attrs) in self._groups.items(): + group = group_cls() + for attr_name in group_attrs: + setattr(group, attr_name, cleaned.pop(attr_name)) + cleaned[group_name] = group + kwargs_view.update(**cleaned) def _query_to_dict(self, query: MultiDict): """ @@ -130,18 +148,74 @@ class HeadersGetter(AbstractInjector): context = "headers" def __init__(self, args_spec: dict, default_values: dict): + args_spec = args_spec.copy() + + self._groups = {} + for group_name, group in args_spec.items(): + if robuste_issubclass(group, Group): + self._groups[group_name] = (group, _get_group_signature(group)[0]) + + _unpack_group_in_signature(args_spec, default_values) + attrs = {"__annotations__": args_spec} attrs.update(default_values) self.model = type("HeaderModel", (BaseModel,), attrs) def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()} - kwargs_view.update(self.model(**header).dict()) + cleaned = self.model(**header).dict() + for group_name, (group_cls, group_attrs) in self._groups.items(): + group = group_cls() + for attr_name in group_attrs: + setattr(group, attr_name, cleaned.pop(attr_name)) + cleaned[group_name] = group + kwargs_view.update(cleaned) -def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]: +class Group(SimpleNamespace): """ - Analyse function signature and returns 4-tuple: + Class to group header or query string parameters. + + The parameter from query string or header will be set in the group + and the group will be passed as function parameter. + + Example: + + class Pagination(Group): + current_page: int = 1 + page_size: int = 15 + + class PetView(PydanticView): + def get(self, page: Pagination): + ... + """ + + +def _get_group_signature(cls) -> Tuple[dict, dict]: + """ + Analyse Group subclass annotations and return them with default values. + """ + + sig = {} + defaults = {} + mro = getmro(cls) + for base in reversed(mro[: mro.index(Group)]): + attrs = vars(base) + for attr_name, type_ in base.__annotations__.items(): + sig[attr_name] = type_ + if (default := attrs.get(attr_name)) is None: + defaults.pop(attr_name, None) + else: + defaults[attr_name] = default + + return sig, defaults + + +def _parse_func_signature( + func: Callable, unpack_group: bool = False +) -> Tuple[dict, dict, dict, dict, dict]: + """ + Analyse function signature and returns 5-tuple: 0 - arguments will be set from the url path 1 - argument will be set from the request body. 2 - argument will be set from the query string. @@ -178,4 +252,46 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict] else: raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters") + if unpack_group: + try: + _unpack_group_in_signature(qs_args, defaults) + _unpack_group_in_signature(header_args, defaults) + except DuplicateNames as error: + raise TypeError( + f"Parameters conflict in function {func}," + f" the group {error.group} has an attribute named {error.attr_name}" + ) from None + return path_args, body_args, qs_args, header_args, defaults + + +class DuplicateNames(Exception): + """ + Raised when a same parameter name is used in group and function signature. + """ + + group: Type[Group] + attr_name: str + + def __init__(self, group: Type[Group], attr_name: str): + self.group = group + self.attr_name = attr_name + super().__init__( + f"Conflict with {group}.{attr_name} and function parameter name" + ) + + +def _unpack_group_in_signature(args: dict, defaults: dict) -> None: + """ + Unpack in place each Group found in args. + """ + for group_name, group in args.copy().items(): + if robuste_issubclass(group, Group): + group_sig, group_default = _get_group_signature(group) + for attr_name in group_sig: + if attr_name in args and attr_name != group_name: + raise DuplicateNames(group, attr_name) + + del args[group_name] + args.update(group_sig) + defaults.update(group_default) diff --git a/aiohttp_pydantic/oas/view.py b/aiohttp_pydantic/oas/view.py index 778513f..2d10609 100644 --- a/aiohttp_pydantic/oas/view.py +++ b/aiohttp_pydantic/oas/view.py @@ -81,7 +81,7 @@ def _add_http_method_to_oas( oas_operation: OperationObject = getattr(oas_path, http_method) handler = getattr(view, http_method) path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( - handler + handler, unpack_group=True ) description = getdoc(handler) if description: diff --git a/aiohttp_pydantic/utils.py b/aiohttp_pydantic/utils.py index efee749..92644a1 100644 --- a/aiohttp_pydantic/utils.py +++ b/aiohttp_pydantic/utils.py @@ -5,7 +5,15 @@ def is_pydantic_base_model(obj): """ Return true is obj is a pydantic.BaseModel subclass. """ + return robuste_issubclass(obj, BaseModel) + + +def robuste_issubclass(cls1, cls2): + """ + function likes issubclass but returns False instead of raise type error + if first parameter is not a class. + """ try: - return issubclass(obj, BaseModel) + return issubclass(cls1, cls2) except TypeError: return False diff --git a/aiohttp_pydantic/view.py b/aiohttp_pydantic/view.py index 00093bc..3b7cd65 100644 --- a/aiohttp_pydantic/view.py +++ b/aiohttp_pydantic/view.py @@ -18,6 +18,7 @@ from .injectors import ( QueryGetter, _parse_func_signature, CONTEXT, + Group, ) @@ -142,3 +143,14 @@ def is_pydantic_view(obj) -> bool: return issubclass(obj, PydanticView) except TypeError: return False + + +__all__ = ( + "AbstractInjector", + "BodyGetter", + "HeadersGetter", + "MatchInfoGetter", + "QueryGetter", + "CONTEXT", + "Group", +) diff --git a/tests/test_group.py b/tests/test_group.py new file mode 100644 index 0000000..24a903f --- /dev/null +++ b/tests/test_group.py @@ -0,0 +1,74 @@ +import pytest + +from aiohttp_pydantic.injectors import ( + Group, + _get_group_signature, + _unpack_group_in_signature, + DuplicateNames, +) + + +def test_get_group_signature_with_a2b2(): + class A(Group): + a: int = 1 + + class B(Group): + b: str = "b" + + class B2(B): + b: str = "b2" # Overwrite default value + + class A2(A): + a: int # Remove default value + + class A2B2(A2, B2): + ab2: float + + assert ({"ab2": float, "a": int, "b": str}, {"b": "b2"}) == _get_group_signature( + A2B2 + ) + + +def test_unpack_group_in_signature(): + class PaginationGroup(Group): + page: int + page_size: int = 20 + + args = {"pagination": PaginationGroup, "name": str, "age": int} + + default = {"age": 18} + + _unpack_group_in_signature(args, default) + + assert args == {"page": int, "page_size": int, "name": str, "age": int} + + assert default == {"age": 18, "page_size": 20} + + +def test_unpack_group_in_signature_with_duplicate_error(): + class PaginationGroup(Group): + page: int + page_size: int = 20 + + args = {"pagination": PaginationGroup, "page": int, "age": int} + + with pytest.raises(DuplicateNames) as e_info: + _unpack_group_in_signature(args, {}) + + assert e_info.value.group is PaginationGroup + assert e_info.value.attr_name == "page" + + +def test_unpack_group_in_signature_with_parameters_overwrite(): + class PaginationGroup(Group): + page: int = 0 + page_size: int = 20 + + args = {"page": PaginationGroup, "age": int} + + default = {} + _unpack_group_in_signature(args, default) + + assert args == {"page": int, "page_size": int, "age": int} + + assert default == {"page": 0, "page_size": 20} diff --git a/tests/test_oas/test_struct/test_paths.py b/tests/test_oas/test_struct/test_paths.py index 2c97182..78453cd 100644 --- a/tests/test_oas/test_struct/test_paths.py +++ b/tests/test_oas/test_struct/test_paths.py @@ -123,18 +123,12 @@ def test_paths_operation_tags(): oas = OpenApiSpec3() operation = oas.paths["/users/{petId}"].get assert operation.tags == [] - operation.tags = ['pets'] + operation.tags = ["pets"] - assert oas.spec['paths']['/users/{petId}'] == { - 'get': { - 'tags': ['pets'] - } - } + assert oas.spec["paths"]["/users/{petId}"] == {"get": {"tags": ["pets"]}} operation.tags = [] - assert oas.spec['paths']['/users/{petId}'] == { - 'get': {} - } + assert oas.spec["paths"]["/users/{petId}"] == {"get": {}} def test_paths_operation_responses(): diff --git a/tests/test_oas/test_view.py b/tests/test_oas/test_view.py index 3eb4c34..af9c338 100644 --- a/tests/test_oas/test_view.py +++ b/tests/test_oas/test_view.py @@ -7,6 +7,7 @@ from aiohttp import web from pydantic.main import BaseModel from aiohttp_pydantic import PydanticView, oas +from aiohttp_pydantic.injectors import Group from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 from aiohttp_pydantic.oas.view import generate_oas @@ -76,6 +77,24 @@ class ViewResponseReturnASimpleType(PydanticView): return web.json_response() +async def ensure_content_durability(client): + """ + Reload the page 2 times to ensure that content is always the same + note: pydantic can return a cached dict, if a view updates the dict the + output will be incoherent + """ + response_1 = await client.get("/oas/spec") + assert response_1.status == 200 + assert response_1.content_type == "application/json" + content_1 = await response_1.json() + + response_2 = await client.get("/oas/spec") + content_2 = await response_2.json() + assert content_1 == content_2 + + return content_2 + + @pytest.fixture async def generated_oas(aiohttp_client, loop) -> web.Application: app = web.Application() @@ -84,20 +103,7 @@ async def generated_oas(aiohttp_client, loop) -> web.Application: app.router.add_view("/simple-type", ViewResponseReturnASimpleType) oas.setup(app) - client = await aiohttp_client(app) - response_1 = await client.get("/oas/spec") - assert response_1.content_type == "application/json" - assert response_1.status == 200 - content_1 = await response_1.json() - - # Reload the page to ensure that content is always the same - # note: pydantic can return a cached dict, if a view updates - # the dict the output will be incoherent - response_2 = await client.get("/oas/spec") - content_2 = await response_2.json() - assert content_1 == content_2 - - return content_2 + return await ensure_content_durability(await aiohttp_client(app)) async def test_generated_oas_should_have_components_schemas(generated_oas): @@ -377,3 +383,29 @@ async def test_generated_view_info_as_title(): "info": {"title": "test title", "version": "1.0.0"}, "openapi": "3.0.0", } + + +async def test_use_parameters_group_should_not_impact_the_oas(aiohttp_client): + class PetCollectionView1(PydanticView): + async def get(self, page: int = 1, page_size: int = 20) -> r200[List[Pet]]: + return web.json_response() + + class Pagination(Group): + page: int = 1 + page_size: int = 20 + + class PetCollectionView2(PydanticView): + async def get(self, pagination: Pagination) -> r200[List[Pet]]: + return web.json_response() + + app1 = web.Application() + app1.router.add_view("/pets", PetCollectionView1) + oas.setup(app1) + + app2 = web.Application() + app2.router.add_view("/pets", PetCollectionView2) + oas.setup(app2) + + assert await ensure_content_durability( + await aiohttp_client(app1) + ) == await ensure_content_durability(await aiohttp_client(app2)) diff --git a/tests/test_validation_header.py b/tests/test_validation_header.py index e18a863..134c6ca 100644 --- a/tests/test_validation_header.py +++ b/tests/test_validation_header.py @@ -5,6 +5,7 @@ from enum import Enum from aiohttp import web from aiohttp_pydantic import PydanticView +from aiohttp_pydantic.injectors import Group class JSONEncoder(json.JSONEncoder): @@ -32,6 +33,31 @@ class ViewWithEnumType(PydanticView): return web.json_response({"format": format}, dumps=JSONEncoder().encode) +class Signature(Group): + signature_expired: datetime + signature_scope: str = "read" + + @property + def expired(self) -> datetime: + return self.signature_expired + + @property + def scope(self) -> str: + return self.signature_scope + + +class ArticleViewWithSignatureGroup(PydanticView): + async def get( + self, + *, + signature: Signature, + ): + return web.json_response( + {"expired": signature.expired, "scope": signature.scope}, + dumps=JSONEncoder().encode, + ) + + async def test_get_article_without_required_header_should_return_an_error_message( aiohttp_client, loop ): @@ -134,3 +160,21 @@ async def test_correct_value_to_header_defined_with_str_enum(aiohttp_client, loo assert await resp.json() == {"format": "UMT"} assert resp.status == 200 assert resp.content_type == "application/json" + + +async def test_with_signature_group(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/article", ArticleViewWithSignatureGroup) + + client = await aiohttp_client(app) + resp = await client.get( + "/article", + headers={ + "signature_expired": "2020-10-04T18:01:00", + "signature.scope": "write", + }, + ) + + assert await resp.json() == {"expired": "2020-10-04T18:01:00", "scope": "read"} + assert resp.status == 200 + assert resp.content_type == "application/json" diff --git a/tests/test_validation_query_string.py b/tests/test_validation_query_string.py index 4b8913f..344cea0 100644 --- a/tests/test_validation_query_string.py +++ b/tests/test_validation_query_string.py @@ -3,6 +3,7 @@ from pydantic import Field from aiohttp import web from aiohttp_pydantic import PydanticView +from aiohttp_pydantic.injectors import Group class ArticleView(PydanticView): @@ -23,6 +24,34 @@ class ArticleView(PydanticView): ) +class Pagination(Group): + page_num: int + page_size: int = 20 + + @property + def num(self) -> int: + return self.page_num + + @property + def size(self) -> int: + return self.page_size + + +class ArticleViewWithPaginationGroup(PydanticView): + async def get( + self, + with_comments: bool, + page: Pagination, + ): + return web.json_response( + { + "with_comments": with_comments, + "page_num": page.num, + "page_size": page.size, + } + ) + + async def test_get_article_without_required_qs_should_return_an_error_message( aiohttp_client, loop ): @@ -158,3 +187,69 @@ async def test_get_article_with_one_value_of_tags_must_be_a_list(aiohttp_client, } assert resp.status == 200 assert resp.content_type == "application/json" + + +async def test_get_article_without_required_field_page(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/article", ArticleViewWithPaginationGroup) + + client = await aiohttp_client(app) + + resp = await client.get("/article", params={"with_comments": 1}) + assert await resp.json() == [ + { + "in": "query string", + "loc": ["page_num"], + "msg": "field required", + "type": "value_error.missing", + } + ] + assert resp.status == 400 + assert resp.content_type == "application/json" + + +async def test_get_article_with_page(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/article", ArticleViewWithPaginationGroup) + + client = await aiohttp_client(app) + + resp = await client.get("/article", params={"with_comments": 1, "page_num": 2}) + assert await resp.json() == {"page_num": 2, "page_size": 20, "with_comments": True} + assert resp.status == 200 + assert resp.content_type == "application/json" + + +async def test_get_article_with_page_and_page_size(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/article", ArticleViewWithPaginationGroup) + + client = await aiohttp_client(app) + + resp = await client.get( + "/article", params={"with_comments": 1, "page_num": 1, "page_size": 10} + ) + assert await resp.json() == {"page_num": 1, "page_size": 10, "with_comments": True} + assert resp.status == 200 + assert resp.content_type == "application/json" + + +async def test_get_article_with_page_and_wrong_page_size(aiohttp_client, loop): + app = web.Application() + app.router.add_view("/article", ArticleViewWithPaginationGroup) + + client = await aiohttp_client(app) + + resp = await client.get( + "/article", params={"with_comments": 1, "page_num": 1, "page_size": "large"} + ) + assert await resp.json() == [ + { + "in": "query string", + "loc": ["page_size"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + assert resp.status == 400 + assert resp.content_type == "application/json"