Skip to content

Commit

Permalink
fix(exception-handling): Fix #2147 - setting app debug does not propa…
Browse files Browse the repository at this point in the history
…gate to exception handling middleware (#2153)

* fix(exception-handling): Fix #2147 - setting app debug dynamically does not propagate to exception handling middleware.

Change ExceptionHandlerMiddleware to use the current application's debug value rather than assigning it statically at creation / startup time.
  • Loading branch information
provinzkraut authored Aug 12, 2023
1 parent edbe1c9 commit 5dbbedf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 35 deletions.
4 changes: 2 additions & 2 deletions litestar/cli/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uvicorn
from rich.tree import Tree

from litestar.cli._utils import RICH_CLICK_INSTALLED, console, show_app_info
from litestar.cli._utils import RICH_CLICK_INSTALLED, LitestarEnv, console, show_app_info
from litestar.routes import HTTPRoute, WebSocketRoute
from litestar.utils.helpers import unwrap_partial

Expand Down Expand Up @@ -118,7 +118,7 @@ def run_command(
if pdb:
ctx.obj.app.pdb_on_exception = True

env = ctx.obj
env: LitestarEnv = ctx.obj
app = env.app

reload_dirs = env.reload_dirs or reload_dir
Expand Down
26 changes: 22 additions & 4 deletions litestar/middleware/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from litestar.middleware.exceptions._debug_response import create_debug_response
from litestar.serialization import encode_json
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.utils.deprecation import warn_deprecation

__all__ = ("ExceptionHandlerMiddleware", "ExceptionResponseContent", "create_exception_response")

Expand Down Expand Up @@ -147,17 +148,33 @@ class ExceptionHandlerMiddleware:
This used in multiple layers of Litestar.
"""

def __init__(self, app: ASGIApp, debug: bool, exception_handlers: ExceptionHandlersMap) -> None:
def __init__(self, app: ASGIApp, debug: bool | None, exception_handlers: ExceptionHandlersMap) -> None:
"""Initialize ``ExceptionHandlerMiddleware``.
Args:
app: The ``next`` ASGI app to call.
debug: Whether ``debug`` mode is enabled
debug: Whether ``debug`` mode is enabled. Deprecated. Debug mode will be inferred from the request scope
exception_handlers: A dictionary mapping status codes and/or exception types to handler functions.
.. deprecated:: 2.0.0
The ``debug`` parameter is deprecated. It will be inferred from the request scope
"""
self.app = app
self.exception_handlers = exception_handlers
self.debug = debug
if debug is not None:
warn_deprecation(
"2.0.0",
deprecated_name="debug",
kind="parameter",
info="Debug mode will be inferred from the request scope",
)

self._get_debug = self._get_debug_scope if debug is None else lambda *a: debug

@staticmethod
def _get_debug_scope(scope: Scope) -> bool:
return scope["app"].debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI-callable.
Expand Down Expand Up @@ -249,7 +266,7 @@ def default_http_exception_handler(self, request: Request, exc: Exception) -> Re
An HTTP response.
"""
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self.debug:
if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self._get_debug_scope(request.scope):
return create_debug_response(request=request, exc=exc)
return create_exception_response(request=request, exc=exc)

Expand All @@ -265,6 +282,7 @@ def handle_exception_logging(self, logger: Logger, logging_config: BaseLoggingCo
None
"""
if (
logging_config.log_exceptions == "always" or (logging_config.log_exceptions == "debug" and self.debug)
logging_config.log_exceptions == "always"
or (logging_config.log_exceptions == "debug" and self._get_debug_scope(scope))
) and logging_config.exception_logging_handler:
logging_config.exception_logging_handler(logger, scope, format_exception(*exc_info()))
91 changes: 62 additions & 29 deletions tests/unit/test_middleware/test_exception_handler_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional

import pytest
from _pytest.capture import CaptureFixture
Expand All @@ -7,17 +7,14 @@
from structlog.testing import capture_logs

from litestar import Litestar, Request, Response, get
from litestar.exceptions import (
HTTPException,
InternalServerException,
ValidationException,
)
from litestar.exceptions import HTTPException, InternalServerException, ValidationException
from litestar.logging.config import LoggingConfig, StructLoggingConfig
from litestar.middleware.exceptions import ExceptionHandlerMiddleware
from litestar.middleware.exceptions.middleware import get_exception_handler
from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing import TestClient, create_test_client
from litestar.types import ExceptionHandlersMap
from litestar.types.asgi_types import HTTPScope

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
Expand All @@ -30,13 +27,26 @@ async def dummy_app(scope: Any, receive: Any, send: Any) -> None:
return None


middleware = ExceptionHandlerMiddleware(dummy_app, False, {})
@pytest.fixture()
def app() -> Litestar:
return Litestar()


@pytest.fixture()
def middleware() -> ExceptionHandlerMiddleware:
return ExceptionHandlerMiddleware(dummy_app, False, {})


@pytest.fixture()
def scope(create_scope: Callable[..., HTTPScope], app: Litestar) -> HTTPScope:
return create_scope(app=app)


def test_default_handle_http_exception_handling_extra_object() -> None:
def test_default_handle_http_exception_handling_extra_object(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), # type: ignore
HTTPException(detail="litestar_exception", extra={"key": "value"}),
Request(scope=scope), HTTPException(detail="litestar_exception", extra={"key": "value"})
)
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {
Expand All @@ -46,28 +56,31 @@ def test_default_handle_http_exception_handling_extra_object() -> None:
}


def test_default_handle_http_exception_handling_extra_none() -> None:
def test_default_handle_http_exception_handling_extra_none(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), # type: ignore
HTTPException(detail="litestar_exception"),
Request(scope=scope), HTTPException(detail="litestar_exception")
)
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {"detail": "Internal Server Error", "status_code": 500}


def test_default_handle_litestar_http_exception_handling() -> None:
def test_default_handle_litestar_http_exception_handling(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), # type: ignore
HTTPException(detail="litestar_exception"),
Request(scope=scope), HTTPException(detail="litestar_exception")
)
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {"detail": "Internal Server Error", "status_code": 500}


def test_default_handle_litestar_http_exception_extra_list() -> None:
def test_default_handle_litestar_http_exception_extra_list(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), # type: ignore
HTTPException(detail="litestar_exception", extra=["extra-1", "extra-2"]),
Request(scope=scope), HTTPException(detail="litestar_exception", extra=["extra-1", "extra-2"])
)
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {
Expand All @@ -77,22 +90,21 @@ def test_default_handle_litestar_http_exception_extra_list() -> None:
}


def test_default_handle_starlette_http_exception_handling() -> None:
def test_default_handle_starlette_http_exception_handling(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), # type: ignore
Request(scope=scope),
StarletteHTTPException(detail="litestar_exception", status_code=HTTP_500_INTERNAL_SERVER_ERROR),
)
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {
"detail": "Internal Server Error",
"status_code": 500,
}
assert response.content == {"detail": "Internal Server Error", "status_code": 500}


def test_default_handle_python_http_exception_handling() -> None:
response = middleware.default_http_exception_handler(
Request(scope={"type": "http", "method": "GET"}), AttributeError("oops") # type: ignore
)
def test_default_handle_python_http_exception_handling(
scope: HTTPScope, middleware: ExceptionHandlerMiddleware
) -> None:
response = middleware.default_http_exception_handler(Request(scope=scope), AttributeError("oops"))
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert response.content == {
"detail": "Internal Server Error",
Expand Down Expand Up @@ -307,3 +319,24 @@ def handler() -> None:

assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
mock_post_mortem.assert_called_once()


def test_get_debug_from_scope(get_logger: "GetLogger", caplog: "LogCaptureFixture") -> None:
@get("/test")
def handler() -> None:
raise ValueError("Test debug exception")

app = Litestar([handler], debug=False)
app.debug = True

with caplog.at_level("ERROR", "litestar"), TestClient(app=app) as client:
client.app.logger = get_logger("litestar")
response = client.get("/test")

assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert "Test debug exception" in response.text
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "ERROR"
assert caplog.records[0].message.startswith(
"exception raised on http connection to route /test\n\nTraceback (most recent call last):\n"
)

0 comments on commit 5dbbedf

Please sign in to comment.