Compare commits

...

5 Commits

Author SHA1 Message Date
jar3b
6ba671532b feat: add raise_validation_errors parameter to setup()
To control how pydantic.ValidationError will be handled, own handler (return json) or raise exception to allow intercept in aiohttp middleware
2020-11-27 02:02:54 +03:00
jar3b
efbaaa5e6f fix: to push 2020-11-25 16:59:10 +03:00
jar3b
6211c71875 fix: detect x-forwarded-proto if deployed behind proxy
For static files handling, was set up in "oas_ui()"
2020-11-25 16:25:04 +03:00
jar3b
5567d73952 fix: return copy of schema without 'definitions' key
Instead of delete key 'definitions' from schema, bc schema was "cached" and if you try to load swagger twice, you got "no definitions" exception
2020-11-25 01:55:01 +03:00
jar3b
67a95ec9c9 fix: move response definitions to top level of oas 2020-11-25 01:23:10 +03:00
4 changed files with 32 additions and 9 deletions

View File

@ -13,10 +13,13 @@ def setup(
apps_to_expose: Iterable[web.Application] = (), apps_to_expose: Iterable[web.Application] = (),
url_prefix: str = "/oas", url_prefix: str = "/oas",
enable: bool = True, enable: bool = True,
raise_validation_errors: bool = False,
): ):
if enable: if enable:
oas_app = web.Application() oas_app = web.Application()
oas_app["apps to expose"] = tuple(apps_to_expose) or (app,) oas_app["apps to expose"] = tuple(apps_to_expose) or (app,)
for a in oas_app["apps to expose"]:
a['raise_validation_errors'] = raise_validation_errors
oas_app["index template"] = jinja2.Template( oas_app["index template"] = jinja2.Template(
resources.read_text("aiohttp_pydantic.oas", "index.j2") resources.read_text("aiohttp_pydantic.oas", "index.j2")
) )

View File

@ -312,3 +312,8 @@ class OpenApiSpec3:
@property @property
def spec(self): def spec(self):
return self._spec return self._spec
@property
def definitions(self):
self._spec.setdefault('definitions', {})
return self._spec['definitions']

View File

@ -47,13 +47,21 @@ class _OASResponseBuilder:
generate the OAS operation response. generate the OAS operation response.
""" """
def __init__(self, oas_operation): def __init__(self, oas_operation, definitions):
self._oas_operation = oas_operation self._oas_operation = oas_operation
self._definitions = definitions
@staticmethod def _process_definitions(self, schema):
def _handle_pydantic_base_model(obj): if 'definitions' in schema:
for k, v in schema['definitions'].items():
self._definitions[k] = v
return {i:schema[i] for i in schema if i!='definitions'}
def _handle_pydantic_base_model(self, obj):
if is_pydantic_base_model(obj): if is_pydantic_base_model(obj):
return obj.schema() return self._process_definitions(obj.schema())
return {} return {}
def _handle_list(self, obj): def _handle_list(self, obj):
@ -88,7 +96,7 @@ class _OASResponseBuilder:
def _add_http_method_to_oas( def _add_http_method_to_oas(
oas_path: PathItem, http_method: str, view: Type[PydanticView] oas_path: PathItem, http_method: str, view: Type[PydanticView], definitions: dict
): ):
http_method = http_method.lower() http_method = http_method.lower()
oas_operation: OperationObject = getattr(oas_path, http_method) oas_operation: OperationObject = getattr(oas_path, http_method)
@ -123,7 +131,7 @@ def _add_http_method_to_oas(
return_type = handler.__annotations__.get("return") return_type = handler.__annotations__.get("return")
if return_type is not None: if return_type is not None:
_OASResponseBuilder(oas_operation).build(return_type) _OASResponseBuilder(oas_operation, definitions).build(return_type)
def generate_oas(apps: List[Application]) -> dict: def generate_oas(apps: List[Application]) -> dict:
@ -131,6 +139,7 @@ def generate_oas(apps: List[Application]) -> dict:
Generate and return Open Api Specification from PydanticView in application. Generate and return Open Api Specification from PydanticView in application.
""" """
oas = OpenApiSpec3() oas = OpenApiSpec3()
for app in apps: for app in apps:
for resources in app.router.resources(): for resources in app.router.resources():
for resource_route in resources: for resource_route in resources:
@ -140,9 +149,9 @@ def generate_oas(apps: List[Application]) -> dict:
path = oas.paths[info.get("path", info.get("formatter"))] path = oas.paths[info.get("path", info.get("formatter"))]
if resource_route.method == "*": if resource_route.method == "*":
for method_name in view.allowed_methods: for method_name in view.allowed_methods:
_add_http_method_to_oas(path, method_name, view) _add_http_method_to_oas(path, method_name, view, oas.definitions)
else: else:
_add_http_method_to_oas(path, resource_route.method, view) _add_http_method_to_oas(path, resource_route.method, view, oas.definitions)
return oas.spec return oas.spec
@ -163,6 +172,9 @@ async def oas_ui(request):
static_url = request.app.router["static"].url_for(filename="") static_url = request.app.router["static"].url_for(filename="")
spec_url = request.app.router["spec"].url_for() spec_url = request.app.router["spec"].url_for()
if request.scheme != request.headers.get('x-forwarded-proto', request.scheme):
request = request.clone(scheme=request.headers['x-forwarded-proto'])
host = request.url.origin() host = request.url.origin()
return Response( return Response(

View File

@ -83,7 +83,10 @@ def inject_params(
else: else:
injector.inject(self.request, args, kwargs) injector.inject(self.request, args, kwargs)
except ValidationError as error: except ValidationError as error:
return json_response(text=error.json(), status=400) if self.request.app['raise_validation_errors']:
raise
else:
return json_response(text=error.json(), status=400)
return await handler(self, *args, **kwargs) return await handler(self, *args, **kwargs)