diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 6213ff1f..392d889a 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -17,6 +17,9 @@ class TaskQueueConfig(BaseModel): """The configuration for the task queue deployment.""" app_name: str = Field(description="The name of the Aana app") + retryable_exceptions: list[str] = Field( + description="The list of exceptions that should be retried" + ) @serve.deployment @@ -63,6 +66,29 @@ def __del__(self): progress=0, ) + async def app_health_check(self) -> bool: + """Check the health of the app. + + The app is considered healthy if for every deployment, at least 50% of the replicas are running. + The reason for this is that even if some replicas are not running, the app can still can process requests. + And in the cluster setup, it is possible that some replicas on other nodes are not running or just starting up + and it is not a reason to consider the app unhealthy. + + Returns: + bool: True if the app is healthy, False otherwise + """ + serve_status = serve.status() + for app in serve_status.applications.values(): + for deployment in app.deployments.values(): + num_replicas = sum(deployment.replica_states.values()) + if num_replicas > 0: + health_ratio = ( + deployment.replica_states.get("RUNNING", 0) / num_replicas + ) + if health_ratio < 0.5: + return False + return True + async def apply_config(self, config: dict[str, Any]): """Apply the configuration. @@ -72,6 +98,7 @@ async def apply_config(self, config: dict[str, Any]): """ config_obj = TaskQueueConfig(**config) self.app_name = config_obj.app_name + self.retryable_exceptions = config_obj.retryable_exceptions async def loop(self): # noqa: C901 """The main loop for the task queue deployment. @@ -79,11 +106,22 @@ async def loop(self): # noqa: C901 The loop will check the queue and assign tasks to workers. """ handle = None + app_health_check_attempts = 0 configuration_attempts = 0 full_queue_attempts = 0 no_tasks_attempts = 0 while True: + # Check the health of the app + app_health = await self.app_health_check() + if not app_health: + # If the app is not healthy, wait and retry + await sleep_exponential_backoff(1.0, 5.0, app_health_check_attempts) + app_health_check_attempts += 1 + continue + else: + app_health_check_attempts = 0 + if not self._configured: # Wait for the deployment to be configured. await sleep_exponential_backoff(1.0, 5.0, configuration_attempts) @@ -141,7 +179,9 @@ async def loop(self): # noqa: C901 self.running_task_ids ) tasks = TaskRepository(session).fetch_unprocessed_tasks( - limit=num_tasks_to_assign + limit=num_tasks_to_assign, + max_retries=aana_settings.task_queue.max_retries, + retryable_exceptions=self.retryable_exceptions, ) # If there are no tasks, wait and retry diff --git a/aana/sdk.py b/aana/sdk.py index bd12fcec..a49317e6 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -23,6 +23,7 @@ DeploymentException, EmptyMigrationsException, FailedDeployment, + InferenceException, InsufficientResources, ) from aana.storage.op import run_alembic_migrations @@ -32,18 +33,36 @@ class AanaSDK: """Aana SDK to deploy and manage Aana deployments and endpoints.""" - def __init__(self, name: str = "app", migration_func: Callable | None = None): + def __init__( + self, + name: str = "app", + migration_func: Callable | None = None, + retryable_exceptions: list[Exception, str] | None = None, + ): """Aana SDK to deploy and manage Aana deployments and endpoints. Args: name (str, optional): The name of the application. Defaults to "app". migration_func (Callable | None): The migration function to run. Defaults to None. + retryable_exceptions (list[Exception, str] | None): The exceptions that can be retried in the task queue. + Defaults to ['InferenceException']. """ self.name = name self.migration_func = migration_func self.endpoints: dict[str, Endpoint] = {} self.deployments: dict[str, Deployment] = {} + if retryable_exceptions is None: + self.retryable_exceptions = [InferenceException] + else: + self.retryable_exceptions = retryable_exceptions + # Convert exceptions to string if they are not already + # to avoid serialization issues + self.retryable_exceptions = [ + exc if isinstance(exc, str) else exc.__name__ + for exc in self.retryable_exceptions + ] + if aana_settings.task_queue.enabled: self.add_task_queue(deploy=False) @@ -151,6 +170,7 @@ def add_task_queue(self, deploy: bool = False): num_replicas=1, user_config=TaskQueueConfig( app_name=self.name, + retryable_exceptions=self.retryable_exceptions, ).model_dump(mode="json"), ) self.register_deployment( diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 0ff0d48e..e3d3bf07 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -2,7 +2,7 @@ from typing import Any from uuid import UUID -from sqlalchemy import and_, desc +from sqlalchemy import and_, desc, or_, text from sqlalchemy.orm import Session from aana.storage.models.task import Status as TaskStatus @@ -67,7 +67,12 @@ def save(self, endpoint: str, data: Any, priority: int = 0): self.session.commit() return task - def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: + def fetch_unprocessed_tasks( + self, + limit: int | None = None, + max_retries: int = 1, + retryable_exceptions: list[str] | None = None, + ) -> list[TaskEntity]: """Fetches unprocessed tasks and marks them as ASSIGNED. The task is considered unprocessed if it is in CREATED or NOT_FINISHED state. @@ -81,21 +86,58 @@ def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: Args: limit (int | None): The maximum number of tasks to fetch. If None, fetch all. + max_retries (int): The maximum number of retries for a task. + retryable_exceptions (list[str] | None): The list of exceptions that should be retried. Returns: list[TaskEntity]: the unprocessed tasks. """ - tasks = ( - self.session.query(TaskEntity) - .filter( - TaskEntity.status.in_([TaskStatus.CREATED, TaskStatus.NOT_FINISHED]) + if retryable_exceptions: + # Convert the list of exceptions to a string for the query: + # e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'" + exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions]) + if self.session.bind.dialect.name == "postgresql": + exception_name_query = f"result->>'error' IN ({exceptions_str})" + elif self.session.bind.dialect.name == "sqlite": + exception_name_query = ( + f"json_extract(result, '$.error') IN ({exceptions_str})" + ) + + tasks = ( + self.session.query(TaskEntity) + .filter( + or_( + TaskEntity.status.in_( + [TaskStatus.CREATED, TaskStatus.NOT_FINISHED] + ), + and_( + TaskEntity.status == TaskStatus.FAILED, + text(exception_name_query), + TaskEntity.num_retries < max_retries, + ), + ) + ) + .order_by(desc(TaskEntity.priority), TaskEntity.created_at) + .limit(limit) + .populate_existing() + .with_for_update(skip_locked=True) + .all() ) - .order_by(desc(TaskEntity.priority), TaskEntity.created_at) - .limit(limit) - .populate_existing() - .with_for_update(skip_locked=True) - .all() - ) + else: + tasks = ( + self.session.query(TaskEntity) + .filter( + TaskEntity.status.in_( + [TaskStatus.CREATED, TaskStatus.NOT_FINISHED] + ), + ) + .order_by(desc(TaskEntity.priority), TaskEntity.created_at) + .limit(limit) + .populate_existing() + .with_for_update(skip_locked=True) + .all() + ) + for task in tasks: self.update_status( task_id=task.id, diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py index 026da902..521b0dee 100644 --- a/aana/tests/db/datastore/test_task_repo.py +++ b/aana/tests/db/datastore/test_task_repo.py @@ -81,33 +81,62 @@ def _create_sample_tasks(): priority=2, created_at=now - timedelta(hours=3), ) + task5 = TaskEntity( + endpoint="/test5", + data={"test": "data5"}, + status=TaskStatus.FAILED, + priority=1, + created_at=now - timedelta(minutes=1), + result={"error": "InferenceException"}, + ) + task6 = TaskEntity( + endpoint="/test6", + data={"test": "data6"}, + status=TaskStatus.RUNNING, + priority=3, + created_at=now - timedelta(minutes=2), + ) + task7 = TaskEntity( + endpoint="/test7", + data={"test": "data7"}, + status=TaskStatus.FAILED, + priority=1, + created_at=now - timedelta(minutes=3), + result={"error": "NonRecoverableError"}, + ) - db_session.add_all([task1, task2, task3, task4]) + db_session.add_all([task1, task2, task3, task4, task5, task6, task7]) db_session.commit() - return task1, task2, task3, task4 + return task1, task2, task3, task4, task5, task6, task7 # Create sample tasks - task1, task2, task3, task4 = _create_sample_tasks() + task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks() # Fetch unprocessed tasks without any limit - unprocessed_tasks = task_repo.fetch_unprocessed_tasks() + unprocessed_tasks = task_repo.fetch_unprocessed_tasks( + max_retries=3, retryable_exceptions=["InferenceException"] + ) # Assert that only tasks with CREATED and NOT_FINISHED status are returned - assert len(unprocessed_tasks) == 3 + assert len(unprocessed_tasks) == 4 assert task1 in unprocessed_tasks assert task2 in unprocessed_tasks assert task4 in unprocessed_tasks + assert task5 in unprocessed_tasks # Ensure tasks are ordered by priority and then by created_at assert unprocessed_tasks[0].id == task4.id # Highest priority assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task assert unprocessed_tasks[2].id == task1.id # Lowest priority + assert unprocessed_tasks[3].id == task5.id # Highest priority, but older # Create sample tasks - task1, task2, task3, task4 = _create_sample_tasks() + task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks() # Fetch unprocessed tasks with a limit - limited_tasks = task_repo.fetch_unprocessed_tasks(limit=2) + limited_tasks = task_repo.fetch_unprocessed_tasks( + limit=2, max_retries=3, retryable_exceptions=["InferenceException"] + ) # Assert that only the specified number of tasks is returned assert len(limited_tasks) == 2