diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index a21bbf6..d30b4e9 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -5,7 +5,7 @@ import numpy as np import time from typing import Union, List -from lib.model.generic_transformer import GenericTransformerModel +from lib.model.audio import Model as AudioModel from lib.queue.queue import Queue from lib.queue.worker import QueueWorker from lib import schemas @@ -31,8 +31,8 @@ class TestQueueWorker(unittest.TestCase): @patch('lib.helpers.get_environment_setting', return_value='us-west-1') @patch('lib.telemetry.OpenTelemetryExporter.log_execution_time') def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resource): - self.model = GenericTransformerModel(None) - self.model.model_name = "generic" + self.model = AudioModel(None) + self.model.model_name = "audio" self.mock_model = MagicMock() self.queue_name_input = Queue.get_input_queue_name() self.queue_name_output = Queue.get_output_queue_name() @@ -84,7 +84,7 @@ def test_execute_with_timeout_success(self, mock_log_execution_status, mock_log_ def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({ "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "model_name": "audio" })), self.mock_input_queue)]) self.queue.input_queue = MagicMock(return_value=None) self.model.model = self.mock_model @@ -212,7 +212,7 @@ def test_extract_messages(self): self.assertIsInstance(extracted_messages[0].body, schemas.GenericItem) self.assertEqual(extracted_messages[0].body.text, "Test message 1") self.assertEqual(extracted_messages[1].body.text, "Test message 2") - self.assertEqual(extracted_messages[0].model_name, "generic") + self.assertEqual(extracted_messages[0].model_name, "audio") @patch('lib.queue.worker.logger.error') def test_log_and_handle_error(self, mock_logger_error): @@ -235,7 +235,7 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get): mock_cache_set.return_value = True message_data = { "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "model_name": "audio" } message = schemas.parse_message(message_data) message.body.content_hash = "test_hash"