Skip to content

Commit

Permalink
feat: add API for getting session ID (#3127)
Browse files Browse the repository at this point in the history
* have auth middleware wrapper use backend middleware instead of hardcoded middleware

* reworking session id to be generated before the route handler

* adding session_id to ScopeState instead of connection scope

* fixed docstring

* adjusted session_id type and handling. added middleware tests.

* test linting

* adjustments for test coverage

---------

Co-authored-by: Janek Nouvertné <[email protected]>
Co-authored-by: guacs <[email protected]>
  • Loading branch information
3 people authored Mar 4, 2024
1 parent 55be181 commit 7dc7219
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 23 deletions.
8 changes: 5 additions & 3 deletions docs/examples/testing/test_set_session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def get_session_data(request: Request) -> Dict[str, Any]:

app = Litestar(route_handlers=[get_session_data], middleware=[session_config.middleware])

with TestClient(app=app, session_config=session_config) as client:
client.set_session_data({"foo": "bar"})
assert client.get("/test").json() == {"foo": "bar"}

def test_get_session_data() -> None:
with TestClient(app=app, session_config=session_config) as client:
client.set_session_data({"foo": "bar"})
assert client.get("/test").json() == {"foo": "bar"}
7 changes: 6 additions & 1 deletion litestar/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from litestar.datastructures.url import URL, Address, make_absolute_url
from litestar.exceptions import ImproperlyConfiguredException
from litestar.types.empty import Empty
from litestar.utils.empty import value_or_default
from litestar.utils.scope.state import ScopeState

if TYPE_CHECKING:
Expand Down Expand Up @@ -287,7 +288,7 @@ def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) ->
value: Dictionary or pydantic model instance for the session data.
Returns:
None.
None
"""
self.scope["session"] = value

Expand All @@ -301,6 +302,10 @@ def clear_session(self) -> None:
None.
"""
self.scope["session"] = Empty
self._connection_state.session_id = Empty

def get_session_id(self) -> str | None:
return value_or_default(value=self._connection_state.session_id, default=None)

def url_for(self, name: str, **path_parameters: Any) -> str:
"""Return the url for a given route handler name.
Expand Down
12 changes: 12 additions & 0 deletions litestar/middleware/session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ def deserialize_data(data: Any) -> dict[str, Any]:
"""
return cast("dict[str, Any]", decode_json(value=data))

@abstractmethod
def get_session_id(self, connection: ASGIConnection) -> str | None:
"""Try to fetch session id from connection ScopeState. If one does not exist, generate one.
Args:
connection: Originating ASGIConnection containing the scope
Returns:
Session id str or None if the concept of a session id does not apply.
"""

@abstractmethod
async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None:
"""Store the necessary information in the outgoing ``Message``
Expand Down Expand Up @@ -241,5 +252,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

connection = ASGIConnection[Any, Any, Any, Any](scope, receive=receive, send=send)
scope["session"] = await self.backend.load_from_connection(connection)
connection._connection_state.session_id = self.backend.get_session_id(connection) # pyright: ignore [reportGeneralTypeIssues]

await self.app(scope, receive, self.create_send_wrapper(connection))
3 changes: 3 additions & 0 deletions litestar/middleware/session/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ async def load_from_connection(self, connection: ASGIConnection) -> dict[str, An
return self.load_data(data)
return {}

def get_session_id(self, connection: ASGIConnection) -> str | None:
return None


@dataclass
class CookieBackendConfig(BaseBackendConfig[ClientSideSessionBackend]): # pyright: ignore
Expand Down
24 changes: 21 additions & 3 deletions litestar/middleware/session/server_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ async def delete(self, session_id: str, store: Store) -> None:
"""
await store.delete(session_id)

def get_session_id(self, connection: ASGIConnection) -> str:
"""Try to fetch session id from the connection. If one does not exist, generate one.
If a session ID already exists in the cookies, it is returned.
If there is no ID in the cookies but one in the connection state, then the session exists but has not yet
been returned to the user.
Otherwise, a new session must be created.
Args:
connection: Originating ASGIConnection containing the scope
Returns:
Session id str or None if the concept of a session id does not apply.
"""
session_id = connection.cookies.get(self.config.key)
if not session_id or session_id == "null":
session_id = connection.get_session_id()
if not session_id:
session_id = self.generate_session_id()
return session_id

def generate_session_id(self) -> str:
"""Generate a new session-ID, with
n=:attr:`session_id_bytes <ServerSideSessionConfig.session_id_bytes>` random bytes.
Expand Down Expand Up @@ -104,9 +124,7 @@ async def store_in_message(self, scope_session: ScopeSession, message: Message,
scope = connection.scope
store = self.config.get_store_from_app(scope["app"])
headers = MutableScopeHeaders.from_message(message)
session_id = connection.cookies.get(self.config.key)
if not session_id or session_id == "null":
session_id = self.generate_session_id()
session_id = self.get_session_id(connection)

cookie_params = dict(extract_dataclass_items(self.config, exclude_none=True, include=Cookie.__dict__.keys()))

Expand Down
3 changes: 1 addition & 2 deletions litestar/security/session_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
AuthenticationResult,
)
from litestar.middleware.exceptions import ExceptionHandlerMiddleware
from litestar.middleware.session.base import SessionMiddleware
from litestar.types import Empty, Method, Scopes

__all__ = ("MiddlewareWrapper", "SessionAuthMiddleware")
Expand Down Expand Up @@ -61,7 +60,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
exception_handlers=litestar_app.exception_handlers or {}, # pyright: ignore
debug=None,
)
self.app = SessionMiddleware(
self.app = self.config.session_backend_config.middleware.middleware(
app=exception_middleware,
backend=self.config.session_backend,
)
Expand Down
19 changes: 7 additions & 12 deletions litestar/testing/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from httpx._types import CookieTypes

from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend
from litestar.middleware.session.client_side import ClientSideSessionBackend
from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send

T = TypeVar("T", bound=ASGIApp)
Expand Down Expand Up @@ -155,20 +154,16 @@ def portal(self) -> Generator[BlockingPortal, None, None]:
) as portal:
yield portal

@staticmethod
def _create_session_cookies(backend: ClientSideSessionBackend, data: dict[str, Any]) -> dict[str, str]:
encoded_data = backend.dump_data(data=data)
return {cookie.key: cast("str", cookie.value) for cookie in backend._create_session_cookies(encoded_data)}

async def _set_session_data(self, data: dict[str, Any]) -> None:
mutable_headers = MutableScopeHeaders()
connection = fake_asgi_connection(
app=self.app,
cookies=dict(self.cookies), # type: ignore[arg-type]
)
session_id = self.session_backend.get_session_id(connection)
connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues]
await self.session_backend.store_in_message(
scope_session=data,
message=fake_http_send_message(mutable_headers),
connection=fake_asgi_connection(
app=self.app,
cookies=dict(self.cookies), # type: ignore[arg-type]
),
scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection
)
response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers)

Expand Down
3 changes: 3 additions & 0 deletions litestar/utils/scope/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ScopeState:
"msgpack",
"parsed_query",
"response_compressed",
"session_id",
"url",
"_compat_ns",
)
Expand All @@ -62,6 +63,7 @@ def __init__(self) -> None:
self.msgpack = Empty
self.parsed_query = Empty
self.response_compressed = Empty
self.session_id = Empty
self.url = Empty
self._compat_ns: dict[str, Any] = {}

Expand All @@ -81,6 +83,7 @@ def __init__(self) -> None:
msgpack: Any | EmptyType
parsed_query: tuple[tuple[str, str], ...] | EmptyType
response_compressed: bool | EmptyType
session_id: str | None | EmptyType
url: URL | EmptyType
_compat_ns: dict[str, Any]

Expand Down
56 changes: 54 additions & 2 deletions tests/unit/test_middleware/test_session/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, Union

from litestar import HttpMethod, Request, Response, get, post, route
from litestar.middleware.session.server_side import ServerSideSessionConfig
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing import create_test_client
from litestar.types import Empty

if TYPE_CHECKING:
from litestar.middleware.session.base import BaseBackendConfig
from litestar.middleware.session.server_side import ServerSideSessionConfig


def test_session_middleware_not_installed_raises() -> None:
Expand Down Expand Up @@ -37,6 +37,7 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]:
with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client:
response = client.get("/session")
assert response.json() == {"has_session": False}
first_session_id = client.cookies.get("session")

client.post("/session")

Expand All @@ -52,6 +53,57 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]:

response = client.get("/session")
assert response.json() == {"has_session": True}
second_session_id = client.cookies.get("session")
assert first_session_id != second_session_id


def test_session_id_correctness(session_backend_config: "BaseBackendConfig") -> None:
# Test that `request.get_session_id()` is the same as in the cookies
@route("/session", http_method=[HttpMethod.POST])
def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]:
request.set_session({"foo": "bar"})
return {"session_id": request.get_session_id()}

with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client:
if isinstance(session_backend_config, ServerSideSessionConfig):
# Generic verification that a session id is set before entering the route handler scope
response = client.post("/session")
request_session_id = response.json()["session_id"]
cookie_session_id = client.cookies.get("session")
assert request_session_id == cookie_session_id
else:
# Client side config does not have a session id in cookies
response = client.post("/session")
assert response.json()["session_id"] is None
assert client.cookies.get("session") is not None
response = client.post("/session")
assert response.json()["session_id"] is None
assert client.cookies.get("session") is not None


def test_keep_session_id(session_backend_config: "BaseBackendConfig") -> None:
# Test that session is only created if not already exists
@route("/session", http_method=[HttpMethod.POST])
def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]:
request.set_session({"foo": "bar"})
return {"session_id": request.get_session_id()}

with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client:
if isinstance(session_backend_config, ServerSideSessionConfig):
# Generic verification that a session id is set before entering the route handler scope
response = client.post("/session")
first_call_id = response.json()["session_id"]
response = client.post("/session")
second_call_id = response.json()["session_id"]
assert first_call_id == second_call_id == client.cookies.get("session")
else:
# Client side config does not have a session id in cookies
response = client.post("/session")
assert response.json()["session_id"] is None
assert client.cookies.get("session") is not None
response = client.post("/session")
assert response.json()["session_id"] is None
assert client.cookies.get("session") is not None


def test_set_empty(session_backend_config: "BaseBackendConfig") -> None:
Expand Down

0 comments on commit 7dc7219

Please sign in to comment.