From 7dc721975baeaa4dc6f4d1fd2ee850a1e4b654a0 Mon Sep 17 00:00:00 2001 From: aranvir <75439739+aranvir@users.noreply.github.com> Date: Mon, 4 Mar 2024 04:00:45 +0100 Subject: [PATCH] feat: add API for getting session ID (#3127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * have auth middleware wrapper use backend middleware instead of hardcoded middleware * reworking session id to be generated before the route handler * adding session_id to ScopeState instead of connection scope * fixed docstring * adjusted session_id type and handling. added middleware tests. * test linting * adjustments for test coverage --------- Co-authored-by: Janek Nouvertné Co-authored-by: guacs <126393040+guacs@users.noreply.github.com> --- .../examples/testing/test_set_session_data.py | 8 ++- litestar/connection/base.py | 7 ++- litestar/middleware/session/base.py | 12 ++++ litestar/middleware/session/client_side.py | 3 + litestar/middleware/session/server_side.py | 24 +++++++- litestar/security/session_auth/middleware.py | 3 +- litestar/testing/client/base.py | 19 +++---- litestar/utils/scope/state.py | 3 + .../test_session/test_middleware.py | 56 ++++++++++++++++++- 9 files changed, 112 insertions(+), 23 deletions(-) diff --git a/docs/examples/testing/test_set_session_data.py b/docs/examples/testing/test_set_session_data.py index 864f921aa5..913c690aa8 100644 --- a/docs/examples/testing/test_set_session_data.py +++ b/docs/examples/testing/test_set_session_data.py @@ -14,6 +14,8 @@ def get_session_data(request: Request) -> Dict[str, Any]: app = Litestar(route_handlers=[get_session_data], middleware=[session_config.middleware]) -with TestClient(app=app, session_config=session_config) as client: - client.set_session_data({"foo": "bar"}) - assert client.get("/test").json() == {"foo": "bar"} + +def test_get_session_data() -> None: + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} diff --git a/litestar/connection/base.py b/litestar/connection/base.py index 7fb7098101..d14c6620e5 100644 --- a/litestar/connection/base.py +++ b/litestar/connection/base.py @@ -9,6 +9,7 @@ from litestar.datastructures.url import URL, Address, make_absolute_url from litestar.exceptions import ImproperlyConfiguredException from litestar.types.empty import Empty +from litestar.utils.empty import value_or_default from litestar.utils.scope.state import ScopeState if TYPE_CHECKING: @@ -287,7 +288,7 @@ def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> value: Dictionary or pydantic model instance for the session data. Returns: - None. + None """ self.scope["session"] = value @@ -301,6 +302,10 @@ def clear_session(self) -> None: None. """ self.scope["session"] = Empty + self._connection_state.session_id = Empty + + def get_session_id(self) -> str | None: + return value_or_default(value=self._connection_state.session_id, default=None) def url_for(self, name: str, **path_parameters: Any) -> str: """Return the url for a given route handler name. diff --git a/litestar/middleware/session/base.py b/litestar/middleware/session/base.py index b7e2e1586a..bb39fa682b 100644 --- a/litestar/middleware/session/base.py +++ b/litestar/middleware/session/base.py @@ -145,6 +145,17 @@ def deserialize_data(data: Any) -> dict[str, Any]: """ return cast("dict[str, Any]", decode_json(value=data)) + @abstractmethod + def get_session_id(self, connection: ASGIConnection) -> str | None: + """Try to fetch session id from connection ScopeState. If one does not exist, generate one. + + Args: + connection: Originating ASGIConnection containing the scope + + Returns: + Session id str or None if the concept of a session id does not apply. + """ + @abstractmethod async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: """Store the necessary information in the outgoing ``Message`` @@ -241,5 +252,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: connection = ASGIConnection[Any, Any, Any, Any](scope, receive=receive, send=send) scope["session"] = await self.backend.load_from_connection(connection) + connection._connection_state.session_id = self.backend.get_session_id(connection) # pyright: ignore [reportGeneralTypeIssues] await self.app(scope, receive, self.create_send_wrapper(connection)) diff --git a/litestar/middleware/session/client_side.py b/litestar/middleware/session/client_side.py index cce502f64d..f709410478 100644 --- a/litestar/middleware/session/client_side.py +++ b/litestar/middleware/session/client_side.py @@ -206,6 +206,9 @@ async def load_from_connection(self, connection: ASGIConnection) -> dict[str, An return self.load_data(data) return {} + def get_session_id(self, connection: ASGIConnection) -> str | None: + return None + @dataclass class CookieBackendConfig(BaseBackendConfig[ClientSideSessionBackend]): # pyright: ignore diff --git a/litestar/middleware/session/server_side.py b/litestar/middleware/session/server_side.py index cec0011d80..91708ac80d 100644 --- a/litestar/middleware/session/server_side.py +++ b/litestar/middleware/session/server_side.py @@ -77,6 +77,26 @@ async def delete(self, session_id: str, store: Store) -> None: """ await store.delete(session_id) + def get_session_id(self, connection: ASGIConnection) -> str: + """Try to fetch session id from the connection. If one does not exist, generate one. + + If a session ID already exists in the cookies, it is returned. + If there is no ID in the cookies but one in the connection state, then the session exists but has not yet + been returned to the user. + Otherwise, a new session must be created. + + Args: + connection: Originating ASGIConnection containing the scope + Returns: + Session id str or None if the concept of a session id does not apply. + """ + session_id = connection.cookies.get(self.config.key) + if not session_id or session_id == "null": + session_id = connection.get_session_id() + if not session_id: + session_id = self.generate_session_id() + return session_id + def generate_session_id(self) -> str: """Generate a new session-ID, with n=:attr:`session_id_bytes ` random bytes. @@ -104,9 +124,7 @@ async def store_in_message(self, scope_session: ScopeSession, message: Message, scope = connection.scope store = self.config.get_store_from_app(scope["app"]) headers = MutableScopeHeaders.from_message(message) - session_id = connection.cookies.get(self.config.key) - if not session_id or session_id == "null": - session_id = self.generate_session_id() + session_id = self.get_session_id(connection) cookie_params = dict(extract_dataclass_items(self.config, exclude_none=True, include=Cookie.__dict__.keys())) diff --git a/litestar/security/session_auth/middleware.py b/litestar/security/session_auth/middleware.py index 691294fddc..bb3fce4349 100644 --- a/litestar/security/session_auth/middleware.py +++ b/litestar/security/session_auth/middleware.py @@ -8,7 +8,6 @@ AuthenticationResult, ) from litestar.middleware.exceptions import ExceptionHandlerMiddleware -from litestar.middleware.session.base import SessionMiddleware from litestar.types import Empty, Method, Scopes __all__ = ("MiddlewareWrapper", "SessionAuthMiddleware") @@ -61,7 +60,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: exception_handlers=litestar_app.exception_handlers or {}, # pyright: ignore debug=None, ) - self.app = SessionMiddleware( + self.app = self.config.session_backend_config.middleware.middleware( app=exception_middleware, backend=self.config.session_backend, ) diff --git a/litestar/testing/client/base.py b/litestar/testing/client/base.py index 93f1082a97..3c25be117b 100644 --- a/litestar/testing/client/base.py +++ b/litestar/testing/client/base.py @@ -22,7 +22,6 @@ from httpx._types import CookieTypes from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend - from litestar.middleware.session.client_side import ClientSideSessionBackend from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send T = TypeVar("T", bound=ASGIApp) @@ -155,20 +154,16 @@ def portal(self) -> Generator[BlockingPortal, None, None]: ) as portal: yield portal - @staticmethod - def _create_session_cookies(backend: ClientSideSessionBackend, data: dict[str, Any]) -> dict[str, str]: - encoded_data = backend.dump_data(data=data) - return {cookie.key: cast("str", cookie.value) for cookie in backend._create_session_cookies(encoded_data)} - async def _set_session_data(self, data: dict[str, Any]) -> None: mutable_headers = MutableScopeHeaders() + connection = fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ) + session_id = self.session_backend.get_session_id(connection) + connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues] await self.session_backend.store_in_message( - scope_session=data, - message=fake_http_send_message(mutable_headers), - connection=fake_asgi_connection( - app=self.app, - cookies=dict(self.cookies), # type: ignore[arg-type] - ), + scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection ) response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) diff --git a/litestar/utils/scope/state.py b/litestar/utils/scope/state.py index 31f6442e61..bed43940e2 100644 --- a/litestar/utils/scope/state.py +++ b/litestar/utils/scope/state.py @@ -41,6 +41,7 @@ class ScopeState: "msgpack", "parsed_query", "response_compressed", + "session_id", "url", "_compat_ns", ) @@ -62,6 +63,7 @@ def __init__(self) -> None: self.msgpack = Empty self.parsed_query = Empty self.response_compressed = Empty + self.session_id = Empty self.url = Empty self._compat_ns: dict[str, Any] = {} @@ -81,6 +83,7 @@ def __init__(self) -> None: msgpack: Any | EmptyType parsed_query: tuple[tuple[str, str], ...] | EmptyType response_compressed: bool | EmptyType + session_id: str | None | EmptyType url: URL | EmptyType _compat_ns: dict[str, Any] diff --git a/tests/unit/test_middleware/test_session/test_middleware.py b/tests/unit/test_middleware/test_session/test_middleware.py index fb60bb15e0..ef0aa89d11 100644 --- a/tests/unit/test_middleware/test_session/test_middleware.py +++ b/tests/unit/test_middleware/test_session/test_middleware.py @@ -1,13 +1,13 @@ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Union from litestar import HttpMethod, Request, Response, get, post, route +from litestar.middleware.session.server_side import ServerSideSessionConfig from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import create_test_client from litestar.types import Empty if TYPE_CHECKING: from litestar.middleware.session.base import BaseBackendConfig - from litestar.middleware.session.server_side import ServerSideSessionConfig def test_session_middleware_not_installed_raises() -> None: @@ -37,6 +37,7 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]: with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: response = client.get("/session") assert response.json() == {"has_session": False} + first_session_id = client.cookies.get("session") client.post("/session") @@ -52,6 +53,57 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]: response = client.get("/session") assert response.json() == {"has_session": True} + second_session_id = client.cookies.get("session") + assert first_session_id != second_session_id + + +def test_session_id_correctness(session_backend_config: "BaseBackendConfig") -> None: + # Test that `request.get_session_id()` is the same as in the cookies + @route("/session", http_method=[HttpMethod.POST]) + def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]: + request.set_session({"foo": "bar"}) + return {"session_id": request.get_session_id()} + + with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: + if isinstance(session_backend_config, ServerSideSessionConfig): + # Generic verification that a session id is set before entering the route handler scope + response = client.post("/session") + request_session_id = response.json()["session_id"] + cookie_session_id = client.cookies.get("session") + assert request_session_id == cookie_session_id + else: + # Client side config does not have a session id in cookies + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + + +def test_keep_session_id(session_backend_config: "BaseBackendConfig") -> None: + # Test that session is only created if not already exists + @route("/session", http_method=[HttpMethod.POST]) + def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]: + request.set_session({"foo": "bar"}) + return {"session_id": request.get_session_id()} + + with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: + if isinstance(session_backend_config, ServerSideSessionConfig): + # Generic verification that a session id is set before entering the route handler scope + response = client.post("/session") + first_call_id = response.json()["session_id"] + response = client.post("/session") + second_call_id = response.json()["session_id"] + assert first_call_id == second_call_id == client.cookies.get("session") + else: + # Client side config does not have a session id in cookies + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None def test_set_empty(session_backend_config: "BaseBackendConfig") -> None: