Skip to content

Commit

Permalink
Move UWS test support code to safir.testing.uws
Browse files Browse the repository at this point in the history
To allow the UWS support code to be reused by UWS applications, move
the mock job runner to a new safir.testing.uws module. Add some
additional support to the UWSApplication object so that test suites
don't need access to private members of the safir.uws module.
  • Loading branch information
rra committed Jul 24, 2024
1 parent 4c23a86 commit c08ad80
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 154 deletions.
3 changes: 3 additions & 0 deletions changelog.d/20240723_170311_rra_DM_45281_queue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### New features

- Add new `safir.testing.uws` module that provides a mock UWS job runner for testing UWS applications.
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,8 @@ API reference
.. automodapi:: safir.testing.uvicorn
:include-all-objects:

.. automodapi:: safir.testing.uws
:include-all-objects:

.. automodapi:: safir.uws
:include-all-objects:
175 changes: 175 additions & 0 deletions safir/src/safir/testing/uws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Mock UWS job executor for testing."""

from __future__ import annotations

import asyncio
from datetime import UTC, datetime
from types import TracebackType
from typing import Literal, Self

import structlog
from sqlalchemy.ext.asyncio import AsyncEngine

from safir.arq import JobMetadata, JobResult, MockArqQueue
from safir.database import create_async_session, create_database_engine
from safir.uws import UWSConfig, UWSJob, UWSJobResult
from safir.uws._service import JobService
from safir.uws._storage import JobStore

__all__ = ["MockUWSJobRunner"]


class MockUWSJobRunner:
"""Simulate execution of jobs with a mock queue.
When running the test suite, the arq queue is replaced with a mock queue
that doesn't execute workers. That execution has to be simulated by
manually updating state in the mock queue and running the UWS database
worker functions that normally would be run automatically by the queue.
This class wraps that functionality in an async context manager. An
instance of it is normally provided as a fixture, initialized with the
same test objects as the test suite.
Parameters
----------
config
UWS configuration.
arq_queue
Mock arq queue for testing.
"""

def __init__(self, config: UWSConfig, arq_queue: MockArqQueue) -> None:
self._config = config
self._arq = arq_queue
self._engine: AsyncEngine
self._store: JobStore
self._service: JobService

async def __aenter__(self) -> Self:
"""Create a database session and the underlying service."""
# This duplicates some of the code in UWSDependency to avoid needing
# to set up the result store or to expose UWSFactory outside of the
# Safir package internals.
self._engine = create_database_engine(
self._config.database_url,
self._config.database_password,
isolation_level="REPEATABLE READ",
)
session = await create_async_session(self._engine)
self._store = JobStore(session)
self._service = JobService(
config=self._config,
arq_queue=self._arq,
storage=self._store,
logger=structlog.get_logger("uws"),
)
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
"""Close the database engine and session."""
await self._engine.dispose()
return False

async def get_job_metadata(
self, username: str, job_id: str
) -> JobMetadata:
"""Get the arq job metadata for a job.
Parameters
----------
job_id
UWS job ID.
Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_metadata(job.message_id)

async def get_job_result(self, username: str, job_id: str) -> JobResult:
"""Get the arq job result for a job.
Parameters
----------
job_id
UWS job ID.
Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_result(job.message_id)

async def mark_in_progress(
self, username: str, job_id: str, *, delay: float | None = None
) -> UWSJob:
"""Mark a queued job in progress.
Parameters
----------
username
Owner of job.
job_id
Job ID.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_in_progress(job.message_id)
await self._store.mark_executing(job_id, datetime.now(tz=UTC))
return await self._service.get(username, job_id)

async def mark_complete(
self,
username: str,
job_id: str,
results: list[UWSJobResult] | Exception,
*,
delay: float | None = None,
) -> UWSJob:
"""Mark an in progress job as complete.
Parameters
----------
username
Owner of job.
job_id
Job ID.
results
Results to return. May be an exception to simulate a job failure.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_complete(job.message_id, result=results)
job_result = await self._arq.get_job_result(job.message_id)
await self._store.mark_completed(job_id, job_result)
return await self._service.get(username, job_id)
13 changes: 13 additions & 0 deletions safir/src/safir/uws/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fastapi.responses import PlainTextResponse
from structlog.stdlib import BoundLogger

from safir.arq import ArqQueue
from safir.arq.uws import UWS_QUEUE_NAME, WorkerSettings
from safir.database import create_database_engine, initialize_database

Expand Down Expand Up @@ -193,6 +194,18 @@ def install_handlers(self, router: APIRouter) -> None:
# FastAPI. This problem was last verified in FastAPI 0.111.0.
install_async_post_handler(router, self._config.async_post_route)

def override_arq_queue(self, arq_queue: ArqQueue) -> None:
"""Change the arq used by the FastAPI route handlers.
This method is probably only useful for the test suite.
Parameters
----------
arq
New arq queue.
"""
uws_dependency.override_arq_queue(arq_queue)

async def shutdown_fastapi(self) -> None:
"""Shut down the UWS subsystem for FastAPI applications.
Expand Down
140 changes: 3 additions & 137 deletions safir/tests/support/uws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,17 @@

from __future__ import annotations

import asyncio
from datetime import UTC, datetime, timedelta
from datetime import timedelta
from typing import Annotated, Self

from arq.connections import RedisSettings
from fastapi import Form, Query
from pydantic import BaseModel, SecretStr

from safir.arq import ArqMode, JobMetadata, JobResult, MockArqQueue
from safir.uws import (
ParametersModel,
UWSConfig,
UWSJob,
UWSJobParameter,
UWSJobResult,
UWSRoute,
)
from safir.uws._dependencies import UWSFactory
from safir.arq import ArqMode
from safir.uws import ParametersModel, UWSConfig, UWSJobParameter, UWSRoute

__all__ = [
"MockJobRunner",
"SimpleParameters",
"build_uws_config",
]
Expand Down Expand Up @@ -80,127 +70,3 @@ def build_uws_config(database_url: str, database_password: str) -> UWSConfig:
),
worker="hello",
)


class MockJobRunner:
"""Simulate execution of jobs with a mock queue.
When running the test suite, the arq queue is replaced with a mock queue
that doesn't execute workers. That execution has to be simulated by
manually updating state in the mock queue and running the UWS database
worker functions that normally would be run automatically by the queue.
This class wraps that functionality. An instance of it is normally
provided as a fixture, initialized with the same test objects as the test
suite.
Parameters
----------
factory
Factory for UWS components.
arq_queue
Mock arq queue for testing.
"""

def __init__(self, factory: UWSFactory, arq_queue: MockArqQueue) -> None:
self._service = factory.create_job_service()
self._store = factory.create_job_store()
self._arq = arq_queue

async def get_job_metadata(
self, username: str, job_id: str
) -> JobMetadata:
"""Get the arq job metadata for a job.
Parameters
----------
job_id
UWS job ID.
Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_metadata(job.message_id)

async def get_job_result(self, username: str, job_id: str) -> JobResult:
"""Get the arq job result for a job.
Parameters
----------
job_id
UWS job ID.
Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_result(job.message_id)

async def mark_in_progress(
self, username: str, job_id: str, *, delay: float | None = None
) -> UWSJob:
"""Mark a queued job in progress.
Parameters
----------
username
Owner of job.
job_id
Job ID.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_in_progress(job.message_id)
await self._store.mark_executing(job_id, datetime.now(tz=UTC))
return await self._service.get(username, job_id)

async def mark_complete(
self,
username: str,
job_id: str,
results: list[UWSJobResult] | Exception,
*,
delay: float | None = None,
) -> UWSJob:
"""Mark an in progress job as complete.
Parameters
----------
username
Owner of job.
job_id
Job ID.
results
Results to return. May be an exception to simulate a job failure.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_complete(job.message_id, result=results)
job_result = await self._arq.get_job_result(job.message_id)
await self._store.mark_completed(job_id, job_result)
return await self._service.get(username, job_id)
Loading

0 comments on commit c08ad80

Please sign in to comment.