diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 7923f439..e3d3bf07 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -96,6 +96,13 @@ def fetch_unprocessed_tasks( # Convert the list of exceptions to a string for the query: # e.g., ["InferenceException", "ValueError"] -> "'InferenceException', 'ValueError'" exceptions_str = ", ".join([f"'{ex}'" for ex in retryable_exceptions]) + if self.session.bind.dialect.name == "postgresql": + exception_name_query = f"result->>'error' IN ({exceptions_str})" + elif self.session.bind.dialect.name == "sqlite": + exception_name_query = ( + f"json_extract(result, '$.error') IN ({exceptions_str})" + ) + tasks = ( self.session.query(TaskEntity) .filter( @@ -105,7 +112,7 @@ def fetch_unprocessed_tasks( ), and_( TaskEntity.status == TaskStatus.FAILED, - text(f"result->>'error' IN ({exceptions_str})"), + text(exception_name_query), TaskEntity.num_retries < max_retries, ), ) diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py index 026da902..521b0dee 100644 --- a/aana/tests/db/datastore/test_task_repo.py +++ b/aana/tests/db/datastore/test_task_repo.py @@ -81,33 +81,62 @@ def _create_sample_tasks(): priority=2, created_at=now - timedelta(hours=3), ) + task5 = TaskEntity( + endpoint="/test5", + data={"test": "data5"}, + status=TaskStatus.FAILED, + priority=1, + created_at=now - timedelta(minutes=1), + result={"error": "InferenceException"}, + ) + task6 = TaskEntity( + endpoint="/test6", + data={"test": "data6"}, + status=TaskStatus.RUNNING, + priority=3, + created_at=now - timedelta(minutes=2), + ) + task7 = TaskEntity( + endpoint="/test7", + data={"test": "data7"}, + status=TaskStatus.FAILED, + priority=1, + created_at=now - timedelta(minutes=3), + result={"error": "NonRecoverableError"}, + ) - db_session.add_all([task1, task2, task3, task4]) + db_session.add_all([task1, task2, task3, task4, task5, task6, task7]) db_session.commit() - return task1, task2, task3, task4 + return task1, task2, task3, task4, task5, task6, task7 # Create sample tasks - task1, task2, task3, task4 = _create_sample_tasks() + task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks() # Fetch unprocessed tasks without any limit - unprocessed_tasks = task_repo.fetch_unprocessed_tasks() + unprocessed_tasks = task_repo.fetch_unprocessed_tasks( + max_retries=3, retryable_exceptions=["InferenceException"] + ) # Assert that only tasks with CREATED and NOT_FINISHED status are returned - assert len(unprocessed_tasks) == 3 + assert len(unprocessed_tasks) == 4 assert task1 in unprocessed_tasks assert task2 in unprocessed_tasks assert task4 in unprocessed_tasks + assert task5 in unprocessed_tasks # Ensure tasks are ordered by priority and then by created_at assert unprocessed_tasks[0].id == task4.id # Highest priority assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task assert unprocessed_tasks[2].id == task1.id # Lowest priority + assert unprocessed_tasks[3].id == task5.id # Highest priority, but older # Create sample tasks - task1, task2, task3, task4 = _create_sample_tasks() + task1, task2, task3, task4, task5, task6, task7 = _create_sample_tasks() # Fetch unprocessed tasks with a limit - limited_tasks = task_repo.fetch_unprocessed_tasks(limit=2) + limited_tasks = task_repo.fetch_unprocessed_tasks( + limit=2, max_retries=3, retryable_exceptions=["InferenceException"] + ) # Assert that only the specified number of tasks is returned assert len(limited_tasks) == 2