Skip to content

Commit

Permalink
Add support for aborting arq jobs
Browse files Browse the repository at this point in the history
Add a new `abort_job` method to instances of `safir.arq.ArqQueue`,
which tries to abort a job. If the job has already started, this
requires the worker to enable support for aborting jobs.

The mock implementation deletes the job entirely if it hasn't been
started, and treats it as if it failed with asyncio.CancelledError
if it were already running. This seems to match what arq does with
a Redis queue.
  • Loading branch information
rra committed Jul 11, 2024
1 parent 5d197a2 commit 3e4aade
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
3 changes: 3 additions & 0 deletions changelog.d/20240711_120038_rra_DM_45138.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### New features

- Add new `abort_job` method to instances of `safir.arq.ArqQueue`, which tells arq to abort a job that has been queued or in progress. To successfully abort a job that has already started, the arq worker must enable support for aborting jobs.
76 changes: 76 additions & 0 deletions src/safir/arq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime
Expand Down Expand Up @@ -326,6 +327,31 @@ async def enqueue(
"""
raise NotImplementedError

@abc.abstractmethod
async def abort_job(
self,
job_id: str,
queue_name: str | None = None,
*,
timeout: float | None = None,
) -> bool:
"""Abort a queued or running job.
The worker must be configured to allow aborting jobs for this to
succeed.
Parameters
----------
job_id
The job's identifier.
queue_name
Name of the queue.
timeout
How long to wait for the job result before raising `TimeoutError`.
If `None`, waits forever.
"""
raise NotImplementedError

@abc.abstractmethod
async def get_job_metadata(
self, job_id: str, queue_name: str | None = None
Expand Down Expand Up @@ -433,6 +459,16 @@ def _get_job(self, job_id: str, queue_name: str | None = None) -> Job:
_queue_name=queue_name or self.default_queue_name,
)

async def abort_job(
self,
job_id: str,
queue_name: str | None = None,
*,
timeout: float | None = None,
) -> bool:
job = self._get_job(job_id, queue_name=queue_name)
return await job.abort(timeout=timeout)

async def get_job_metadata(
self, job_id: str, queue_name: str | None = None
) -> JobMetadata:
Expand Down Expand Up @@ -483,6 +519,46 @@ async def enqueue(
self._job_metadata[queue_name][new_job.id] = new_job
return new_job

async def abort_job(
self,
job_id: str,
queue_name: str | None = None,
*,
timeout: float | None = None,
) -> bool:
queue_name = self._resolve_queue_name(queue_name)
try:
job_metadata = self._job_metadata[queue_name][job_id]
except KeyError:
return False

# If the job was started, simulate cancelling it.
if job_metadata.status == JobStatus.in_progress:
job_metadata.status = JobStatus.complete
result_info = JobResult(
id=job_metadata.id,
name=job_metadata.name,
args=job_metadata.args,
kwargs=job_metadata.kwargs,
status=job_metadata.status,
enqueue_time=job_metadata.enqueue_time,
start_time=current_datetime(microseconds=True),
finish_time=current_datetime(microseconds=True),
result=asyncio.CancelledError(),
success=False,
queue_name=queue_name,
)
self._job_results[queue_name][job_id] = result_info
return True

# If it was just queued, delete it.
if job_metadata.status in (JobStatus.deferred, JobStatus.queued):
del self._job_metadata[queue_name][job_id]
return True

# Otherwise, the job has already completed, so we can't abort it.
return False

async def get_job_metadata(
self, job_id: str, queue_name: str | None = None
) -> JobMetadata:
Expand Down
53 changes: 53 additions & 0 deletions tests/dependencies/arq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ async def get_result(
)
except (JobNotFound, JobResultUnavailable) as e:
raise HTTPException(status_code=404, detail=str(e)) from e

# For testing purposes, turn exceptions into something serializable.
if isinstance(job_result.result, BaseException):
job_result.result = f"EXCEPTION {type(job_result.result).__name__}"
return {
"job_id": job_result.id,
"job_status": job_result.status,
Expand Down Expand Up @@ -121,6 +125,19 @@ async def post_job_complete(
except JobNotFound as e:
raise HTTPException(status_code=404, detail=str(e)) from e

@app.post("/jobs/{job_id}/abort")
async def abort_job(
*,
job_id: str,
queue_name: str | None = None,
arq_queue: Annotated[MockArqQueue, Depends(arq_dependency)],
) -> None:
"""Abort a job, for testing."""
try:
await arq_queue.abort_job(job_id, queue_name=queue_name)
except JobNotFound as e:
raise HTTPException(status_code=404, detail=str(e)) from e

transport = ASGITransport(app=app) # type: ignore[arg-type]
base_url = "http://example.com"
async with LifespanManager(app):
Expand Down Expand Up @@ -165,3 +182,39 @@ async def post_job_complete(
assert data["job_status"] == "complete"
assert data["job_result"] == "done"
assert data["job_success"] is True

# Aborting a completed job does nothing.
r = await c.post(f"/jobs/{job_id}/abort")
r = await c.get(f"/results/{job_id}")
assert r.status_code == 200
data = r.json()
assert data["job_status"] == "complete"

# Create a new job and abort it before starting it, which should
# delete the job.
r = await c.post("/")
assert r.status_code == 200
data = r.json()
job_id = data["job_id"]
r = await c.get(f"/jobs/{job_id}")
assert r.status_code == 200
r = await c.post(f"/jobs/{job_id}/abort")
r = await c.get(f"/results/{job_id}")
assert r.status_code == 404

# Create a new job, start it, and then abort it. This should keep
# the job but mark it complete and failed.
r = await c.post("/")
assert r.status_code == 200
data = r.json()
job_id = data["job_id"]
r = await c.get(f"/jobs/{job_id}")
assert r.status_code == 200
r = await c.post(f"/jobs/{job_id}/inprogress")
r = await c.post(f"/jobs/{job_id}/abort")
r = await c.get(f"/results/{job_id}")
assert r.status_code == 200
data = r.json()
assert data["job_status"] == "complete"
assert data["job_result"] == "EXCEPTION CancelledError"
assert data["job_success"] is False

0 comments on commit 3e4aade

Please sign in to comment.