Skip to content

Commit

Permalink
Merge pull request #194 from mobiusml/task_queue_reliablity_improvement
Browse files Browse the repository at this point in the history
Task Queue Reliability Enhancements
  • Loading branch information
movchan74 authored Oct 30, 2024
2 parents da5a42a + 939694b commit 6781b50
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 21 deletions.
42 changes: 41 additions & 1 deletion aana/deployments/task_queue_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class TaskQueueConfig(BaseModel):
"""The configuration for the task queue deployment."""

app_name: str = Field(description="The name of the Aana app")
retryable_exceptions: list[str] = Field(
description="The list of exceptions that should be retried"
)


@serve.deployment
Expand Down Expand Up @@ -63,6 +66,29 @@ def __del__(self):
progress=0,
)

async def app_health_check(self) -> bool:
"""Check the health of the app.
The app is considered healthy if for every deployment, at least 50% of the replicas are running.
The reason for this is that even if some replicas are not running, the app can still can process requests.
And in the cluster setup, it is possible that some replicas on other nodes are not running or just starting up
and it is not a reason to consider the app unhealthy.
Returns:
bool: True if the app is healthy, False otherwise
"""
serve_status = serve.status()
for app in serve_status.applications.values():
for deployment in app.deployments.values():
num_replicas = sum(deployment.replica_states.values())
if num_replicas > 0:
health_ratio = (
deployment.replica_states.get("RUNNING", 0) / num_replicas
)
if health_ratio < 0.5:
return False
return True

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
Expand All @@ -72,18 +98,30 @@ async def apply_config(self, config: dict[str, Any]):
"""
config_obj = TaskQueueConfig(**config)
self.app_name = config_obj.app_name
self.retryable_exceptions = config_obj.retryable_exceptions

async def loop(self): # noqa: C901
"""The main loop for the task queue deployment.
The loop will check the queue and assign tasks to workers.
"""
handle = None
app_health_check_attempts = 0
configuration_attempts = 0
full_queue_attempts = 0
no_tasks_attempts = 0

while True:
# Check the health of the app
app_health = await self.app_health_check()
if not app_health:
# If the app is not healthy, wait and retry
await sleep_exponential_backoff(1.0, 5.0, app_health_check_attempts)
app_health_check_attempts += 1
continue
else:
app_health_check_attempts = 0

if not self._configured:
# Wait for the deployment to be configured.
await sleep_exponential_backoff(1.0, 5.0, configuration_attempts)
Expand Down Expand Up @@ -141,7 +179,9 @@ async def loop(self): # noqa: C901
self.running_task_ids
)
tasks = TaskRepository(session).fetch_unprocessed_tasks(
limit=num_tasks_to_assign
limit=num_tasks_to_assign,
max_retries=aana_settings.task_queue.max_retries,
retryable_exceptions=self.retryable_exceptions,
)

# If there are no tasks, wait and retry
Expand Down
22 changes: 21 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DeploymentException,
EmptyMigrationsException,
FailedDeployment,
InferenceException,
InsufficientResources,
)
from aana.storage.op import run_alembic_migrations
Expand All @@ -32,18 +33,36 @@
class AanaSDK:
"""Aana SDK to deploy and manage Aana deployments and endpoints."""

def __init__(self, name: str = "app", migration_func: Callable | None = None):
def __init__(
self,
name: str = "app",
migration_func: Callable | None = None,
retryable_exceptions: list[Exception, str] | None = None,
):
"""Aana SDK to deploy and manage Aana deployments and endpoints.
Args:
name (str, optional): The name of the application. Defaults to "app".
migration_func (Callable | None): The migration function to run. Defaults to None.
retryable_exceptions (list[Exception, str] | None): The exceptions that can be retried in the task queue.
Defaults to ['InferenceException'].
"""
self.name = name
self.migration_func = migration_func
self.endpoints: dict[str, Endpoint] = {}
self.deployments: dict[str, Deployment] = {}

if retryable_exceptions is None:
self.retryable_exceptions = [InferenceException]
else:
self.retryable_exceptions = retryable_exceptions
# Convert exceptions to string if they are not already
# to avoid serialization issues
self.retryable_exceptions = [
exc if isinstance(exc, str) else exc.__name__
for exc in self.retryable_exceptions
]

if aana_settings.task_queue.enabled:
self.add_task_queue(deploy=False)

Expand Down Expand Up @@ -151,6 +170,7 @@ def add_task_queue(self, deploy: bool = False):
num_replicas=1,
user_config=TaskQueueConfig(
app_name=self.name,
retryable_exceptions=self.retryable_exceptions,
).model_dump(mode="json"),
)
self.register_deployment(
Expand Down
66 changes: 54 additions & 12 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any
from uuid import UUID

from sqlalchemy import and_, desc
from sqlalchemy import and_, desc, or_, text
from sqlalchemy.orm import Session

from aana.storage.models.task import Status as TaskStatus
Expand Down Expand Up @@ -67,7 +67,12 @@ def save(self, endpoint: str, data: Any, priority: int = 0):
self.session.commit()
return task

def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
def fetch_unprocessed_tasks(
self,
limit: int | None = None,
max_retries: int = 1,
retryable_exceptions: list[str] | None = None,
) -> list[TaskEntity]:
"""Fetches unprocessed tasks and marks them as ASSIGNED.
The task is considered unprocessed if it is in CREATED or NOT_FINISHED state.
Expand All @@ -81,21 +86,58 @@ def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
Args:
limit (int | None): The maximum number of tasks to fetch. If None, fetch all.
max_retries (int): The maximum number of retries for a task.
retryable_exceptions (list[str] | None): The list of exceptions that should be retried.
Returns:
list[TaskEntity]: the unprocessed tasks.
"""
tasks = (
self.session.query(TaskEntity)
.filter(
TaskEntity.status.in_([TaskStatus.CREATED, TaskStatus.NOT_FINISHED])
if retryable_exceptions:
# Convert the list of exceptions to a string for the query:
# e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'"
exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions])
if self.session.bind.dialect.name == "postgresql":
exception_name_query = f"result->>'error' IN ({exceptions_str})"
elif self.session.bind.dialect.name == "sqlite":
exception_name_query = (
f"json_extract(result, '$.error') IN ({exceptions_str})"
)

tasks = (
self.session.query(TaskEntity)
.filter(
or_(
TaskEntity.status.in_(
[TaskStatus.CREATED, TaskStatus.NOT_FINISHED]
),
and_(
TaskEntity.status == TaskStatus.FAILED,
text(exception_name_query),
TaskEntity.num_retries < max_retries,
),
)
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
else:
tasks = (
self.session.query(TaskEntity)
.filter(
TaskEntity.status.in_(
[TaskStatus.CREATED, TaskStatus.NOT_FINISHED]
),
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)

for task in tasks:
self.update_status(
task_id=task.id,
Expand Down
43 changes: 36 additions & 7 deletions aana/tests/db/datastore/test_task_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,62 @@ def _create_sample_tasks():
priority=2,
created_at=now - timedelta(hours=3),
)
task5 = TaskEntity(
endpoint="/test5",
data={"test": "data5"},
status=TaskStatus.FAILED,
priority=1,
created_at=now - timedelta(minutes=1),
result={"error": "InferenceException"},
)
task6 = TaskEntity(
endpoint="/test6",
data={"test": "data6"},
status=TaskStatus.RUNNING,
priority=3,
created_at=now - timedelta(minutes=2),
)
task7 = TaskEntity(
endpoint="/test7",
data={"test": "data7"},
status=TaskStatus.FAILED,
priority=1,
created_at=now - timedelta(minutes=3),
result={"error": "NonRecoverableError"},
)

db_session.add_all([task1, task2, task3, task4])
db_session.add_all([task1, task2, task3, task4, task5, task6, task7])
db_session.commit()
return task1, task2, task3, task4
return task1, task2, task3, task4, task5, task6, task7

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()
task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks()

# Fetch unprocessed tasks without any limit
unprocessed_tasks = task_repo.fetch_unprocessed_tasks()
unprocessed_tasks = task_repo.fetch_unprocessed_tasks(
max_retries=3, retryable_exceptions=["InferenceException"]
)

# Assert that only tasks with CREATED and NOT_FINISHED status are returned
assert len(unprocessed_tasks) == 3
assert len(unprocessed_tasks) == 4
assert task1 in unprocessed_tasks
assert task2 in unprocessed_tasks
assert task4 in unprocessed_tasks
assert task5 in unprocessed_tasks

# Ensure tasks are ordered by priority and then by created_at
assert unprocessed_tasks[0].id == task4.id # Highest priority
assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task
assert unprocessed_tasks[2].id == task1.id # Lowest priority
assert unprocessed_tasks[3].id == task5.id # Highest priority, but older

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()
task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks()

# Fetch unprocessed tasks with a limit
limited_tasks = task_repo.fetch_unprocessed_tasks(limit=2)
limited_tasks = task_repo.fetch_unprocessed_tasks(
limit=2, max_retries=3, retryable_exceptions=["InferenceException"]
)

# Assert that only the specified number of tasks is returned
assert len(limited_tasks) == 2
Expand Down

0 comments on commit 6781b50

Please sign in to comment.