diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 3e2728cf..b8068a09 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,7 +1,7 @@ import json import time from typing import Annotated, Any -from uuid import uuid4 +from uuid import UUID, uuid4 import orjson import ray @@ -68,6 +68,7 @@ def __init__( app.openapi = self.custom_openapi self.ready = True + self.running_tasks = set() def custom_openapi(self) -> dict[str, Any]: """Returns OpenAPI schema, generating it if necessary.""" @@ -95,16 +96,25 @@ async def is_ready(self): """ return AanaJSONResponse(content={"ready": self.ready}) - async def execute_task(self, task_id: str) -> Any: + async def check_health(self): + """Check the health of the application.""" + # Heartbeat for the running tasks + with get_session() as session: + task_repo = TaskRepository(session) + task_repo.heartbeat(self.running_tasks) + + async def execute_task(self, task_id: str | UUID) -> Any: """Execute a task. Args: - task_id (str): The task ID. + task_id (str | UUID): The ID of the task. Returns: Any: The response from the endpoint. """ try: + print(f"Executing task {task_id}, type: {type(task_id)}") + self.running_tasks.add(task_id) with get_session() as session: task_repo = TaskRepository(session) task = task_repo.read(task_id) @@ -139,8 +149,9 @@ async def execute_task(self, task_id: str) -> Any: TaskRepository(session).update_status( task_id, TaskStatus.FAILED, 0, error ) - else: - return out + finally: + self.running_tasks.remove(task_id) + return out @app.get( "/tasks/get/{task_id}", diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index e3d3bf07..0705246f 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -285,3 +285,20 @@ def update_expired_tasks( ) self.session.commit() return tasks + + def heartbeat(self, task_ids: list[str] | set[str]): + """Updates the updated_at timestamp for multiple tasks. + + Args: + task_ids (list[str] | set[str]): List or set of task IDs to update + """ + print(f"Heartbeat: {task_ids}") + task_ids = [ + UUID(task_id) if isinstance(task_id, str) else task_id + for task_id in task_ids + ] + self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update( + {TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005 + synchronize_session=False, + ) + self.session.commit()