Skip to content

Commit

Permalink
Merge pull request #176 from tobymao/toby/priority_groups
Browse files Browse the repository at this point in the history
feat!: add priorities and groups to postgres
  • Loading branch information
tobymao authored Oct 13, 2024
2 parents 29ab970 + 4d0139d commit cbaf0c7
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 43 deletions.
Empty file added benchmarks/__init__.py
Empty file.
4 changes: 3 additions & 1 deletion benchmarks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time

from benchmarks.funcs import *
from funcs import *


SEM = asyncio.Semaphore(20)
Expand Down Expand Up @@ -68,6 +68,8 @@ async def enqueue(func):
while await queue.count("incomplete"):
await asyncio.sleep(0.1)
print(f"SAQ process {N} sleep {time.time() - now}")
await worker.stop()
await queue.disconnect()


def bench_rq():
Expand Down
4 changes: 3 additions & 1 deletion examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ async def cron_job(ctx):
print("executing cron job")


queue = Queue.from_url("postgres://postgres@localhost")

settings = {
"queue": queue,
"functions": [sleeper, adder],
"concurrency": 100,
"cron_jobs": [CronJob(cron_job, cron="* * * * * */5")],
Expand All @@ -33,7 +36,6 @@ async def cron_job(ctx):


async def enqueue(func, **kwargs):
queue = Queue.from_url("redis://localhost")
for _ in range(10000):
await queue.enqueue(func, **{k: v() for k, v in kwargs.items()})

Expand Down
18 changes: 11 additions & 7 deletions saq/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ class Job:
Don't set these, but you can read them.
Parameters:
attempts (int): number of attempts a job has had
completed (int): job completion time epoch seconds
queued (int): job enqueued time epoch seconds
started (int): job started time epoch seconds
touched (int): job touched/updated time epoch seconds
attempts: number of attempts a job has had
completed: job completion time epoch seconds
queued: job enqueued time epoch seconds
started: job started time epoch seconds
touched: job touched/updated time epoch seconds
result: payload containing the results, this is the return of the function provided, must be serializable, defaults to json
error (str | None): stack trace if a runtime error occurs
status (Status): Status Enum, default to Status.New
error: stack trace if a runtime error occurs
status: Status Enum, default to Status.New
priority: The priority of a job, only available in postgres.
group_key: Only one job per group can be active at any time, only available in postgres.
"""

function: str
Expand All @@ -131,6 +133,8 @@ class Job:
result: t.Any = None
error: str | None = None
status: Status = Status.NEW
priority: int = 0
group_key: str | None = None
meta: dict[t.Any, t.Any] = dataclasses.field(default_factory=dict)

_EXCLUDE_NON_FULL = {
Expand Down
66 changes: 47 additions & 19 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class PostgresQueue(Queue):
saq_lock_keyspace: The first of two advisory lock keys used by SAQ. (default 0)
SAQ uses advisory locks for coordinating tasks between its workers, e.g. sweeping.
job_lock_keyspace: The first of two advisory lock keys used for jobs. (default 1)
priorities: The priority range to dequeue (default (0, 32767))
"""

@classmethod
Expand All @@ -95,6 +96,7 @@ def __init__(
poll_interval: int = 1,
saq_lock_keyspace: int = 0,
job_lock_keyspace: int = 1,
priorities: tuple[int, int] = (0, 32767),
) -> None:
super().__init__(name=name, dump=dump, load=load)

Expand All @@ -106,6 +108,7 @@ def __init__(
self.poll_interval = poll_interval
self.saq_lock_keyspace = saq_lock_keyspace
self.job_lock_keyspace = job_lock_keyspace
self._priorities = priorities

self._job_queue: asyncio.Queue = asyncio.Queue()
self._waiting = 0 # Internal counter of worker tasks waiting for dequeue
Expand Down Expand Up @@ -165,11 +168,10 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
dedent(
"""
SELECT worker_id, stats FROM {stats_table}
WHERE %(now)s <= expire_at
WHERE NOW() <= TO_TIMESTAMP(expire_at)
"""
)
).format(stats_table=self.stats_table),
{"now": seconds(now())},
)
results = await cursor.fetchall()
workers: dict[str, dict[str, t.Any]] = dict(results)
Expand Down Expand Up @@ -212,14 +214,15 @@ async def count(self, kind: CountKind) -> int:
SQL(
dedent(
"""
SELECT count(*) FROM {jobs_table}
SELECT count(*)
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
AND NOW() >= TO_TIMESTAMP(scheduled)
"""
)
).format(jobs_table=self.jobs_table),
{"queue": self.name, "now": seconds(now())},
{"queue": self.name},
)
elif kind == "active":
await cursor.execute(
Expand Down Expand Up @@ -287,7 +290,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
DELETE FROM {jobs_table}
WHERE queue = %(queue)s
AND status IN ('aborted', 'complete', 'failed')
AND %(now)s >= expire_at;
AND NOW() >= TO_TIMESTAMP(expire_at);
"""
)
).format(
Expand All @@ -296,7 +299,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
),
{
"queue": self.name,
"now": seconds(now()),
},
)

Expand All @@ -306,16 +308,13 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
"""
-- Delete expired stats
DELETE FROM {stats_table}
WHERE %(now)s >= expire_at;
WHERE NOW() >= TO_TIMESTAMP(expire_at);
"""
)
).format(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
),
{
"now": seconds(now()),
},
)

await cursor.execute(
Expand Down Expand Up @@ -571,8 +570,16 @@ async def _dequeue(self) -> None:
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
ORDER BY scheduled
AND NOW() >= TO_TIMESTAMP(scheduled)
AND priority BETWEEN %(plow)s AND %(phigh)s
AND group_key NOT IN (
SELECT DISTINCT group_key
FROM {jobs_table}
WHERE status = 'active'
AND queue = %(queue)s
AND group_key IS NOT NULL
)
ORDER BY priority, scheduled
LIMIT %(limit)s
FOR UPDATE SKIP LOCKED
)
Expand All @@ -589,8 +596,9 @@ async def _dequeue(self) -> None:
),
{
"queue": self.name,
"now": seconds(now()),
"limit": self._waiting,
"plow": self._priorities[0],
"phigh": self._priorities[1],
},
)
results = await cursor.fetchall()
Expand All @@ -607,13 +615,31 @@ async def _enqueue(self, job: Job) -> Job | None:
SQL(
dedent(
"""
INSERT INTO {jobs_table} (key, job, queue, status, scheduled)
VALUES (%(key)s, %(job)s, %(queue)s, %(status)s, %(scheduled)s)
INSERT INTO {jobs_table} (
key,
job,
queue,
status,
priority,
group_key,
scheduled
)
VALUES (
%(key)s,
%(job)s,
%(queue)s,
%(status)s,
%(priority)s,
%(group_key)s,
%(scheduled)s
)
ON CONFLICT (key) DO UPDATE
SET
job = %(job)s,
queue = %(queue)s,
status = %(status)s,
priority = %(priority)s,
group_key = %(group_key)s,
scheduled = %(scheduled)s,
expire_at = null
WHERE
Expand All @@ -628,6 +654,8 @@ async def _enqueue(self, job: Job) -> Job | None:
"job": self.serialize(job),
"queue": self.name,
"status": job.status,
"priority": job.priority,
"group_key": job.group_key,
"scheduled": job.scheduled or int(seconds(now())),
},
)
Expand All @@ -645,16 +673,16 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None:
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(expire_at)s)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = %(expire_at)s
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": self.uuid,
"stats": json.dumps(stats),
"expire_at": seconds(now()) + ttl,
"ttl": ttl,
},
)

Expand Down
5 changes: 3 additions & 2 deletions saq/queue/postgres_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
CREATE TABLE IF NOT EXISTS {jobs_table} (
key TEXT PRIMARY KEY,
lock_key SERIAL NOT NULL,
queued BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM now()),
job BYTEA NOT NULL,
queue TEXT NOT NULL,
status TEXT NOT NULL,
scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM now()),
priority SMALLINT NOT NULL DEFAULT 0,
group_key TEXT,
scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
expire_at BIGINT
);
"""
Expand Down
29 changes: 21 additions & 8 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,14 +718,9 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None:
result = await cursor.fetchone()
self.assertIsNone(result)

@mock.patch("saq.utils.time")
async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None:
mock_time.time.return_value = 1000.5
await self.enqueue("test", scheduled=1001)

# The job is scheduled to run at 1001, but we're running at 1000.5
# so it should not be picked up
job = await self.queue.dequeue(timeout=1)
async def test_cron_job_close_to_target(self) -> None:
await self.enqueue("test", scheduled=time.time() + 0.5)
job = await self.queue.dequeue(timeout=0.1)
assert not job

async def test_bad_connection(self) -> None:
Expand All @@ -741,3 +736,21 @@ async def test_bad_connection(self) -> None:
self.assertNotEqual(original_connection, self.queue._dequeue_conn)

await self.queue.pool.putconn(original_connection)

async def test_group_key(self) -> None:
job1 = await self.enqueue("test", group_key=1)
assert job1
job2 = await self.enqueue("test", group_key=1)
assert job2
self.assertEqual(await self.count("queued"), 2)

assert await self.dequeue()
self.assertEqual(await self.count("queued"), 1)
assert not await self.queue.dequeue(0.01)
await job1.update(status="finished")
assert await self.dequeue()

async def test_priority(self) -> None:
assert await self.enqueue("test", priority=-1)
self.assertEqual(await self.count("queued"), 1)
assert not await self.queue.dequeue(0.01)
8 changes: 3 additions & 5 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import contextvars
import logging
import time
import typing as t
import unittest
from unittest import mock
Expand Down Expand Up @@ -522,8 +521,7 @@ async def test_schedule(self, mock_time: MagicMock) -> None:
self.skipTest("Not implemented")

@mock.patch("saq.worker.logger")
@mock.patch("saq.utils.time")
async def test_cron(self, mock_time: MagicMock, mock_logger: MagicMock) -> None:
async def test_cron(self, mock_logger: MagicMock) -> None:
with self.assertRaises(ValueError):
Worker(
self.queue,
Expand All @@ -534,15 +532,15 @@ async def test_cron(self, mock_time: MagicMock, mock_logger: MagicMock) -> None:
worker = Worker(
self.queue,
functions=functions,
cron_jobs=[CronJob(cron, cron="* * * * *")],
cron_jobs=[CronJob(cron, cron="* * * * * *")],
)
self.assertEqual(await self.queue.count("queued"), 0)
self.assertEqual(await self.queue.count("incomplete"), 0)
await worker.schedule()
self.assertEqual(await self.queue.count("queued"), 0)
self.assertEqual(await self.queue.count("incomplete"), 1)
await asyncio.sleep(1)

mock_time.time.return_value = time.time() + 60
self.assertEqual(await self.queue.count("queued"), 1)
self.assertEqual(await self.queue.count("incomplete"), 1)

Expand Down

0 comments on commit cbaf0c7

Please sign in to comment.