diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 6213ff1f..eef68cdd 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -17,6 +17,9 @@ class TaskQueueConfig(BaseModel): """The configuration for the task queue deployment.""" app_name: str = Field(description="The name of the Aana app") + retryable_exceptions: list[str] = Field( + description="The list of exceptions that should be retried" + ) @serve.deployment @@ -72,6 +75,7 @@ async def apply_config(self, config: dict[str, Any]): """ config_obj = TaskQueueConfig(**config) self.app_name = config_obj.app_name + self.retryable_exceptions = config_obj.retryable_exceptions async def loop(self): # noqa: C901 """The main loop for the task queue deployment. @@ -141,7 +145,9 @@ async def loop(self): # noqa: C901 self.running_task_ids ) tasks = TaskRepository(session).fetch_unprocessed_tasks( - limit=num_tasks_to_assign + limit=num_tasks_to_assign, + max_retries=aana_settings.task_queue.max_retries, + retryable_exceptions=self.retryable_exceptions, ) # If there are no tasks, wait and retry diff --git a/aana/sdk.py b/aana/sdk.py index bd12fcec..a49317e6 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -23,6 +23,7 @@ DeploymentException, EmptyMigrationsException, FailedDeployment, + InferenceException, InsufficientResources, ) from aana.storage.op import run_alembic_migrations @@ -32,18 +33,36 @@ class AanaSDK: """Aana SDK to deploy and manage Aana deployments and endpoints.""" - def __init__(self, name: str = "app", migration_func: Callable | None = None): + def __init__( + self, + name: str = "app", + migration_func: Callable | None = None, + retryable_exceptions: list[Exception, str] | None = None, + ): """Aana SDK to deploy and manage Aana deployments and endpoints. Args: name (str, optional): The name of the application. Defaults to "app". migration_func (Callable | None): The migration function to run. Defaults to None. + retryable_exceptions (list[Exception, str] | None): The exceptions that can be retried in the task queue. + Defaults to ['InferenceException']. """ self.name = name self.migration_func = migration_func self.endpoints: dict[str, Endpoint] = {} self.deployments: dict[str, Deployment] = {} + if retryable_exceptions is None: + self.retryable_exceptions = [InferenceException] + else: + self.retryable_exceptions = retryable_exceptions + # Convert exceptions to string if they are not already + # to avoid serialization issues + self.retryable_exceptions = [ + exc if isinstance(exc, str) else exc.__name__ + for exc in self.retryable_exceptions + ] + if aana_settings.task_queue.enabled: self.add_task_queue(deploy=False) @@ -151,6 +170,7 @@ def add_task_queue(self, deploy: bool = False): num_replicas=1, user_config=TaskQueueConfig( app_name=self.name, + retryable_exceptions=self.retryable_exceptions, ).model_dump(mode="json"), ) self.register_deployment( diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 0ff0d48e..7923f439 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -2,7 +2,7 @@ from typing import Any from uuid import UUID -from sqlalchemy import and_, desc +from sqlalchemy import and_, desc, or_, text from sqlalchemy.orm import Session from aana.storage.models.task import Status as TaskStatus @@ -67,7 +67,12 @@ def save(self, endpoint: str, data: Any, priority: int = 0): self.session.commit() return task - def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: + def fetch_unprocessed_tasks( + self, + limit: int | None = None, + max_retries: int = 1, + retryable_exceptions: list[str] | None = None, + ) -> list[TaskEntity]: """Fetches unprocessed tasks and marks them as ASSIGNED. The task is considered unprocessed if it is in CREATED or NOT_FINISHED state. @@ -81,21 +86,51 @@ def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: Args: limit (int | None): The maximum number of tasks to fetch. If None, fetch all. + max_retries (int): The maximum number of retries for a task. + retryable_exceptions (list[str] | None): The list of exceptions that should be retried. Returns: list[TaskEntity]: the unprocessed tasks. """ - tasks = ( - self.session.query(TaskEntity) - .filter( - TaskEntity.status.in_([TaskStatus.CREATED, TaskStatus.NOT_FINISHED]) + if retryable_exceptions: + # Convert the list of exceptions to a string for the query: + # e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'" + exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions]) + tasks = ( + self.session.query(TaskEntity) + .filter( + or_( + TaskEntity.status.in_( + [TaskStatus.CREATED, TaskStatus.NOT_FINISHED] + ), + and_( + TaskEntity.status == TaskStatus.FAILED, + text(f"result->>'error' IN ({exceptions_str})"), + TaskEntity.num_retries < max_retries, + ), + ) + ) + .order_by(desc(TaskEntity.priority), TaskEntity.created_at) + .limit(limit) + .populate_existing() + .with_for_update(skip_locked=True) + .all() ) - .order_by(desc(TaskEntity.priority), TaskEntity.created_at) - .limit(limit) - .populate_existing() - .with_for_update(skip_locked=True) - .all() - ) + else: + tasks = ( + self.session.query(TaskEntity) + .filter( + TaskEntity.status.in_( + [TaskStatus.CREATED, TaskStatus.NOT_FINISHED] + ), + ) + .order_by(desc(TaskEntity.priority), TaskEntity.created_at) + .limit(limit) + .populate_existing() + .with_for_update(skip_locked=True) + .all() + ) + for task in tasks: self.update_status( task_id=task.id,