diff --git a/changelog.d/20240711_120038_rra_DM_45138.md b/changelog.d/20240711_120038_rra_DM_45138.md new file mode 100644 index 00000000..e178f616 --- /dev/null +++ b/changelog.d/20240711_120038_rra_DM_45138.md @@ -0,0 +1,3 @@ +### New features + +- Add new `abort_job` method to instances of `safir.arq.ArqQueue`, which tells arq to abort a job that has been queued or in progress. To successfully abort a job that has already started, the arq worker must enable support for aborting jobs. diff --git a/src/safir/arq.py b/src/safir/arq.py index d66ae3e1..91faad40 100644 --- a/src/safir/arq.py +++ b/src/safir/arq.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import asyncio import uuid from dataclasses import dataclass from datetime import datetime @@ -326,6 +327,31 @@ async def enqueue( """ raise NotImplementedError + @abc.abstractmethod + async def abort_job( + self, + job_id: str, + queue_name: str | None = None, + *, + timeout: float | None = None, + ) -> bool: + """Abort a queued or running job. + + The worker must be configured to allow aborting jobs for this to + succeed. + + Parameters + ---------- + job_id + The job's identifier. + queue_name + Name of the queue. + timeout + How long to wait for the job result before raising `TimeoutError`. + If `None`, waits forever. + """ + raise NotImplementedError + @abc.abstractmethod async def get_job_metadata( self, job_id: str, queue_name: str | None = None @@ -433,6 +459,16 @@ def _get_job(self, job_id: str, queue_name: str | None = None) -> Job: _queue_name=queue_name or self.default_queue_name, ) + async def abort_job( + self, + job_id: str, + queue_name: str | None = None, + *, + timeout: float | None = None, + ) -> bool: + job = self._get_job(job_id, queue_name=queue_name) + return await job.abort(timeout=timeout) + async def get_job_metadata( self, job_id: str, queue_name: str | None = None ) -> JobMetadata: @@ -483,6 +519,46 @@ async def enqueue( self._job_metadata[queue_name][new_job.id] = new_job return new_job + async def abort_job( + self, + job_id: str, + queue_name: str | None = None, + *, + timeout: float | None = None, + ) -> bool: + queue_name = self._resolve_queue_name(queue_name) + try: + job_metadata = self._job_metadata[queue_name][job_id] + except KeyError: + return False + + # If the job was started, simulate cancelling it. + if job_metadata.status == JobStatus.in_progress: + job_metadata.status = JobStatus.complete + result_info = JobResult( + id=job_metadata.id, + name=job_metadata.name, + args=job_metadata.args, + kwargs=job_metadata.kwargs, + status=job_metadata.status, + enqueue_time=job_metadata.enqueue_time, + start_time=current_datetime(microseconds=True), + finish_time=current_datetime(microseconds=True), + result=asyncio.CancelledError(), + success=False, + queue_name=queue_name, + ) + self._job_results[queue_name][job_id] = result_info + return True + + # If it was just queued, delete it. + if job_metadata.status in (JobStatus.deferred, JobStatus.queued): + del self._job_metadata[queue_name][job_id] + return True + + # Otherwise, the job has already completed, so we can't abort it. + return False + async def get_job_metadata( self, job_id: str, queue_name: str | None = None ) -> JobMetadata: diff --git a/tests/dependencies/arq_test.py b/tests/dependencies/arq_test.py index dbd92e2a..5b3ccc8b 100644 --- a/tests/dependencies/arq_test.py +++ b/tests/dependencies/arq_test.py @@ -80,6 +80,10 @@ async def get_result( ) except (JobNotFound, JobResultUnavailable) as e: raise HTTPException(status_code=404, detail=str(e)) from e + + # For testing purposes, turn exceptions into something serializable. + if isinstance(job_result.result, BaseException): + job_result.result = f"EXCEPTION {type(job_result.result).__name__}" return { "job_id": job_result.id, "job_status": job_result.status, @@ -121,6 +125,19 @@ async def post_job_complete( except JobNotFound as e: raise HTTPException(status_code=404, detail=str(e)) from e + @app.post("/jobs/{job_id}/abort") + async def abort_job( + *, + job_id: str, + queue_name: str | None = None, + arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)], + ) -> None: + """Abort a job, for testing.""" + try: + await arq_queue.abort_job(job_id, queue_name=queue_name) + except JobNotFound as e: + raise HTTPException(status_code=404, detail=str(e)) from e + transport = ASGITransport(app=app) # type: ignore[arg-type] base_url = "http://example.com" async with LifespanManager(app): @@ -165,3 +182,39 @@ async def post_job_complete( assert data["job_status"] == "complete" assert data["job_result"] == "done" assert data["job_success"] is True + + # Aborting a completed job does nothing. + r = await c.post(f"/jobs/{job_id}/abort") + r = await c.get(f"/results/{job_id}") + assert r.status_code == 200 + data = r.json() + assert data["job_status"] == "complete" + + # Create a new job and abort it before starting it, which should + # delete the job. + r = await c.post("/") + assert r.status_code == 200 + data = r.json() + job_id = data["job_id"] + r = await c.get(f"/jobs/{job_id}") + assert r.status_code == 200 + r = await c.post(f"/jobs/{job_id}/abort") + r = await c.get(f"/results/{job_id}") + assert r.status_code == 404 + + # Create a new job, start it, and then abort it. This should keep + # the job but mark it complete and failed. + r = await c.post("/") + assert r.status_code == 200 + data = r.json() + job_id = data["job_id"] + r = await c.get(f"/jobs/{job_id}") + assert r.status_code == 200 + r = await c.post(f"/jobs/{job_id}/inprogress") + r = await c.post(f"/jobs/{job_id}/abort") + r = await c.get(f"/results/{job_id}") + assert r.status_code == 200 + data = r.json() + assert data["job_status"] == "complete" + assert data["job_result"] == "EXCEPTION CancelledError" + assert data["job_success"] is False