From cc051a8e3229e53a5903e5e1ab319093715c0eb0 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Fri, 9 Aug 2024 07:13:11 -0700 Subject: [PATCH 1/3] refactor to accommodate erroring in transformers --- lib/model/generic_transformer.py | 69 ++++++++++++++++++---- test/lib/model/test_generic.py | 98 ++++++++++++++++++++++++++++++-- 2 files changed, 152 insertions(+), 15 deletions(-) diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 3bb2747..88a1717 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -1,9 +1,10 @@ import os -from typing import Union, Dict, List +from typing import Union, List from sentence_transformers import SentenceTransformer from lib.logger import logger from lib.model.model import Model from lib import schemas +from lib.cache import Cache class GenericTransformerModel(Model): def __init__(self, model_name: str): @@ -17,20 +18,68 @@ def __init__(self, model_name: str): def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]: """ - Force messages as list of messages in case we get a singular item. Then, run fingerprint routine. - Respond can probably be genericized across all models. + Respond to a batch of messages by vectorizing uncached texts. """ - if not isinstance(docs, list): - docs = [docs] - logger.info(docs) - vectorizable_texts = [e.body.text for e in docs] - vectorized = self.vectorize(vectorizable_texts) - for doc, vector in zip(docs, vectorized): - doc.body.result = vector + docs = self._ensure_list(docs) + self._log_docs(docs) + + docs_to_process, texts_to_vectorize = self._separate_cached_docs(docs) + + if texts_to_vectorize: + self._vectorize_and_cache(docs_to_process, texts_to_vectorize) + return docs + def _ensure_list(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: + """ + Ensure the input is a list of messages. + """ + return docs if isinstance(docs, list) else [docs] + + def _log_docs(self, docs: List[schemas.Message]): + """ + Log the documents for debugging purposes. + """ + logger.info(docs) + + def _separate_cached_docs(self, docs: List[schemas.Message]) -> (List[schemas.Message], List[str]): + """ + Separate cached documents from those that need to be vectorized. + """ + docs_to_process = [] + texts_to_vectorize = [] + + for doc in docs: + cached_result = Cache.get_cached_result(doc.body.content_hash) + if cached_result: + doc.body.result = cached_result + else: + docs_to_process.append(doc) + texts_to_vectorize.append(doc.body.text) + + return docs_to_process, texts_to_vectorize + + def _vectorize_and_cache(self, docs_to_process: List[schemas.Message], texts_to_vectorize: List[str]): + """ + Vectorize the uncached texts and store the results in the cache. + """ + try: + vectorized = self.vectorize(texts_to_vectorize) + for doc, vector in zip(docs_to_process, vectorized): + doc.body.result = vector + Cache.set_cached_result(doc.body.content_hash, vector) + except Exception as e: + self.handle_fingerprinting_error(e) + def vectorize(self, texts: List[str]) -> List[List[float]]: """ Vectorize the text! Run as batch. """ return self.model.encode(texts).tolist() + + def handle_fingerprinting_error(self, error: Exception): + """ + Handle any error that occurs during vectorization. + """ + logger.error(f"Error during vectorization: {error}") + raise error diff --git a/test/lib/model/test_generic.py b/test/lib/model/test_generic.py index 7d696a1..1ad3f21 100644 --- a/test/lib/model/test_generic.py +++ b/test/lib/model/test_generic.py @@ -1,6 +1,6 @@ import os import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})] + texts = ["Hello, how are you?", "I'm doing great, thanks!"] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -21,12 +21,100 @@ def test_vectorize(self): self.assertEqual(vectors[0], [4, 5, 6]) self.assertEqual(vectors[1], [7, 8, 9]) - def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"}) + @patch('lib.cache.Cache.get_cached_result') + @patch('lib.cache.Cache.set_cached_result') + def test_respond_with_cache(self, mock_set_cache, mock_get_cache): + # Simulate cache hit + mock_get_cache.return_value = [1, 2, 3] + + query = schemas.parse_message({ + "body": { + "id": "123", + "callback_url": "http://example.com/callback", + "text": "Anong pangalan mo?" + }, + "model_name": "fptg__Model" + }) + + response = self.model.respond(query) + self.assertEqual(len(response), 1) + self.assertEqual(response[0].body.result, [1, 2, 3]) + mock_set_cache.assert_not_called() + + @patch('lib.cache.Cache.get_cached_result') + @patch('lib.cache.Cache.set_cached_result') + def test_respond_without_cache(self, mock_set_cache, mock_get_cache): + # Simulate cache miss + mock_get_cache.return_value = None + + query = schemas.parse_message({ + "body": { + "id": "123", + "callback_url": "http://example.com/callback", + "text": "Anong pangalan mo?" + }, + "model_name": "fptg__Model" + }) + self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) + response = self.model.respond(query) self.assertEqual(len(response), 1) self.assertEqual(response[0].body.result, [1, 2, 3]) + mock_set_cache.assert_called_once_with(query.body.content_hash, [1, 2, 3]) + + def test_ensure_list(self): + single_doc = schemas.parse_message({ + "body": { + "id": "123", + "callback_url": "http://example.com/callback", + "text": "Hello" + }, + "model_name": "fptg__Model" + }) + + result = self.model._ensure_list(single_doc) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0], single_doc) + + def test_separate_cached_docs(self): + # Mock cache + with patch('lib.cache.Cache.get_cached_result') as mock_cache: + mock_cache.side_effect = [None, [4, 5, 6]] + + docs = [ + schemas.parse_message({ + "body": { + "id": "123", + "callback_url": "http://example.com/callback", + "text": "Hello" + }, + "model_name": "fptg__Model" + }), + schemas.parse_message({ + "body": { + "id": "456", + "callback_url": "http://example.com/callback", + "text": "How are you?" + }, + "model_name": "fptg__Model" + }) + ] + + docs_to_process, texts_to_vectorize = self.model._separate_cached_docs(docs) + self.assertEqual(len(docs_to_process), 1) + self.assertEqual(len(texts_to_vectorize), 1) + self.assertEqual(texts_to_vectorize[0], "Hello") + self.assertEqual(docs[1].body.result, [4, 5, 6]) + + @patch('lib.model.generic_transformer.logger') + def test_handle_fingerprinting_error(self, mock_logger): + with self.assertRaises(Exception) as context: + self.model.handle_fingerprinting_error(ValueError("An error occurred")) + + mock_logger.error.assert_called_once_with("Error during vectorization: An error occurred") + self.assertTrue(isinstance(context.exception, ValueError)) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From cc0a44142c53d6f76c94bd1a3e2095429e0d85bb Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Fri, 9 Aug 2024 08:59:37 -0700 Subject: [PATCH 2/3] try moving case to audio model instead --- test/lib/queue/test_queue.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index a21bbf6..d30b4e9 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -5,7 +5,7 @@ import numpy as np import time from typing import Union, List -from lib.model.generic_transformer import GenericTransformerModel +from lib.model.audio import Model as AudioModel from lib.queue.queue import Queue from lib.queue.worker import QueueWorker from lib import schemas @@ -31,8 +31,8 @@ class TestQueueWorker(unittest.TestCase): @patch('lib.helpers.get_environment_setting', return_value='us-west-1') @patch('lib.telemetry.OpenTelemetryExporter.log_execution_time') def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resource): - self.model = GenericTransformerModel(None) - self.model.model_name = "generic" + self.model = AudioModel(None) + self.model.model_name = "audio" self.mock_model = MagicMock() self.queue_name_input = Queue.get_input_queue_name() self.queue_name_output = Queue.get_output_queue_name() @@ -84,7 +84,7 @@ def test_execute_with_timeout_success(self, mock_log_execution_status, mock_log_ def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({ "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "model_name": "audio" })), self.mock_input_queue)]) self.queue.input_queue = MagicMock(return_value=None) self.model.model = self.mock_model @@ -212,7 +212,7 @@ def test_extract_messages(self): self.assertIsInstance(extracted_messages[0].body, schemas.GenericItem) self.assertEqual(extracted_messages[0].body.text, "Test message 1") self.assertEqual(extracted_messages[1].body.text, "Test message 2") - self.assertEqual(extracted_messages[0].model_name, "generic") + self.assertEqual(extracted_messages[0].model_name, "audio") @patch('lib.queue.worker.logger.error') def test_log_and_handle_error(self, mock_logger_error): @@ -235,7 +235,7 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get): mock_cache_set.return_value = True message_data = { "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "model_name": "audio" } message = schemas.parse_message(message_data) message.body.content_hash = "test_hash" From 06395f9b4ad98b86658b291ee3612d589337a7f5 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Fri, 9 Aug 2024 10:48:56 -0700 Subject: [PATCH 3/3] remove bad imports, fix broken test cases --- lib/model/generic_transformer.py | 1 - lib/queue/queue.py | 1 - lib/queue/worker.py | 1 - lib/telemetry.py | 1 - test/lib/queue/test_queue.py | 4 ++-- test/lib/test_telemetry.py | 1 - 6 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 88a1717..c8f3fc8 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -22,7 +22,6 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s """ docs = self._ensure_list(docs) self._log_docs(docs) - docs_to_process, texts_to_vectorize = self._separate_cached_docs(docs) if texts_to_vectorize: diff --git a/lib/queue/queue.py b/lib/queue/queue.py index 6a11cb0..f54a286 100644 --- a/lib/queue/queue.py +++ b/lib/queue/queue.py @@ -1,4 +1,3 @@ -import pdb import json from typing import List, Dict, Tuple import os diff --git a/lib/queue/worker.py b/lib/queue/worker.py index edc9914..4435988 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -1,4 +1,3 @@ -import pdb import os import time from concurrent.futures import ThreadPoolExecutor, TimeoutError diff --git a/lib/telemetry.py b/lib/telemetry.py index 40916fe..f557a47 100644 --- a/lib/telemetry.py +++ b/lib/telemetry.py @@ -1,4 +1,3 @@ -import pdb import os from opentelemetry import metrics from opentelemetry.sdk.resources import SERVICE_NAME, Resource diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index d30b4e9..c531722 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -31,7 +31,7 @@ class TestQueueWorker(unittest.TestCase): @patch('lib.helpers.get_environment_setting', return_value='us-west-1') @patch('lib.telemetry.OpenTelemetryExporter.log_execution_time') def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resource): - self.model = AudioModel(None) + self.model = AudioModel() self.model.model_name = "audio" self.mock_model = MagicMock() self.queue_name_input = Queue.get_input_queue_name() @@ -234,7 +234,7 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get): mock_cache_get.return_value = None mock_cache_set.return_value = True message_data = { - "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, + "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a testzzz"}, "model_name": "audio" } message = schemas.parse_message(message_data) diff --git a/test/lib/test_telemetry.py b/test/lib/test_telemetry.py index 6e577a4..48e9b3e 100644 --- a/test/lib/test_telemetry.py +++ b/test/lib/test_telemetry.py @@ -1,4 +1,3 @@ -import pdb import os import pytest from unittest.mock import patch, MagicMock