From 67321a2c11a866854ac6381eedecf52e2284237e Mon Sep 17 00:00:00 2001 From: Mate Zoltan Date: Fri, 7 Feb 2025 19:20:46 +0000 Subject: [PATCH] Validate responses from exception handlers to avoid inconsistent OpenAPI specification In the OpenAPI specification of the request handlers, it is possible to declare response schemas for different response status, but if the response is generated by an exception handler, it was not validated like normal responses. This could lead to inconsistent OpenAPI specification and implementation. --- ninja/operation.py | 28 +++++++++++++++++++--------- tests/test_exceptions.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/ninja/operation.py b/ninja/operation.py index 0e4976d6e..6b5a15360 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -123,13 +123,17 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: temporal_response = self.api.create_temporal_response(request) values = self._get_values(request, kw, temporal_response) result = self.view_func(request, **values) + + if isinstance(result, HttpResponseBase): + return result + return self._result_to_response(request, result, temporal_response) except Exception as e: if isinstance(e, TypeError) and "required positional argument" in str(e): msg = "Did you fail to use functools.wraps() in a decorator?" msg = f"{e.args[0]}: {msg}" if e.args else msg e.args = (msg,) + e.args[1:] - return self.api.on_exception(request, e) + return self._on_exception(request, e) def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.api = api @@ -158,6 +162,12 @@ def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: if router.tags is not None: self.tags = router.tags + def _on_exception(self, request: HttpRequest, exc: Exception) -> HttpResponse: + temporal_response = self.api.create_temporal_response(request) + result = self.api.on_exception(request, exc) + + return self._result_to_response(request, result, temporal_response) + def _set_auth( self, auth: Optional[Union[Sequence[Callable], Callable, object]] ) -> None: @@ -196,12 +206,12 @@ def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: else: result = callback(request) except Exception as exc: - return self.api.on_exception(request, exc) + return self._on_exception(request, exc) if result: request.auth = result # type: ignore return None - return self.api.on_exception(request, AuthenticationError()) + return self._on_exception(request, AuthenticationError()) def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: throttle_durations = [] @@ -216,19 +226,19 @@ def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: ] duration = max(durations, default=None) - return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore + return self._on_exception(request, Throttled(wait=duration)) # type: ignore return None def _result_to_response( self, request: HttpRequest, result: Any, temporal_response: HttpResponse - ) -> HttpResponseBase: + ) -> HttpResponse: """ The protocol for results - if HttpResponse - returns as is - if tuple with 2 elements - means http_code + body - otherwise it's a body """ - if isinstance(result, HttpResponseBase): + if isinstance(result, HttpResponse): return result status: int = 200 @@ -338,7 +348,7 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ result = await self.view_func(request, **values) return self._result_to_response(request, result, temporal_response) except Exception as e: - return self.api.on_exception(request, e) + return self._on_exception(request, e) async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore "Runs security checks for each operation" @@ -376,12 +386,12 @@ async def _run_authentication(self, request: HttpRequest) -> Optional[HttpRespon else: result = callback(request) except Exception as exc: - return self.api.on_exception(request, exc) + return self._on_exception(request, exc) if result: request.auth = result # type: ignore return None - return self.api.on_exception(request, AuthenticationError()) + return self._on_exception(request, AuthenticationError()) class PathView: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e3e314fa1..1d27f25e0 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,9 @@ import pytest from django.http import Http404 +from pydantic import ValidationError from ninja import NinjaAPI, Schema +from ninja.errors import ConfigError from ninja.testing import TestAsyncClient, TestClient api = NinjaAPI() @@ -97,3 +99,33 @@ def thrower(request): with pytest.raises(RuntimeError): client.get("/error") + + +def test_improper_response_body_from_exception_handler(): + @api.exception_handler(RuntimeError) + def on_runtime_error(request, exc): + return 418, {"payload": "non-proper"} + + @api.get("/error", response={418: Payload}) + def thrower(request): + raise RuntimeError + + client = TestClient(api) + + with pytest.raises(ValidationError): + client.get("/error") + + +def test_non_configured_status_code_from_exception_handler(): + @api.exception_handler(RuntimeError) + def on_runtime_error(request, exc): + return 410, Payload(test=1234) + + @api.get("/error", response={418: Payload}) + def thrower(request): + raise RuntimeError + + client = TestClient(api) + + with pytest.raises(ConfigError): + client.get("/error")