Skip to content

Commit

Permalink
Enhance task fetching logic to support PostgreSQL and SQLite error ha…
Browse files Browse the repository at this point in the history
…ndling; add additional test cases for unprocessed tasks
  • Loading branch information
Aleksandr Movchan committed Oct 30, 2024
1 parent 9a9e615 commit 939694b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
9 changes: 8 additions & 1 deletion aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def fetch_unprocessed_tasks(
# Convert the list of exceptions to a string for the query:
# e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'"
exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions])
if self.session.bind.dialect.name == "postgresql":
exception_name_query = f"result->>'error' IN ({exceptions_str})"
elif self.session.bind.dialect.name == "sqlite":
exception_name_query = (
f"json_extract(result, '$.error') IN ({exceptions_str})"
)

tasks = (
self.session.query(TaskEntity)
.filter(
Expand All @@ -105,7 +112,7 @@ def fetch_unprocessed_tasks(
),
and_(
TaskEntity.status == TaskStatus.FAILED,
text(f"result->>'error' IN ({exceptions_str})"),
text(exception_name_query),
TaskEntity.num_retries < max_retries,
),
)
Expand Down
43 changes: 36 additions & 7 deletions aana/tests/db/datastore/test_task_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,62 @@ def _create_sample_tasks():
priority=2,
created_at=now - timedelta(hours=3),
)
task5 = TaskEntity(
endpoint="/test5",
data={"test": "data5"},
status=TaskStatus.FAILED,
priority=1,
created_at=now - timedelta(minutes=1),
result={"error": "InferenceException"},
)
task6 = TaskEntity(
endpoint="/test6",
data={"test": "data6"},
status=TaskStatus.RUNNING,
priority=3,
created_at=now - timedelta(minutes=2),
)
task7 = TaskEntity(
endpoint="/test7",
data={"test": "data7"},
status=TaskStatus.FAILED,
priority=1,
created_at=now - timedelta(minutes=3),
result={"error": "NonRecoverableError"},
)

db_session.add_all([task1, task2, task3, task4])
db_session.add_all([task1, task2, task3, task4, task5, task6, task7])
db_session.commit()
return task1, task2, task3, task4
return task1, task2, task3, task4, task5, task6, task7

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()
task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks()

# Fetch unprocessed tasks without any limit
unprocessed_tasks = task_repo.fetch_unprocessed_tasks()
unprocessed_tasks = task_repo.fetch_unprocessed_tasks(
max_retries=3, retryable_exceptions=["InferenceException"]
)

# Assert that only tasks with CREATED and NOT_FINISHED status are returned
assert len(unprocessed_tasks) == 3
assert len(unprocessed_tasks) == 4
assert task1 in unprocessed_tasks
assert task2 in unprocessed_tasks
assert task4 in unprocessed_tasks
assert task5 in unprocessed_tasks

# Ensure tasks are ordered by priority and then by created_at
assert unprocessed_tasks[0].id == task4.id # Highest priority
assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task
assert unprocessed_tasks[2].id == task1.id # Lowest priority
assert unprocessed_tasks[3].id == task5.id # Highest priority, but older

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()
task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks()

# Fetch unprocessed tasks with a limit
limited_tasks = task_repo.fetch_unprocessed_tasks(limit=2)
limited_tasks = task_repo.fetch_unprocessed_tasks(
limit=2, max_retries=3, retryable_exceptions=["InferenceException"]
)

# Assert that only the specified number of tasks is returned
assert len(limited_tasks) == 2
Expand Down

0 comments on commit 939694b

Please sign in to comment.