63
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										63
									
								
								README.rst
									
									
									
									
									
								
							| @@ -316,6 +316,69 @@ Open Api Specification. | |||||||
|             return web.Response(status=204) |             return web.Response(status=204) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Group parameters | ||||||
|  | ---------------- | ||||||
|  |  | ||||||
|  | If your method has lot of parameters you can group them together inside one or several Groups. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | .. code-block:: python3 | ||||||
|  |  | ||||||
|  |     class Pagination(Group): | ||||||
|  |         page_num: int = 1 | ||||||
|  |         page_size: int = 15 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     class ArticleView(PydanticView): | ||||||
|  |  | ||||||
|  |         async def get(self, page: Pagination): | ||||||
|  |             articles = Article.get(page.page_num, page.page_size) | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | The parameters page_num and page_size are expected in the query string, and | ||||||
|  | set inside a Pagination object passed as page parameter. | ||||||
|  |  | ||||||
|  | The code above is equivalent to: | ||||||
|  |  | ||||||
|  |  | ||||||
|  | .. code-block:: python3 | ||||||
|  |  | ||||||
|  |     class ArticleView(PydanticView): | ||||||
|  |  | ||||||
|  |         async def get(self, page_num: int = 1, page_size: int = 15): | ||||||
|  |             articles = Article.get(page_num, page_size) | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | You can add methods or properties to your Group. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | .. code-block:: python3 | ||||||
|  |  | ||||||
|  |     class Pagination(Group): | ||||||
|  |         page_num: int = 1 | ||||||
|  |         page_size: int = 15 | ||||||
|  |  | ||||||
|  |         @property | ||||||
|  |         def num(self): | ||||||
|  |             return self.page_num | ||||||
|  |  | ||||||
|  |         @property | ||||||
|  |         def size(self): | ||||||
|  |             return self.page_size | ||||||
|  |  | ||||||
|  |         def slice(self): | ||||||
|  |             return slice(self.num, self.size) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     class ArticleView(PydanticView): | ||||||
|  |  | ||||||
|  |         async def get(self, page: Pagination): | ||||||
|  |             articles = Article.get(page.num, page.size) | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |  | ||||||
| Custom Validation error | Custom Validation error | ||||||
| ----------------------- | ----------------------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| from .view import PydanticView | from .view import PydanticView | ||||||
|  |  | ||||||
| __version__ = "1.11.0" | __version__ = "1.12.0" | ||||||
|  |  | ||||||
| __all__ = ("PydanticView", "__version__") | __all__ = ("PydanticView", "__version__") | ||||||
|   | |||||||
| @@ -1,16 +1,16 @@ | |||||||
| import abc | import abc | ||||||
| import typing | import typing | ||||||
| from inspect import signature | from inspect import signature, getmro | ||||||
| from json.decoder import JSONDecodeError | from json.decoder import JSONDecodeError | ||||||
| from typing import Callable, Tuple, Literal | from types import SimpleNamespace | ||||||
|  | from typing import Callable, Tuple, Literal, Type | ||||||
|  |  | ||||||
| from aiohttp.web_exceptions import HTTPBadRequest | from aiohttp.web_exceptions import HTTPBadRequest | ||||||
| from aiohttp.web_request import BaseRequest | from aiohttp.web_request import BaseRequest | ||||||
| from multidict import MultiDict | from multidict import MultiDict | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
|  |  | ||||||
| from .utils import is_pydantic_base_model | from .utils import is_pydantic_base_model, robuste_issubclass | ||||||
|  |  | ||||||
|  |  | ||||||
| CONTEXT = Literal["body", "headers", "path", "query string"] | CONTEXT = Literal["body", "headers", "path", "query string"] | ||||||
|  |  | ||||||
| @@ -20,6 +20,8 @@ class AbstractInjector(metaclass=abc.ABCMeta): | |||||||
|     An injector parse HTTP request and inject params to the view. |     An injector parse HTTP request and inject params to the view. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     model: Type[BaseModel] | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def context(self) -> CONTEXT: |     def context(self) -> CONTEXT: | ||||||
| @@ -96,8 +98,17 @@ class QueryGetter(AbstractInjector): | |||||||
|     context = "query string" |     context = "query string" | ||||||
|  |  | ||||||
|     def __init__(self, args_spec: dict, default_values: dict): |     def __init__(self, args_spec: dict, default_values: dict): | ||||||
|  |         args_spec = args_spec.copy() | ||||||
|  |  | ||||||
|  |         self._groups = {} | ||||||
|  |         for group_name, group in args_spec.items(): | ||||||
|  |             if robuste_issubclass(group, Group): | ||||||
|  |                 self._groups[group_name] = (group, _get_group_signature(group)[0]) | ||||||
|  |  | ||||||
|  |         _unpack_group_in_signature(args_spec, default_values) | ||||||
|         attrs = {"__annotations__": args_spec} |         attrs = {"__annotations__": args_spec} | ||||||
|         attrs.update(default_values) |         attrs.update(default_values) | ||||||
|  |  | ||||||
|         self.model = type("QueryModel", (BaseModel,), attrs) |         self.model = type("QueryModel", (BaseModel,), attrs) | ||||||
|         self.args_spec = args_spec |         self.args_spec = args_spec | ||||||
|         self._is_multiple = frozenset( |         self._is_multiple = frozenset( | ||||||
| @@ -105,7 +116,14 @@ class QueryGetter(AbstractInjector): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): |     def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): | ||||||
|         kwargs_view.update(self.model(**self._query_to_dict(request.query)).dict()) |         data = self._query_to_dict(request.query) | ||||||
|  |         cleaned = self.model(**data).dict() | ||||||
|  |         for group_name, (group_cls, group_attrs) in self._groups.items(): | ||||||
|  |             group = group_cls() | ||||||
|  |             for attr_name in group_attrs: | ||||||
|  |                 setattr(group, attr_name, cleaned.pop(attr_name)) | ||||||
|  |             cleaned[group_name] = group | ||||||
|  |         kwargs_view.update(**cleaned) | ||||||
|  |  | ||||||
|     def _query_to_dict(self, query: MultiDict): |     def _query_to_dict(self, query: MultiDict): | ||||||
|         """ |         """ | ||||||
| @@ -130,18 +148,74 @@ class HeadersGetter(AbstractInjector): | |||||||
|     context = "headers" |     context = "headers" | ||||||
|  |  | ||||||
|     def __init__(self, args_spec: dict, default_values: dict): |     def __init__(self, args_spec: dict, default_values: dict): | ||||||
|  |         args_spec = args_spec.copy() | ||||||
|  |  | ||||||
|  |         self._groups = {} | ||||||
|  |         for group_name, group in args_spec.items(): | ||||||
|  |             if robuste_issubclass(group, Group): | ||||||
|  |                 self._groups[group_name] = (group, _get_group_signature(group)[0]) | ||||||
|  |  | ||||||
|  |         _unpack_group_in_signature(args_spec, default_values) | ||||||
|  |  | ||||||
|         attrs = {"__annotations__": args_spec} |         attrs = {"__annotations__": args_spec} | ||||||
|         attrs.update(default_values) |         attrs.update(default_values) | ||||||
|         self.model = type("HeaderModel", (BaseModel,), attrs) |         self.model = type("HeaderModel", (BaseModel,), attrs) | ||||||
|  |  | ||||||
|     def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): |     def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict): | ||||||
|         header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()} |         header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()} | ||||||
|         kwargs_view.update(self.model(**header).dict()) |         cleaned = self.model(**header).dict() | ||||||
|  |         for group_name, (group_cls, group_attrs) in self._groups.items(): | ||||||
|  |             group = group_cls() | ||||||
|  |             for attr_name in group_attrs: | ||||||
|  |                 setattr(group, attr_name, cleaned.pop(attr_name)) | ||||||
|  |             cleaned[group_name] = group | ||||||
|  |         kwargs_view.update(cleaned) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]: | class Group(SimpleNamespace): | ||||||
|     """ |     """ | ||||||
|     Analyse function signature and returns 4-tuple: |     Class to group header or query string parameters. | ||||||
|  |  | ||||||
|  |     The parameter from query string or header will be set in the group | ||||||
|  |     and the group will be passed as function parameter. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |  | ||||||
|  |     class Pagination(Group): | ||||||
|  |         current_page: int = 1 | ||||||
|  |         page_size: int = 15 | ||||||
|  |  | ||||||
|  |     class PetView(PydanticView): | ||||||
|  |         def get(self, page: Pagination): | ||||||
|  |             ... | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_group_signature(cls) -> Tuple[dict, dict]: | ||||||
|  |     """ | ||||||
|  |     Analyse Group subclass annotations and return them with default values. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     sig = {} | ||||||
|  |     defaults = {} | ||||||
|  |     mro = getmro(cls) | ||||||
|  |     for base in reversed(mro[: mro.index(Group)]): | ||||||
|  |         attrs = vars(base) | ||||||
|  |         for attr_name, type_ in base.__annotations__.items(): | ||||||
|  |             sig[attr_name] = type_ | ||||||
|  |             if (default := attrs.get(attr_name)) is None: | ||||||
|  |                 defaults.pop(attr_name, None) | ||||||
|  |             else: | ||||||
|  |                 defaults[attr_name] = default | ||||||
|  |  | ||||||
|  |     return sig, defaults | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _parse_func_signature( | ||||||
|  |     func: Callable, unpack_group: bool = False | ||||||
|  | ) -> Tuple[dict, dict, dict, dict, dict]: | ||||||
|  |     """ | ||||||
|  |     Analyse function signature and returns 5-tuple: | ||||||
|         0 - arguments will be set from the url path |         0 - arguments will be set from the url path | ||||||
|         1 - argument will be set from the request body. |         1 - argument will be set from the request body. | ||||||
|         2 - argument will be set from the query string. |         2 - argument will be set from the query string. | ||||||
| @@ -178,4 +252,46 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict] | |||||||
|         else: |         else: | ||||||
|             raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters") |             raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters") | ||||||
|  |  | ||||||
|  |     if unpack_group: | ||||||
|  |         try: | ||||||
|  |             _unpack_group_in_signature(qs_args, defaults) | ||||||
|  |             _unpack_group_in_signature(header_args, defaults) | ||||||
|  |         except DuplicateNames as error: | ||||||
|  |             raise TypeError( | ||||||
|  |                 f"Parameters conflict in function {func}," | ||||||
|  |                 f" the group {error.group} has an attribute named {error.attr_name}" | ||||||
|  |             ) from None | ||||||
|  |  | ||||||
|     return path_args, body_args, qs_args, header_args, defaults |     return path_args, body_args, qs_args, header_args, defaults | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DuplicateNames(Exception): | ||||||
|  |     """ | ||||||
|  |     Raised when a same parameter name is used in group and function signature. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     group: Type[Group] | ||||||
|  |     attr_name: str | ||||||
|  |  | ||||||
|  |     def __init__(self, group: Type[Group], attr_name: str): | ||||||
|  |         self.group = group | ||||||
|  |         self.attr_name = attr_name | ||||||
|  |         super().__init__( | ||||||
|  |             f"Conflict with {group}.{attr_name} and function parameter name" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _unpack_group_in_signature(args: dict, defaults: dict) -> None: | ||||||
|  |     """ | ||||||
|  |     Unpack in place each Group found in args. | ||||||
|  |     """ | ||||||
|  |     for group_name, group in args.copy().items(): | ||||||
|  |         if robuste_issubclass(group, Group): | ||||||
|  |             group_sig, group_default = _get_group_signature(group) | ||||||
|  |             for attr_name in group_sig: | ||||||
|  |                 if attr_name in args and attr_name != group_name: | ||||||
|  |                     raise DuplicateNames(group, attr_name) | ||||||
|  |  | ||||||
|  |             del args[group_name] | ||||||
|  |             args.update(group_sig) | ||||||
|  |             defaults.update(group_default) | ||||||
|   | |||||||
| @@ -81,7 +81,7 @@ def _add_http_method_to_oas( | |||||||
|     oas_operation: OperationObject = getattr(oas_path, http_method) |     oas_operation: OperationObject = getattr(oas_path, http_method) | ||||||
|     handler = getattr(view, http_method) |     handler = getattr(view, http_method) | ||||||
|     path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( |     path_args, body_args, qs_args, header_args, defaults = _parse_func_signature( | ||||||
|         handler |         handler, unpack_group=True | ||||||
|     ) |     ) | ||||||
|     description = getdoc(handler) |     description = getdoc(handler) | ||||||
|     if description: |     if description: | ||||||
|   | |||||||
| @@ -5,7 +5,15 @@ def is_pydantic_base_model(obj): | |||||||
|     """ |     """ | ||||||
|     Return true is obj is a pydantic.BaseModel subclass. |     Return true is obj is a pydantic.BaseModel subclass. | ||||||
|     """ |     """ | ||||||
|  |     return robuste_issubclass(obj, BaseModel) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def robuste_issubclass(cls1, cls2): | ||||||
|  |     """ | ||||||
|  |     function likes issubclass but returns False instead of raise type error | ||||||
|  |     if first parameter is not a class. | ||||||
|  |     """ | ||||||
|     try: |     try: | ||||||
|         return issubclass(obj, BaseModel) |         return issubclass(cls1, cls2) | ||||||
|     except TypeError: |     except TypeError: | ||||||
|         return False |         return False | ||||||
|   | |||||||
| @@ -18,6 +18,7 @@ from .injectors import ( | |||||||
|     QueryGetter, |     QueryGetter, | ||||||
|     _parse_func_signature, |     _parse_func_signature, | ||||||
|     CONTEXT, |     CONTEXT, | ||||||
|  |     Group, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -142,3 +143,14 @@ def is_pydantic_view(obj) -> bool: | |||||||
|         return issubclass(obj, PydanticView) |         return issubclass(obj, PydanticView) | ||||||
|     except TypeError: |     except TypeError: | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __all__ = ( | ||||||
|  |     "AbstractInjector", | ||||||
|  |     "BodyGetter", | ||||||
|  |     "HeadersGetter", | ||||||
|  |     "MatchInfoGetter", | ||||||
|  |     "QueryGetter", | ||||||
|  |     "CONTEXT", | ||||||
|  |     "Group", | ||||||
|  | ) | ||||||
|   | |||||||
							
								
								
									
										74
									
								
								tests/test_group.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								tests/test_group.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | |||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from aiohttp_pydantic.injectors import ( | ||||||
|  |     Group, | ||||||
|  |     _get_group_signature, | ||||||
|  |     _unpack_group_in_signature, | ||||||
|  |     DuplicateNames, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_get_group_signature_with_a2b2(): | ||||||
|  |     class A(Group): | ||||||
|  |         a: int = 1 | ||||||
|  |  | ||||||
|  |     class B(Group): | ||||||
|  |         b: str = "b" | ||||||
|  |  | ||||||
|  |     class B2(B): | ||||||
|  |         b: str = "b2"  # Overwrite default value | ||||||
|  |  | ||||||
|  |     class A2(A): | ||||||
|  |         a: int  # Remove default value | ||||||
|  |  | ||||||
|  |     class A2B2(A2, B2): | ||||||
|  |         ab2: float | ||||||
|  |  | ||||||
|  |     assert ({"ab2": float, "a": int, "b": str}, {"b": "b2"}) == _get_group_signature( | ||||||
|  |         A2B2 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_unpack_group_in_signature(): | ||||||
|  |     class PaginationGroup(Group): | ||||||
|  |         page: int | ||||||
|  |         page_size: int = 20 | ||||||
|  |  | ||||||
|  |     args = {"pagination": PaginationGroup, "name": str, "age": int} | ||||||
|  |  | ||||||
|  |     default = {"age": 18} | ||||||
|  |  | ||||||
|  |     _unpack_group_in_signature(args, default) | ||||||
|  |  | ||||||
|  |     assert args == {"page": int, "page_size": int, "name": str, "age": int} | ||||||
|  |  | ||||||
|  |     assert default == {"age": 18, "page_size": 20} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_unpack_group_in_signature_with_duplicate_error(): | ||||||
|  |     class PaginationGroup(Group): | ||||||
|  |         page: int | ||||||
|  |         page_size: int = 20 | ||||||
|  |  | ||||||
|  |     args = {"pagination": PaginationGroup, "page": int, "age": int} | ||||||
|  |  | ||||||
|  |     with pytest.raises(DuplicateNames) as e_info: | ||||||
|  |         _unpack_group_in_signature(args, {}) | ||||||
|  |  | ||||||
|  |     assert e_info.value.group is PaginationGroup | ||||||
|  |     assert e_info.value.attr_name == "page" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_unpack_group_in_signature_with_parameters_overwrite(): | ||||||
|  |     class PaginationGroup(Group): | ||||||
|  |         page: int = 0 | ||||||
|  |         page_size: int = 20 | ||||||
|  |  | ||||||
|  |     args = {"page": PaginationGroup, "age": int} | ||||||
|  |  | ||||||
|  |     default = {} | ||||||
|  |     _unpack_group_in_signature(args, default) | ||||||
|  |  | ||||||
|  |     assert args == {"page": int, "page_size": int, "age": int} | ||||||
|  |  | ||||||
|  |     assert default == {"page": 0, "page_size": 20} | ||||||
| @@ -123,18 +123,12 @@ def test_paths_operation_tags(): | |||||||
|     oas = OpenApiSpec3() |     oas = OpenApiSpec3() | ||||||
|     operation = oas.paths["/users/{petId}"].get |     operation = oas.paths["/users/{petId}"].get | ||||||
|     assert operation.tags == [] |     assert operation.tags == [] | ||||||
|     operation.tags = ['pets'] |     operation.tags = ["pets"] | ||||||
|  |  | ||||||
|     assert oas.spec['paths']['/users/{petId}'] == { |     assert oas.spec["paths"]["/users/{petId}"] == {"get": {"tags": ["pets"]}} | ||||||
|         'get': { |  | ||||||
|             'tags': ['pets'] |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     operation.tags = [] |     operation.tags = [] | ||||||
|     assert oas.spec['paths']['/users/{petId}'] == { |     assert oas.spec["paths"]["/users/{petId}"] == {"get": {}} | ||||||
|         'get': {} |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_paths_operation_responses(): | def test_paths_operation_responses(): | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ from aiohttp import web | |||||||
| from pydantic.main import BaseModel | from pydantic.main import BaseModel | ||||||
|  |  | ||||||
| from aiohttp_pydantic import PydanticView, oas | from aiohttp_pydantic import PydanticView, oas | ||||||
|  | from aiohttp_pydantic.injectors import Group | ||||||
| from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 | from aiohttp_pydantic.oas.typing import r200, r201, r204, r404 | ||||||
| from aiohttp_pydantic.oas.view import generate_oas | from aiohttp_pydantic.oas.view import generate_oas | ||||||
|  |  | ||||||
| @@ -76,6 +77,24 @@ class ViewResponseReturnASimpleType(PydanticView): | |||||||
|         return web.json_response() |         return web.json_response() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def ensure_content_durability(client): | ||||||
|  |     """ | ||||||
|  |     Reload the page 2 times to ensure that content is always the same | ||||||
|  |     note: pydantic can return a cached dict, if a view updates the dict the | ||||||
|  |     output will be incoherent | ||||||
|  |     """ | ||||||
|  |     response_1 = await client.get("/oas/spec") | ||||||
|  |     assert response_1.status == 200 | ||||||
|  |     assert response_1.content_type == "application/json" | ||||||
|  |     content_1 = await response_1.json() | ||||||
|  |  | ||||||
|  |     response_2 = await client.get("/oas/spec") | ||||||
|  |     content_2 = await response_2.json() | ||||||
|  |     assert content_1 == content_2 | ||||||
|  |  | ||||||
|  |     return content_2 | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| async def generated_oas(aiohttp_client, loop) -> web.Application: | async def generated_oas(aiohttp_client, loop) -> web.Application: | ||||||
|     app = web.Application() |     app = web.Application() | ||||||
| @@ -84,20 +103,7 @@ async def generated_oas(aiohttp_client, loop) -> web.Application: | |||||||
|     app.router.add_view("/simple-type", ViewResponseReturnASimpleType) |     app.router.add_view("/simple-type", ViewResponseReturnASimpleType) | ||||||
|     oas.setup(app) |     oas.setup(app) | ||||||
|  |  | ||||||
|     client = await aiohttp_client(app) |     return await ensure_content_durability(await aiohttp_client(app)) | ||||||
|     response_1 = await client.get("/oas/spec") |  | ||||||
|     assert response_1.content_type == "application/json" |  | ||||||
|     assert response_1.status == 200 |  | ||||||
|     content_1 = await response_1.json() |  | ||||||
|  |  | ||||||
|     # Reload the page to ensure that content is always the same |  | ||||||
|     # note: pydantic can return a cached dict, if a view updates |  | ||||||
|     # the dict the output will be incoherent |  | ||||||
|     response_2 = await client.get("/oas/spec") |  | ||||||
|     content_2 = await response_2.json() |  | ||||||
|     assert content_1 == content_2 |  | ||||||
|  |  | ||||||
|     return content_2 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def test_generated_oas_should_have_components_schemas(generated_oas): | async def test_generated_oas_should_have_components_schemas(generated_oas): | ||||||
| @@ -377,3 +383,29 @@ async def test_generated_view_info_as_title(): | |||||||
|         "info": {"title": "test title", "version": "1.0.0"}, |         "info": {"title": "test title", "version": "1.0.0"}, | ||||||
|         "openapi": "3.0.0", |         "openapi": "3.0.0", | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_use_parameters_group_should_not_impact_the_oas(aiohttp_client): | ||||||
|  |     class PetCollectionView1(PydanticView): | ||||||
|  |         async def get(self, page: int = 1, page_size: int = 20) -> r200[List[Pet]]: | ||||||
|  |             return web.json_response() | ||||||
|  |  | ||||||
|  |     class Pagination(Group): | ||||||
|  |         page: int = 1 | ||||||
|  |         page_size: int = 20 | ||||||
|  |  | ||||||
|  |     class PetCollectionView2(PydanticView): | ||||||
|  |         async def get(self, pagination: Pagination) -> r200[List[Pet]]: | ||||||
|  |             return web.json_response() | ||||||
|  |  | ||||||
|  |     app1 = web.Application() | ||||||
|  |     app1.router.add_view("/pets", PetCollectionView1) | ||||||
|  |     oas.setup(app1) | ||||||
|  |  | ||||||
|  |     app2 = web.Application() | ||||||
|  |     app2.router.add_view("/pets", PetCollectionView2) | ||||||
|  |     oas.setup(app2) | ||||||
|  |  | ||||||
|  |     assert await ensure_content_durability( | ||||||
|  |         await aiohttp_client(app1) | ||||||
|  |     ) == await ensure_content_durability(await aiohttp_client(app2)) | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ from enum import Enum | |||||||
| from aiohttp import web | from aiohttp import web | ||||||
|  |  | ||||||
| from aiohttp_pydantic import PydanticView | from aiohttp_pydantic import PydanticView | ||||||
|  | from aiohttp_pydantic.injectors import Group | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONEncoder(json.JSONEncoder): | class JSONEncoder(json.JSONEncoder): | ||||||
| @@ -32,6 +33,31 @@ class ViewWithEnumType(PydanticView): | |||||||
|         return web.json_response({"format": format}, dumps=JSONEncoder().encode) |         return web.json_response({"format": format}, dumps=JSONEncoder().encode) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Signature(Group): | ||||||
|  |     signature_expired: datetime | ||||||
|  |     signature_scope: str = "read" | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def expired(self) -> datetime: | ||||||
|  |         return self.signature_expired | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def scope(self) -> str: | ||||||
|  |         return self.signature_scope | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ArticleViewWithSignatureGroup(PydanticView): | ||||||
|  |     async def get( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         signature: Signature, | ||||||
|  |     ): | ||||||
|  |         return web.json_response( | ||||||
|  |             {"expired": signature.expired, "scope": signature.scope}, | ||||||
|  |             dumps=JSONEncoder().encode, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def test_get_article_without_required_header_should_return_an_error_message( | async def test_get_article_without_required_header_should_return_an_error_message( | ||||||
|     aiohttp_client, loop |     aiohttp_client, loop | ||||||
| ): | ): | ||||||
| @@ -134,3 +160,21 @@ async def test_correct_value_to_header_defined_with_str_enum(aiohttp_client, loo | |||||||
|     assert await resp.json() == {"format": "UMT"} |     assert await resp.json() == {"format": "UMT"} | ||||||
|     assert resp.status == 200 |     assert resp.status == 200 | ||||||
|     assert resp.content_type == "application/json" |     assert resp.content_type == "application/json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_with_signature_group(aiohttp_client, loop): | ||||||
|  |     app = web.Application() | ||||||
|  |     app.router.add_view("/article", ArticleViewWithSignatureGroup) | ||||||
|  |  | ||||||
|  |     client = await aiohttp_client(app) | ||||||
|  |     resp = await client.get( | ||||||
|  |         "/article", | ||||||
|  |         headers={ | ||||||
|  |             "signature_expired": "2020-10-04T18:01:00", | ||||||
|  |             "signature.scope": "write", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     assert await resp.json() == {"expired": "2020-10-04T18:01:00", "scope": "read"} | ||||||
|  |     assert resp.status == 200 | ||||||
|  |     assert resp.content_type == "application/json" | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ from pydantic import Field | |||||||
| from aiohttp import web | from aiohttp import web | ||||||
|  |  | ||||||
| from aiohttp_pydantic import PydanticView | from aiohttp_pydantic import PydanticView | ||||||
|  | from aiohttp_pydantic.injectors import Group | ||||||
|  |  | ||||||
|  |  | ||||||
| class ArticleView(PydanticView): | class ArticleView(PydanticView): | ||||||
| @@ -23,6 +24,34 @@ class ArticleView(PydanticView): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Pagination(Group): | ||||||
|  |     page_num: int | ||||||
|  |     page_size: int = 20 | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def num(self) -> int: | ||||||
|  |         return self.page_num | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def size(self) -> int: | ||||||
|  |         return self.page_size | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ArticleViewWithPaginationGroup(PydanticView): | ||||||
|  |     async def get( | ||||||
|  |         self, | ||||||
|  |         with_comments: bool, | ||||||
|  |         page: Pagination, | ||||||
|  |     ): | ||||||
|  |         return web.json_response( | ||||||
|  |             { | ||||||
|  |                 "with_comments": with_comments, | ||||||
|  |                 "page_num": page.num, | ||||||
|  |                 "page_size": page.size, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def test_get_article_without_required_qs_should_return_an_error_message( | async def test_get_article_without_required_qs_should_return_an_error_message( | ||||||
|     aiohttp_client, loop |     aiohttp_client, loop | ||||||
| ): | ): | ||||||
| @@ -158,3 +187,69 @@ async def test_get_article_with_one_value_of_tags_must_be_a_list(aiohttp_client, | |||||||
|     } |     } | ||||||
|     assert resp.status == 200 |     assert resp.status == 200 | ||||||
|     assert resp.content_type == "application/json" |     assert resp.content_type == "application/json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_get_article_without_required_field_page(aiohttp_client, loop): | ||||||
|  |     app = web.Application() | ||||||
|  |     app.router.add_view("/article", ArticleViewWithPaginationGroup) | ||||||
|  |  | ||||||
|  |     client = await aiohttp_client(app) | ||||||
|  |  | ||||||
|  |     resp = await client.get("/article", params={"with_comments": 1}) | ||||||
|  |     assert await resp.json() == [ | ||||||
|  |         { | ||||||
|  |             "in": "query string", | ||||||
|  |             "loc": ["page_num"], | ||||||
|  |             "msg": "field required", | ||||||
|  |             "type": "value_error.missing", | ||||||
|  |         } | ||||||
|  |     ] | ||||||
|  |     assert resp.status == 400 | ||||||
|  |     assert resp.content_type == "application/json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_get_article_with_page(aiohttp_client, loop): | ||||||
|  |     app = web.Application() | ||||||
|  |     app.router.add_view("/article", ArticleViewWithPaginationGroup) | ||||||
|  |  | ||||||
|  |     client = await aiohttp_client(app) | ||||||
|  |  | ||||||
|  |     resp = await client.get("/article", params={"with_comments": 1, "page_num": 2}) | ||||||
|  |     assert await resp.json() == {"page_num": 2, "page_size": 20, "with_comments": True} | ||||||
|  |     assert resp.status == 200 | ||||||
|  |     assert resp.content_type == "application/json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_get_article_with_page_and_page_size(aiohttp_client, loop): | ||||||
|  |     app = web.Application() | ||||||
|  |     app.router.add_view("/article", ArticleViewWithPaginationGroup) | ||||||
|  |  | ||||||
|  |     client = await aiohttp_client(app) | ||||||
|  |  | ||||||
|  |     resp = await client.get( | ||||||
|  |         "/article", params={"with_comments": 1, "page_num": 1, "page_size": 10} | ||||||
|  |     ) | ||||||
|  |     assert await resp.json() == {"page_num": 1, "page_size": 10, "with_comments": True} | ||||||
|  |     assert resp.status == 200 | ||||||
|  |     assert resp.content_type == "application/json" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_get_article_with_page_and_wrong_page_size(aiohttp_client, loop): | ||||||
|  |     app = web.Application() | ||||||
|  |     app.router.add_view("/article", ArticleViewWithPaginationGroup) | ||||||
|  |  | ||||||
|  |     client = await aiohttp_client(app) | ||||||
|  |  | ||||||
|  |     resp = await client.get( | ||||||
|  |         "/article", params={"with_comments": 1, "page_num": 1, "page_size": "large"} | ||||||
|  |     ) | ||||||
|  |     assert await resp.json() == [ | ||||||
|  |         { | ||||||
|  |             "in": "query string", | ||||||
|  |             "loc": ["page_size"], | ||||||
|  |             "msg": "value is not a valid integer", | ||||||
|  |             "type": "type_error.integer", | ||||||
|  |         } | ||||||
|  |     ] | ||||||
|  |     assert resp.status == 400 | ||||||
|  |     assert resp.content_type == "application/json" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user