Skip to content

Commit

Permalink
Use Annotated for dependency arguments
Browse files Browse the repository at this point in the history
FastAPI now recommends using Annotated to indicate arguments that
are set by FastAPI via dependency injection. Among other benefits,
this aligns Python's opinion about whether the argument has a
default with reality when the function is called outside of a
FastAPI context, which can avoid bugs in some edge case situations.
It also allows removing special Ruff configuration whitelisting the
special FastAPI dependency functions, since Annotated makes the
intent obvious.

Also make the same change to route definitions in the test suite.

Adjust the docstring for auth_delegated_token_dependency to use a
proper link with anchor text.
  • Loading branch information
rra committed Feb 1, 2024
1 parent 3e9d0ca commit 623278c
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 31 deletions.
9 changes: 0 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,6 @@ target-version = "py311"
known-first-party = ["safir", "tests"]
split-on-trailing-comma = false

[tool.ruff.flake8-bugbear]
extend-immutable-calls = [
"fastapi.Form",
"fastapi.Header",
"fastapi.Depends",
"fastapi.Path",
"fastapi.Query",
]

# These are too useful as attributes or methods to allow the conflict with the
# built-in to rule out their use.
[tool.ruff.flake8-builtins]
Expand Down
18 changes: 11 additions & 7 deletions src/safir/dependencies/gafaelfawr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Gafaelfawr authentication dependencies."""

from typing import Annotated

from fastapi import Depends, Header
from structlog.stdlib import BoundLogger

Expand All @@ -13,7 +15,7 @@


async def auth_dependency(
x_auth_request_user: str = Header(..., include_in_schema=False)
x_auth_request_user: Annotated[str, Header(include_in_schema=False)]
) -> str:
"""Retrieve authentication information from HTTP headers.
Expand All @@ -25,22 +27,24 @@ async def auth_dependency(


async def auth_delegated_token_dependency(
x_auth_request_token: str = Header(..., include_in_schema=False)
x_auth_request_token: Annotated[str, Header(include_in_schema=False)]
) -> str:
"""Retrieve Gafaelfawr delegated token from HTTP headers.
Intended for use with applications protected by Gafaelfawr, this retrieves
a delegated token from headers added to the incoming request by the
Gafaelfawr ``auth_request`` NGINX subhandler. The delegated token can
be used to make requests to other services on the user's behalf, see
https://gafaelfawr.lsst.io/user-guide/gafaelfawringress.html#requesting-delegated-tokens
Gafaelfawr ``auth_request`` NGINX subhandler. The delegated token can be
used to make requests to other services on the user's behalf. See `the
Gafaelfawr documentation
<https://gafaelfawr.lsst.io/user-guide/gafaelfawringress.html#requesting-delegated-tokens>`__
for more details.
"""
return x_auth_request_token


async def auth_logger_dependency(
user: str = Depends(auth_dependency),
logger: BoundLogger = Depends(logger_dependency),
user: Annotated[str, Depends(auth_dependency)],
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> BoundLogger:
"""Logger bound to the authenticated user."""
return logger.bind(user=user)
18 changes: 11 additions & 7 deletions tests/dependencies/arq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from typing import Annotated, Any

import pytest
from arq.constants import default_queue_name
Expand All @@ -29,7 +29,8 @@ async def test_arq_dependency_mock() -> None:

@app.post("/")
async def post_job(
arq_queue: MockArqQueue = Depends(arq_dependency),
*,
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> dict[str, Any]:
"""Create a job."""
job = await arq_queue.enqueue("test_task", "hello", a_number=42)
Expand All @@ -44,9 +45,10 @@ async def post_job(

@app.get("/jobs/{job_id}")
async def get_metadata(
*,
job_id: str,
queue_name: str | None = None,
arq_queue: MockArqQueue = Depends(arq_dependency),
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> dict[str, Any]:
"""Get metadata about a job."""
try:
Expand All @@ -66,9 +68,10 @@ async def get_metadata(

@app.get("/results/{job_id}")
async def get_result(
*,
job_id: str,
queue_name: str | None = None,
arq_queue: MockArqQueue = Depends(arq_dependency),
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> dict[str, Any]:
"""Get the results for a job."""
try:
Expand All @@ -90,9 +93,10 @@ async def get_result(

@app.post("/jobs/{job_id}/inprogress")
async def post_job_inprogress(
*,
job_id: str,
queue_name: str | None = None,
arq_queue: MockArqQueue = Depends(arq_dependency),
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> None:
"""Toggle a job to in-progress, for testing."""
try:
Expand All @@ -102,12 +106,12 @@ async def post_job_inprogress(

@app.post("/jobs/{job_id}/complete")
async def post_job_complete(
job_id: str,
*,
job_id: str,
queue_name: str | None = None,
result: str | None = None,
success: bool = True,
arq_queue: MockArqQueue = Depends(arq_dependency),
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> None:
"""Toggle a job to complete, for testing."""
try:
Expand Down
9 changes: 7 additions & 2 deletions tests/dependencies/db_session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from typing import Annotated

import pytest
import structlog
Expand Down Expand Up @@ -48,14 +49,18 @@ async def test_session() -> None:

@app.post("/add")
async def add(
session: async_scoped_session = Depends(db_session_dependency),
session: Annotated[
async_scoped_session, Depends(db_session_dependency)
],
) -> None:
async with session.begin():
session.add(User(username="foo"))

@app.get("/list")
async def get_list(
session: async_scoped_session = Depends(db_session_dependency),
session: Annotated[
async_scoped_session, Depends(db_session_dependency)
],
) -> list[str]:
async with session.begin():
result = await session.scalars(select(User.username))
Expand Down
9 changes: 6 additions & 3 deletions tests/dependencies/gafaelfawr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from typing import Annotated
from unittest.mock import ANY

import pytest
Expand All @@ -24,7 +25,9 @@ async def test_auth_dependency() -> None:
app = FastAPI()

@app.get("/")
async def handler(user: str = Depends(auth_dependency)) -> dict[str, str]:
async def handler(
user: Annotated[str, Depends(auth_dependency)]
) -> dict[str, str]:
return {"user": user}

async with AsyncClient(app=app, base_url="https://example.com") as client:
Expand All @@ -42,7 +45,7 @@ async def test_auth_delegated_token_dependency() -> None:

@app.get("/")
async def handler(
token: str = Depends(auth_delegated_token_dependency),
token: Annotated[str, Depends(auth_delegated_token_dependency)],
) -> dict[str, str]:
return {"token": token}

Expand All @@ -65,7 +68,7 @@ async def test_auth_logger_dependency(caplog: LogCaptureFixture) -> None:

@app.get("/")
async def handler(
logger: BoundLogger = Depends(auth_logger_dependency),
logger: Annotated[BoundLogger, Depends(auth_logger_dependency)],
) -> dict[str, str]:
logger.info("something")
return {}
Expand Down
3 changes: 2 additions & 1 deletion tests/dependencies/http_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Annotated

import pytest
import respx
Expand Down Expand Up @@ -32,7 +33,7 @@ async def test_http_client(respx_mock: respx.Router) -> None:

@app.get("/")
async def handler(
http_client: AsyncClient = Depends(http_client_dependency),
http_client: Annotated[AsyncClient, Depends(http_client_dependency)],
) -> dict[str, str]:
assert isinstance(http_client, AsyncClient)
await http_client.get("https://www.google.com")
Expand Down
5 changes: 3 additions & 2 deletions tests/dependencies/logger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from typing import Annotated
from unittest.mock import ANY

import pytest
Expand All @@ -24,7 +25,7 @@ async def test_logger(caplog: LogCaptureFixture) -> None:

@app.get("/")
async def handler(
logger: BoundLogger = Depends(logger_dependency),
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> dict[str, str]:
logger.info("something", param="value")
return {}
Expand Down Expand Up @@ -61,7 +62,7 @@ async def test_logger_xforwarded(caplog: LogCaptureFixture) -> None:

@app.get("/")
async def handler(
logger: BoundLogger = Depends(logger_dependency),
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> dict[str, str]:
logger.info("something", param="value")
return {}
Expand Down

0 comments on commit 623278c

Please sign in to comment.