Skip to content

Commit

Permalink
Give riverqueue.Job fully defined properties + timestamps as UTC (#26)
Browse files Browse the repository at this point in the history
Previously, the internal sqlc `RiverJob` row was fully typed by virtue
of being generated by sqlc, but the `riverqueue.Job` type was undefined,
with typechecks working by using a `cast`.

Here, give `riverqueue.Job` a full set of defined properties. This is
better for things like conveying type information and autocomplete, but
has a few other side benefits:

* Make sure to return all timestamps as UTC. Previously, they'd be in
  whatever your local timezone is.

* Give some fields like `args`, `metadata`, and `state` better types
  (the first two were previously `Any`).

Lastly, modify `InsertResult` somewhat to make `job` non-optional since
it's always returned, even if insert was skipped, because if it was we
look it up via select query.
  • Loading branch information
brandur authored Jul 6, 2024
1 parent 8b5b76a commit a258b1f
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/riverqueue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
AsyncClient as AsyncClient,
JobArgs as JobArgs,
JobArgsWithInsertOpts as JobArgsWithInsertOpts,
JobState as JobState,
Client as Client,
InsertManyParams as InsertManyParams,
InsertOpts as InsertOpts,
Expand All @@ -12,4 +11,5 @@
from .model import (
InsertResult as InsertResult,
Job as Job,
JobState as JobState,
)
14 changes: 1 addition & 13 deletions src/riverqueue/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from enum import Enum
import re
from typing import (
Any,
Expand All @@ -16,21 +15,10 @@

from .driver import GetParams, JobInsertParams, DriverProtocol, ExecutorProtocol
from .driver.driver_protocol import AsyncDriverProtocol, AsyncExecutorProtocol
from .model import InsertResult
from .model import InsertResult, JobState
from .fnv import fnv1_hash


class JobState(str, Enum):
AVAILABLE = "available"
CANCELLED = "cancelled"
COMPLETED = "completed"
DISCARDED = "discarded"
PENDING = "pending"
RETRYABLE = "retryable"
RUNNING = "running"
SCHEDULED = "scheduled"


MAX_ATTEMPTS_DEFAULT: int = 25
PRIORITY_DEFAULT: int = 1
QUEUE_DEFAULT: str = "default"
Expand Down
71 changes: 51 additions & 20 deletions src/riverqueue/driver/riversqlalchemy/sql_alchemy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from ...driver import DriverProtocol, ExecutorProtocol, GetParams, JobInsertParams
from ...model import Job
from ...model import Job, JobState
from .dbsqlc import models, river_job, pg_misc


Expand All @@ -30,11 +30,13 @@ async def advisory_lock(self, key: int) -> None:
await self.pg_misc_querier.pg_advisory_xact_lock(key=key)

async def job_insert(self, insert_params: JobInsertParams) -> Job:
return cast(
Job,
await self.job_querier.job_insert_fast(
cast(river_job.JobInsertFastParams, insert_params)
),
return _job_from_row(
cast( # drop Optional[] because insert always returns a row
models.RiverJob,
await self.job_querier.job_insert_fast(
cast(river_job.JobInsertFastParams, insert_params)
),
)
)

async def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
Expand All @@ -46,12 +48,10 @@ async def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
async def job_get_by_kind_and_unique_properties(
self, get_params: GetParams
) -> Optional[Job]:
return cast(
Optional[Job],
await self.job_querier.job_get_by_kind_and_unique_properties(
cast(river_job.JobGetByKindAndUniquePropertiesParams, get_params)
),
row = await self.job_querier.job_get_by_kind_and_unique_properties(
cast(river_job.JobGetByKindAndUniquePropertiesParams, get_params)
)
return _job_from_row(row) if row else None

@asynccontextmanager
async def transaction(self) -> AsyncGenerator:
Expand Down Expand Up @@ -91,10 +91,12 @@ def advisory_lock(self, key: int) -> None:
self.pg_misc_querier.pg_advisory_xact_lock(key=key)

def job_insert(self, insert_params: JobInsertParams) -> Job:
return cast(
Job,
self.job_querier.job_insert_fast(
cast(river_job.JobInsertFastParams, insert_params)
return _job_from_row(
cast( # drop Optional[] because insert always returns a row
models.RiverJob,
self.job_querier.job_insert_fast(
cast(river_job.JobInsertFastParams, insert_params)
),
),
)

Expand All @@ -105,12 +107,10 @@ def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
def job_get_by_kind_and_unique_properties(
self, get_params: GetParams
) -> Optional[Job]:
return cast(
Optional[Job],
self.job_querier.job_get_by_kind_and_unique_properties(
cast(river_job.JobGetByKindAndUniquePropertiesParams, get_params)
),
row = self.job_querier.job_get_by_kind_and_unique_properties(
cast(river_job.JobGetByKindAndUniquePropertiesParams, get_params)
)
return _job_from_row(row) if row else None

@contextmanager
def transaction(self) -> Iterator[None]:
Expand Down Expand Up @@ -169,3 +169,34 @@ def _build_insert_many_params(
insert_many_params.tags.append(",".join(insert_params.tags))

return insert_many_params


def _job_from_row(row: models.RiverJob) -> Job:
"""
Converts an internal sqlc generated row to the top level type, issuing a few
minor transformations along the way. Timestamps are changed from local
timezone to UTC.
"""

return Job(
id=row.id,
args=row.args,
attempt=row.attempt,
attempted_at=row.attempted_at.astimezone(timezone.utc)
if row.attempted_at
else None,
attempted_by=row.attempted_by,
created_at=row.created_at.astimezone(timezone.utc),
errors=row.errors,
finalized_at=row.finalized_at.astimezone(timezone.utc)
if row.finalized_at
else None,
kind=row.kind,
max_attempts=row.max_attempts,
metadata=row.metadata,
priority=row.priority,
queue=row.queue,
state=cast(JobState, row.state),
scheduled_at=row.scheduled_at.astimezone(timezone.utc),
tags=row.tags,
)
34 changes: 31 additions & 3 deletions src/riverqueue/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
from dataclasses import dataclass, field
from typing import Optional
import datetime
from enum import Enum
from typing import Any, Optional


class JobState(str, Enum):
AVAILABLE = "available"
CANCELLED = "cancelled"
COMPLETED = "completed"
DISCARDED = "discarded"
PENDING = "pending"
RETRYABLE = "retryable"
RUNNING = "running"
SCHEDULED = "scheduled"


@dataclass
class InsertResult:
job: Optional["Job"] = field(default=None)
job: "Job"
unique_skipped_as_duplicated: bool = field(default=False)


@dataclass
class Job:
pass
id: int
args: dict[str, Any]
attempt: int
attempted_at: Optional[datetime.datetime]
attempted_by: Optional[list[str]]
created_at: datetime.datetime
errors: Optional[list[Any]]
finalized_at: Optional[datetime.datetime]
kind: str
max_attempts: int
metadata: dict[str, Any]
priority: int
queue: str
state: JobState
scheduled_at: datetime.datetime
tags: list[str]
4 changes: 2 additions & 2 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ def test_tag_validation(client):
with pytest.raises(AssertionError) as ex:
client.insert(SimpleArgs(), insert_opts=InsertOpts(tags=["commas,bad"]))
assert (
"tags should be less than 255 characters in length and match regex \A[\w][\w\-]+[\w]\Z"
r"tags should be less than 255 characters in length and match regex \A[\w][\w\-]+[\w]\Z"
== str(ex.value)
)

with pytest.raises(AssertionError) as ex:
client.insert(SimpleArgs(), insert_opts=InsertOpts(tags=["a" * 256]))
assert (
"tags should be less than 255 characters in length and match regex \A[\w][\w\-]+[\w]\Z"
r"tags should be less than 255 characters in length and match regex \A[\w][\w\-]+[\w]\Z"
== str(ex.value)
)

Expand Down
28 changes: 27 additions & 1 deletion tests/driver/riversqlalchemy/sqlalchemy_driver_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import pytest
import pytest_asyncio
from riverqueue.model import JobState
import sqlalchemy
import sqlalchemy.ext.asyncio
from datetime import datetime, timezone
from typing import AsyncIterator, Iterator
from unittest.mock import patch

from riverqueue import Client, InsertOpts, UniqueOpts
from riverqueue.client import AsyncClient, InsertManyParams
from riverqueue.client import (
MAX_ATTEMPTS_DEFAULT,
PRIORITY_DEFAULT,
QUEUE_DEFAULT,
AsyncClient,
InsertManyParams,
)
from riverqueue.driver import riversqlalchemy
from riverqueue.driver.driver_protocol import GetParams

Expand Down Expand Up @@ -45,6 +52,25 @@ async def client_async(
return AsyncClient(driver_async)


def test_insert_job_from_row(client, driver):
insert_res = client.insert(SimpleArgs())
job = insert_res.job
assert job
assert isinstance(job.args, dict)
assert job.attempt == 0
assert job.attempted_by is None
assert job.created_at.tzinfo == timezone.utc
assert job.errors is None
assert job.kind == "simple"
assert job.max_attempts == MAX_ATTEMPTS_DEFAULT
assert isinstance(job.metadata, dict)
assert job.priority == PRIORITY_DEFAULT
assert job.queue == QUEUE_DEFAULT
assert job.scheduled_at.tzinfo == timezone.utc
assert job.state == JobState.AVAILABLE
assert job.tags == []


def test_insert_with_only_args_sync(client, driver):
insert_res = client.insert(SimpleArgs())
assert insert_res.job
Expand Down

0 comments on commit a258b1f

Please sign in to comment.