Skip to content

Commit

Permalink
Add health check and task heartbeat functionality to RequestHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Oct 31, 2024
1 parent 9d4952e commit 83fd94c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
21 changes: 16 additions & 5 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import time
from typing import Annotated, Any
from uuid import uuid4
from uuid import UUID, uuid4

import orjson
import ray
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(

app.openapi = self.custom_openapi
self.ready = True
self.running_tasks = set()

def custom_openapi(self) -> dict[str, Any]:
"""Returns OpenAPI schema, generating it if necessary."""
Expand Down Expand Up @@ -95,16 +96,25 @@ async def is_ready(self):
"""
return AanaJSONResponse(content={"ready": self.ready})

async def execute_task(self, task_id: str) -> Any:
async def check_health(self):
"""Check the health of the application."""
# Heartbeat for the running tasks
with get_session() as session:
task_repo = TaskRepository(session)
task_repo.heartbeat(self.running_tasks)

async def execute_task(self, task_id: str | UUID) -> Any:
"""Execute a task.
Args:
task_id (str): The task ID.
task_id (str | UUID): The ID of the task.
Returns:
Any: The response from the endpoint.
"""
try:
print(f"Executing task {task_id}, type: {type(task_id)}")
self.running_tasks.add(task_id)
with get_session() as session:
task_repo = TaskRepository(session)
task = task_repo.read(task_id)
Expand Down Expand Up @@ -139,8 +149,9 @@ async def execute_task(self, task_id: str) -> Any:
TaskRepository(session).update_status(
task_id, TaskStatus.FAILED, 0, error
)
else:
return out
finally:
self.running_tasks.remove(task_id)
return out

@app.get(
"/tasks/get/{task_id}",
Expand Down
17 changes: 17 additions & 0 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,20 @@ def update_expired_tasks(
)
self.session.commit()
return tasks

def heartbeat(self, task_ids: list[str] | set[str]):
"""Updates the updated_at timestamp for multiple tasks.
Args:
task_ids (list[str] | set[str]): List or set of task IDs to update
"""
print(f"Heartbeat: {task_ids}")
task_ids = [
UUID(task_id) if isinstance(task_id, str) else task_id
for task_id in task_ids
]
self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update(
{TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005
synchronize_session=False,
)
self.session.commit()

0 comments on commit 83fd94c

Please sign in to comment.