increase pydantic integration with headers, query string and url path

This commit is contained in:
Vincent Maillol
2020-11-28 12:39:09 +01:00
parent 93ec0f6c80
commit c4c18ee4a1
12 changed files with 199 additions and 61 deletions

View File

@@ -15,8 +15,16 @@ class AbstractInjector(metaclass=abc.ABCMeta):
An injector parse HTTP request and inject params to the view.
"""
@property
@abc.abstractmethod
def __init__(self, args_spec: dict):
def context(self) -> str:
"""
The name of part of parsed request
i.e "HTTP header", "URL path", ...
"""
@abc.abstractmethod
def __init__(self, args_spec: dict, default_values: dict):
"""
args_spec - ordered mapping: arg_name -> type
"""
@@ -33,8 +41,12 @@ class MatchInfoGetter(AbstractInjector):
Validates and injects the part of URL path inside the view positional args.
"""
def __init__(self, args_spec: dict):
self.model = type("PathModel", (BaseModel,), {"__annotations__": args_spec})
context = "path"
def __init__(self, args_spec: dict, default_values: dict):
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("PathModel", (BaseModel,), attrs)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
args_view.extend(self.model(**request.match_info).dict().values())
@@ -45,7 +57,9 @@ class BodyGetter(AbstractInjector):
Validates and injects the content of request body inside the view kwargs.
"""
def __init__(self, args_spec: dict):
context = "body"
def __init__(self, args_spec: dict, default_values: dict):
self.arg_name, self.model = next(iter(args_spec.items()))
async def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
@@ -64,8 +78,12 @@ class QueryGetter(AbstractInjector):
Validates and injects the query string inside the view kwargs.
"""
def __init__(self, args_spec: dict):
self.model = type("QueryModel", (BaseModel,), {"__annotations__": args_spec})
context = "query string"
def __init__(self, args_spec: dict, default_values: dict):
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("QueryModel", (BaseModel,), attrs)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
kwargs_view.update(self.model(**request.query).dict())
@@ -76,27 +94,33 @@ class HeadersGetter(AbstractInjector):
Validates and injects the HTTP headers inside the view kwargs.
"""
def __init__(self, args_spec: dict):
self.model = type("HeaderModel", (BaseModel,), {"__annotations__": args_spec})
context = "headers"
def __init__(self, args_spec: dict, default_values: dict):
attrs = {"__annotations__": args_spec}
attrs.update(default_values)
self.model = type("HeaderModel", (BaseModel,), attrs)
def inject(self, request: BaseRequest, args_view: list, kwargs_view: dict):
header = {k.lower().replace("-", "_"): v for k, v in request.headers.items()}
kwargs_view.update(self.model(**header).dict())
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]:
def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict, dict]:
"""
Analyse function signature and returns 4-tuple:
0 - arguments will be set from the url path
1 - argument will be set from the request body.
2 - argument will be set from the query string.
3 - argument will be set from the HTTP headers.
4 - Default value for each parameters
"""
path_args = {}
body_args = {}
qs_args = {}
header_args = {}
defaults = {}
for param_name, param_spec in signature(func).parameters.items():
if param_name == "self":
@@ -105,8 +129,12 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]:
if param_spec.annotation == param_spec.empty:
raise RuntimeError(f"The parameter {param_name} must have an annotation")
if param_spec.default is not param_spec.empty:
defaults[param_name] = param_spec.default
if param_spec.kind is param_spec.POSITIONAL_ONLY:
path_args[param_name] = param_spec.annotation
elif param_spec.kind is param_spec.POSITIONAL_OR_KEYWORD:
if is_pydantic_base_model(param_spec.annotation):
body_args[param_name] = param_spec.annotation
@@ -117,4 +145,4 @@ def _parse_func_signature(func: Callable) -> Tuple[dict, dict, dict, dict]:
else:
raise RuntimeError(f"You cannot use {param_spec.VAR_POSITIONAL} parameters")
return path_args, body_args, qs_args, header_args
return path_args, body_args, qs_args, header_args, defaults

View File

@@ -13,7 +13,7 @@ Example:
from functools import lru_cache
from types import new_class
from typing import Protocol, TypeVar, Optional, Type
from typing import Protocol, TypeVar
RespContents = TypeVar("RespContents", covariant=True)

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from aiohttp.web import Response, json_response
from aiohttp.web_app import Application
from pydantic import BaseModel
from aiohttp_pydantic.oas.struct import OpenApiSpec3, OperationObject, PathItem
@@ -93,7 +94,9 @@ def _add_http_method_to_oas(
http_method = http_method.lower()
oas_operation: OperationObject = getattr(oas_path, http_method)
handler = getattr(view, http_method)
path_args, body_args, qs_args, header_args = _parse_func_signature(handler)
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
handler
)
description = getdoc(handler)
if description:
oas_operation.description = description
@@ -114,12 +117,15 @@ def _add_http_method_to_oas(
oas_operation.parameters[i].in_ = args_location
oas_operation.parameters[i].name = name
optional_type = _handle_optional(type_)
if optional_type is None:
oas_operation.parameters[i].schema = JSON_SCHEMA_TYPES[type_]
oas_operation.parameters[i].required = True
else:
oas_operation.parameters[i].schema = JSON_SCHEMA_TYPES[optional_type]
oas_operation.parameters[i].required = False
attrs = {"__annotations__": {"__root__": type_}}
if name in defaults:
attrs["__root__"] = defaults[name]
oas_operation.parameters[i].schema = type(
name, (BaseModel,), attrs
).schema()
oas_operation.parameters[i].required = optional_type is None
return_type = handler.__annotations__.get("return")
if return_type is not None:
@@ -134,15 +140,17 @@ def generate_oas(apps: List[Application]) -> dict:
for app in apps:
for resources in app.router.resources():
for resource_route in resources:
if is_pydantic_view(resource_route.handler):
view: Type[PydanticView] = resource_route.handler
info = resource_route.get_info()
path = oas.paths[info.get("path", info.get("formatter"))]
if resource_route.method == "*":
for method_name in view.allowed_methods:
_add_http_method_to_oas(path, method_name, view)
else:
_add_http_method_to_oas(path, resource_route.method, view)
if not is_pydantic_view(resource_route.handler):
continue
view: Type[PydanticView] = resource_route.handler
info = resource_route.get_info()
path = oas.paths[info.get("path", info.get("formatter"))]
if resource_route.method == "*":
for method_name in view.allowed_methods:
_add_http_method_to_oas(path, method_name, view)
else:
_add_http_method_to_oas(path, resource_route.method, view)
return oas.spec

View File

@@ -9,14 +9,8 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aiohttp.web_response import StreamResponse
from pydantic import ValidationError
from .injectors import (
AbstractInjector,
BodyGetter,
HeadersGetter,
MatchInfoGetter,
QueryGetter,
_parse_func_signature,
)
from .injectors import (AbstractInjector, BodyGetter, HeadersGetter,
MatchInfoGetter, QueryGetter, _parse_func_signature)
class PydanticView(AbstractView):
@@ -50,16 +44,25 @@ class PydanticView(AbstractView):
@staticmethod
def parse_func_signature(func: Callable) -> Iterable[AbstractInjector]:
path_args, body_args, qs_args, header_args = _parse_func_signature(func)
path_args, body_args, qs_args, header_args, defaults = _parse_func_signature(
func
)
injectors = []
def default_value(args: dict) -> dict:
"""
Returns the default values of args.
"""
return {name: defaults[name] for name in args if name in defaults}
if path_args:
injectors.append(MatchInfoGetter(path_args))
injectors.append(MatchInfoGetter(path_args, default_value(path_args)))
if body_args:
injectors.append(BodyGetter(body_args))
injectors.append(BodyGetter(body_args, default_value(body_args)))
if qs_args:
injectors.append(QueryGetter(qs_args))
injectors.append(QueryGetter(qs_args, default_value(qs_args)))
if header_args:
injectors.append(HeadersGetter(header_args))
injectors.append(HeadersGetter(header_args, default_value(header_args)))
return injectors
@@ -83,7 +86,11 @@ def inject_params(
else:
injector.inject(self.request, args, kwargs)
except ValidationError as error:
return json_response(text=error.json(), status=400)
errors = error.errors()
for error in errors:
error["in"] = injector.context
return json_response(data=errors, status=400)
return await handler(self, *args, **kwargs)