From 83fd94cb279d50b8c349791f50966c6b46529ef5 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 31 Oct 2024 16:15:43 +0000 Subject: [PATCH 1/3] Add health check and task heartbeat functionality to RequestHandler --- aana/api/request_handler.py | 21 ++++++++++++++++----- aana/storage/repository/task.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 3e2728cf..b8068a09 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,7 +1,7 @@ import json import time from typing import Annotated, Any -from uuid import uuid4 +from uuid import UUID, uuid4 import orjson import ray @@ -68,6 +68,7 @@ def __init__( app.openapi = self.custom_openapi self.ready = True + self.running_tasks = set() def custom_openapi(self) -> dict[str, Any]: """Returns OpenAPI schema, generating it if necessary.""" @@ -95,16 +96,25 @@ async def is_ready(self): """ return AanaJSONResponse(content={"ready": self.ready}) - async def execute_task(self, task_id: str) -> Any: + async def check_health(self): + """Check the health of the application.""" + # Heartbeat for the running tasks + with get_session() as session: + task_repo = TaskRepository(session) + task_repo.heartbeat(self.running_tasks) + + async def execute_task(self, task_id: str | UUID) -> Any: """Execute a task. Args: - task_id (str): The task ID. + task_id (str | UUID): The ID of the task. Returns: Any: The response from the endpoint. """ try: + print(f"Executing task {task_id}, type: {type(task_id)}") + self.running_tasks.add(task_id) with get_session() as session: task_repo = TaskRepository(session) task = task_repo.read(task_id) @@ -139,8 +149,9 @@ async def execute_task(self, task_id: str) -> Any: TaskRepository(session).update_status( task_id, TaskStatus.FAILED, 0, error ) - else: - return out + finally: + self.running_tasks.remove(task_id) + return out @app.get( "/tasks/get/{task_id}", diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index e3d3bf07..0705246f 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -285,3 +285,20 @@ def update_expired_tasks( ) self.session.commit() return tasks + + def heartbeat(self, task_ids: list[str] | set[str]): + """Updates the updated_at timestamp for multiple tasks. + + Args: + task_ids (list[str] | set[str]): List or set of task IDs to update + """ + print(f"Heartbeat: {task_ids}") + task_ids = [ + UUID(task_id) if isinstance(task_id, str) else task_id + for task_id in task_ids + ] + self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update( + {TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005 + synchronize_session=False, + ) + self.session.commit() From b59272d9699b7375eaef719bd56e822cff0dc217 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Fri, 1 Nov 2024 13:28:44 +0000 Subject: [PATCH 2/3] Add heartbeat timeout to task management and update expired task logic --- aana/api/request_handler.py | 1 - aana/configs/settings.py | 2 ++ aana/deployments/task_queue_deployment.py | 6 ++-- aana/sdk.py | 1 - aana/storage/repository/task.py | 35 ++++++++++++++++------- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index b8068a09..1b3770ae 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -113,7 +113,6 @@ async def execute_task(self, task_id: str | UUID) -> Any: Any: The response from the endpoint. """ try: - print(f"Executing task {task_id}, type: {type(task_id)}") self.running_tasks.add(task_id) with get_session() as session: task_repo = TaskRepository(session) diff --git a/aana/configs/settings.py b/aana/configs/settings.py index ec161ac6..a2ef7a6b 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -28,12 +28,14 @@ class TaskQueueSettings(BaseModel): execution_timeout (int): The maximum execution time for a task in seconds. After this time, if the task is still running, it will be considered as stuck and will be reassign to another worker. + heartbeat_timeout (int): The maximum time between heartbeats in seconds. max_retries (int): The maximum number of retries for a task. """ enabled: bool = True num_workers: int = 4 execution_timeout: int = 600 + heartbeat_timeout: int = 60 max_retries: int = 3 diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 392d889a..1d09ad42 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -137,10 +137,10 @@ async def loop(self): # noqa: C901 ) # Check for expired tasks - execution_timeout = aana_settings.task_queue.execution_timeout - max_retries = aana_settings.task_queue.max_retries expired_tasks = TaskRepository(session).update_expired_tasks( - execution_timeout=execution_timeout, max_retries=max_retries + execution_timeout=aana_settings.task_queue.execution_timeout, + heartbeat_timeout=aana_settings.task_queue.heartbeat_timeout, + max_retries=aana_settings.task_queue.max_retries, ) for task in expired_tasks: deployment_response = self.deployment_responses.get(task.id) diff --git a/aana/sdk.py b/aana/sdk.py index ea9554f1..f25a765f 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -23,7 +23,6 @@ DeploymentException, EmptyMigrationsException, FailedDeployment, - InferenceException, InsufficientResources, ) from aana.storage.op import run_alembic_migrations diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 0705246f..670d00f6 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -223,7 +223,7 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]: return incomplete_task_ids def update_expired_tasks( - self, execution_timeout: float, max_retries: int + self, execution_timeout: float, heartbeat_timeout: float, max_retries: int ) -> list[TaskEntity]: """Fetches all tasks that are expired and updates their status. @@ -243,18 +243,23 @@ def update_expired_tasks( Args: execution_timeout (float): The maximum execution time for a task in seconds + heartbeat_timeout (float): The maximum time since the last heartbeat in seconds max_retries (int): The maximum number of retries for a task Returns: list[TaskEntity]: the expired tasks. """ - cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005 tasks = ( self.session.query(TaskEntity) .filter( and_( TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]), - TaskEntity.updated_at <= cutoff_time, + or_( + TaskEntity.updated_at <= timeout_cutoff, + TaskEntity.updated_at <= heartbeat_cutoff, + ), ), ) .populate_existing() @@ -263,17 +268,28 @@ def update_expired_tasks( ) for task in tasks: if task.num_retries >= max_retries: - self.update_status( - task_id=task.id, - status=TaskStatus.FAILED, - progress=0, - result={ + if task.updated_at <= timeout_cutoff: + result = { "error": "TimeoutError", "message": ( f"Task execution timed out after {execution_timeout} seconds and " f"exceeded the maximum number of retries ({max_retries})" ), - }, + } + else: + result = { + "error": "HeartbeatTimeoutError", + "message": ( + f"The task has not received a heartbeat for {heartbeat_timeout} seconds and " + f"exceeded the maximum number of retries ({max_retries})" + ), + } + + self.update_status( + task_id=task.id, + status=TaskStatus.FAILED, + progress=0, + result=result, commit=False, ) else: @@ -292,7 +308,6 @@ def heartbeat(self, task_ids: list[str] | set[str]): Args: task_ids (list[str] | set[str]): List or set of task IDs to update """ - print(f"Heartbeat: {task_ids}") task_ids = [ UUID(task_id) if isinstance(task_id, str) else task_id for task_id in task_ids From bd54ad52b0e75ff34d2bbb50e040806f317d4078 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Fri, 1 Nov 2024 14:34:20 +0000 Subject: [PATCH 3/3] Update task expiration logic to use assigned_at timestamp for execution timeout and add heartbeat timeout to tests --- aana/storage/repository/task.py | 4 ++-- aana/tests/db/datastore/test_task_repo.py | 28 ++++++++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 670d00f6..74971013 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -257,7 +257,7 @@ def update_expired_tasks( and_( TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]), or_( - TaskEntity.updated_at <= timeout_cutoff, + TaskEntity.assigned_at <= timeout_cutoff, TaskEntity.updated_at <= heartbeat_cutoff, ), ), @@ -268,7 +268,7 @@ def update_expired_tasks( ) for task in tasks: if task.num_retries >= max_retries: - if task.updated_at <= timeout_cutoff: + if task.assigned_at <= timeout_cutoff: result = { "error": "TimeoutError", "message": ( diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py index 521b0dee..f8ab11ab 100644 --- a/aana/tests/db/datastore/test_task_repo.py +++ b/aana/tests/db/datastore/test_task_repo.py @@ -293,49 +293,65 @@ def test_update_expired_tasks(db_session): # Set up current time and a cutoff time current_time = datetime.now() # noqa: DTZ005 execution_timeout = 3600 # 1 hour in seconds + heartbeat_timeout = 60 # 1 minute in seconds # Create tasks with different updated_at times and statuses task1 = TaskEntity( endpoint="/task1", data={"test": "data1"}, status=TaskStatus.RUNNING, - updated_at=current_time - timedelta(hours=2), + assigned_at=current_time - timedelta(hours=2), + updated_at=current_time - timedelta(seconds=10), ) task2 = TaskEntity( endpoint="/task2", data={"test": "data2"}, status=TaskStatus.ASSIGNED, - updated_at=current_time - timedelta(seconds=2), + assigned_at=current_time - timedelta(seconds=2), + updated_at=current_time - timedelta(seconds=5), ) task3 = TaskEntity( endpoint="/task3", data={"test": "data3"}, status=TaskStatus.RUNNING, + assigned_at=current_time - timedelta(seconds=2), updated_at=current_time, ) task4 = TaskEntity( endpoint="/task4", data={"test": "data4"}, status=TaskStatus.COMPLETED, + assigned_at=current_time - timedelta(hours=1), updated_at=current_time - timedelta(hours=2), ) task5 = TaskEntity( endpoint="/task5", data={"test": "data5"}, status=TaskStatus.FAILED, + assigned_at=current_time - timedelta(minutes=1), updated_at=current_time - timedelta(seconds=4), ) + task6 = TaskEntity( + endpoint="/task6", + data={"test": "data6"}, + status=TaskStatus.RUNNING, + assigned_at=current_time - timedelta(minutes=3), + updated_at=current_time - timedelta(minutes=2), + ) - db_session.add_all([task1, task2, task3, task4, task5]) + db_session.add_all([task1, task2, task3, task4, task5, task6]) db_session.commit() # Fetch expired tasks expired_tasks = task_repo.update_expired_tasks( - execution_timeout=execution_timeout, max_retries=3 + execution_timeout=execution_timeout, + heartbeat_timeout=heartbeat_timeout, + max_retries=3, ) - # Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned - expected_task_ids = {str(task1.id)} + # Assert that only tasks with RUNNING or ASSIGNED status and an assigned_at time older than the execution_timeout or + # heartbeat_timeout are returned + expected_task_ids = {str(task1.id), str(task6.id)} returned_task_ids = {str(task.id) for task in expired_tasks} assert returned_task_ids == expected_task_ids