Skip to content

Commit

Permalink
Added retry on certain exception to the task queue.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Oct 30, 2024
1 parent 91de5b5 commit 8be6fa2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 14 deletions.
8 changes: 7 additions & 1 deletion aana/deployments/task_queue_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TaskQueueConfig(BaseModel):
"""The configuration for the task queue deployment."""

app_name: str = Field(description="The name of the Aana app")
retryable_exceptions: list[str] = Field(
description="The list of exceptions that should be retried"
)


@serve.deployment
Expand Down Expand Up @@ -72,6 +75,7 @@ async def apply_config(self, config: dict[str, Any]):
"""
config_obj = TaskQueueConfig(**config)
self.app_name = config_obj.app_name
self.retryable_exceptions = config_obj.retryable_exceptions

async def loop(self): # noqa: C901
"""The main loop for the task queue deployment.
Expand Down Expand Up @@ -141,7 +145,9 @@ async def loop(self): # noqa: C901
self.running_task_ids
)
tasks = TaskRepository(session).fetch_unprocessed_tasks(
limit=num_tasks_to_assign
limit=num_tasks_to_assign,
max_retries=aana_settings.task_queue.max_retries,
retryable_exceptions=self.retryable_exceptions,
)

# If there are no tasks, wait and retry
Expand Down
22 changes: 21 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DeploymentException,
EmptyMigrationsException,
FailedDeployment,
InferenceException,
InsufficientResources,
)
from aana.storage.op import run_alembic_migrations
Expand All @@ -32,18 +33,36 @@
class AanaSDK:
"""Aana SDK to deploy and manage Aana deployments and endpoints."""

def __init__(self, name: str = "app", migration_func: Callable | None = None):
def __init__(
self,
name: str = "app",
migration_func: Callable | None = None,
retryable_exceptions: list[Exception, str] | None = None,
):
"""Aana SDK to deploy and manage Aana deployments and endpoints.
Args:
name (str, optional): The name of the application. Defaults to "app".
migration_func (Callable | None): The migration function to run. Defaults to None.
retryable_exceptions (list[Exception, str] | None): The exceptions that can be retried in the task queue.
Defaults to ['InferenceException'].
"""
self.name = name
self.migration_func = migration_func
self.endpoints: dict[str, Endpoint] = {}
self.deployments: dict[str, Deployment] = {}

if retryable_exceptions is None:
self.retryable_exceptions = [InferenceException]
else:
self.retryable_exceptions = retryable_exceptions
# Convert exceptions to string if they are not already
# to avoid serialization issues
self.retryable_exceptions = [
exc if isinstance(exc, str) else exc.__name__
for exc in self.retryable_exceptions
]

if aana_settings.task_queue.enabled:
self.add_task_queue(deploy=False)

Expand Down Expand Up @@ -151,6 +170,7 @@ def add_task_queue(self, deploy: bool = False):
num_replicas=1,
user_config=TaskQueueConfig(
app_name=self.name,
retryable_exceptions=self.retryable_exceptions,
).model_dump(mode="json"),
)
self.register_deployment(
Expand Down
59 changes: 47 additions & 12 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any
from uuid import UUID

from sqlalchemy import and_, desc
from sqlalchemy import and_, desc, or_, text
from sqlalchemy.orm import Session

from aana.storage.models.task import Status as TaskStatus
Expand Down Expand Up @@ -67,7 +67,12 @@ def save(self, endpoint: str, data: Any, priority: int = 0):
self.session.commit()
return task

def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
def fetch_unprocessed_tasks(
self,
limit: int | None = None,
max_retries: int = 1,
retryable_exceptions: list[str] | None = None,
) -> list[TaskEntity]:
"""Fetches unprocessed tasks and marks them as ASSIGNED.
The task is considered unprocessed if it is in CREATED or NOT_FINISHED state.
Expand All @@ -81,21 +86,51 @@ def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
Args:
limit (int | None): The maximum number of tasks to fetch. If None, fetch all.
max_retries (int): The maximum number of retries for a task.
retryable_exceptions (list[str] | None): The list of exceptions that should be retried.
Returns:
list[TaskEntity]: the unprocessed tasks.
"""
tasks = (
self.session.query(TaskEntity)
.filter(
TaskEntity.status.in_([TaskStatus.CREATED, TaskStatus.NOT_FINISHED])
if retryable_exceptions:
# Convert the list of exceptions to a string for the query:
# e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'"
exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions])
tasks = (
self.session.query(TaskEntity)
.filter(
or_(
TaskEntity.status.in_(
[TaskStatus.CREATED, TaskStatus.NOT_FINISHED]
),
and_(
TaskEntity.status == TaskStatus.FAILED,
text(f"result->>'error' IN ({exceptions_str})"),
TaskEntity.num_retries < max_retries,
),
)
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
else:
tasks = (
self.session.query(TaskEntity)
.filter(
TaskEntity.status.in_(
[TaskStatus.CREATED, TaskStatus.NOT_FINISHED]
),
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)

for task in tasks:
self.update_status(
task_id=task.id,
Expand Down

0 comments on commit 8be6fa2

Please sign in to comment.