Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-5011 refactors for making alegre dual purpose on text encoding #103

Merged
merged 8 commits into from
Aug 16, 2024
70 changes: 59 additions & 11 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Union, Dict, List
from typing import Union, List
from sentence_transformers import SentenceTransformer
from lib.logger import logger
from lib.model.model import Model
from lib import schemas
from lib.cache import Cache

class GenericTransformerModel(Model):
def __init__(self, model_name: str):
Expand All @@ -17,20 +18,67 @@ def __init__(self, model_name: str):

def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Respond to a batch of messages by vectorizing uncached texts.
"""
if not isinstance(docs, list):
docs = [docs]
logger.info(docs)
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
doc.body.result = vector
docs = self._ensure_list(docs)
self._log_docs(docs)
docs_to_process, texts_to_vectorize = self._separate_cached_docs(docs)

if texts_to_vectorize:
self._vectorize_and_cache(docs_to_process, texts_to_vectorize)

return docs

def _ensure_list(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
"""
Ensure the input is a list of messages.
"""
return docs if isinstance(docs, list) else [docs]

def _log_docs(self, docs: List[schemas.Message]):
"""
Log the documents for debugging purposes.
"""
logger.info(docs)

def _separate_cached_docs(self, docs: List[schemas.Message]) -> (List[schemas.Message], List[str]):
"""
Separate cached documents from those that need to be vectorized.
"""
docs_to_process = []
texts_to_vectorize = []

for doc in docs:
cached_result = Cache.get_cached_result(doc.body.content_hash)
if cached_result:
doc.body.result = cached_result
else:
docs_to_process.append(doc)
texts_to_vectorize.append(doc.body.text)

return docs_to_process, texts_to_vectorize

def _vectorize_and_cache(self, docs_to_process: List[schemas.Message], texts_to_vectorize: List[str]):
"""
Vectorize the uncached texts and store the results in the cache.
"""
try:
vectorized = self.vectorize(texts_to_vectorize)
for doc, vector in zip(docs_to_process, vectorized):
doc.body.result = vector
Cache.set_cached_result(doc.body.content_hash, vector)
except Exception as e:
self.handle_fingerprinting_error(e)

def vectorize(self, texts: List[str]) -> List[List[float]]:
"""
Vectorize the text! Run as batch.
"""
return {"hash_value": self.model.encode(texts).tolist()}
return self.model.encode(texts).tolist()

def handle_fingerprinting_error(self, error: Exception):
"""
Handle any error that occurs during vectorization.
"""
logger.error(f"Error during vectorization: {error}")
raise error
1 change: 0 additions & 1 deletion lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import json
from typing import List, Dict, Tuple
import os
Expand Down
1 change: 0 additions & 1 deletion lib/queue/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
Expand Down
1 change: 0 additions & 1 deletion lib/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
from opentelemetry import metrics
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
Expand Down
2 changes: 1 addition & 1 deletion test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
100 changes: 94 additions & 6 deletions test/lib/model/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import unittest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np

Expand All @@ -13,20 +13,108 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
texts = ["Hello, how are you?", "I'm doing great, thanks!"]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"})
@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_with_cache(self, mock_set_cache, mock_get_cache):
# Simulate cache hit
mock_get_cache.return_value = [1, 2, 3]

query = schemas.parse_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Anong pangalan mo?"
},
"model_name": "fptg__Model"
})

response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.result, [1, 2, 3])
mock_set_cache.assert_not_called()

@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_without_cache(self, mock_set_cache, mock_get_cache):
# Simulate cache miss
mock_get_cache.return_value = None

query = schemas.parse_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Anong pangalan mo?"
},
"model_name": "fptg__Model"
})

self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])

response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.result, [1, 2, 3])
mock_set_cache.assert_called_once_with(query.body.content_hash, [1, 2, 3])

def test_ensure_list(self):
single_doc = schemas.parse_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Hello"
},
"model_name": "fptg__Model"
})

result = self.model._ensure_list(single_doc)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
self.assertEqual(result[0], single_doc)

def test_separate_cached_docs(self):
# Mock cache
with patch('lib.cache.Cache.get_cached_result') as mock_cache:
mock_cache.side_effect = [None, [4, 5, 6]]

docs = [
schemas.parse_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Hello"
},
"model_name": "fptg__Model"
}),
schemas.parse_message({
"body": {
"id": "456",
"callback_url": "http://example.com/callback",
"text": "How are you?"
},
"model_name": "fptg__Model"
})
]

docs_to_process, texts_to_vectorize = self.model._separate_cached_docs(docs)
self.assertEqual(len(docs_to_process), 1)
self.assertEqual(len(texts_to_vectorize), 1)
self.assertEqual(texts_to_vectorize[0], "Hello")
self.assertEqual(docs[1].body.result, [4, 5, 6])

@patch('lib.model.generic_transformer.logger')
def test_handle_fingerprinting_error(self, mock_logger):
with self.assertRaises(Exception) as context:
self.model.handle_fingerprinting_error(ValueError("An error occurred"))

mock_logger.error.assert_called_once_with("Error during vectorization: An error occurred")
self.assertTrue(isinstance(context.exception, ValueError))

if __name__ == '__main__':
unittest.main()
unittest.main()
2 changes: 1 addition & 1 deletion test/lib/model/test_indian_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "indian_sbert__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "indian_sbert__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
2 changes: 1 addition & 1 deletion test/lib/model/test_meantokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "mean_tokens__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "mean_tokens__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
14 changes: 7 additions & 7 deletions test/lib/queue/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import time
from typing import Union, List
from lib.model.generic_transformer import GenericTransformerModel
from lib.model.audio import Model as AudioModel
from lib.queue.queue import Queue
from lib.queue.worker import QueueWorker
from lib import schemas
Expand All @@ -31,8 +31,8 @@ class TestQueueWorker(unittest.TestCase):
@patch('lib.helpers.get_environment_setting', return_value='us-west-1')
@patch('lib.telemetry.OpenTelemetryExporter.log_execution_time')
def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resource):
self.model = GenericTransformerModel(None)
self.model.model_name = "generic"
self.model = AudioModel()
self.model.model_name = "audio__Model"
self.mock_model = MagicMock()
self.queue_name_input = Queue.get_input_queue_name()
self.queue_name_output = Queue.get_output_queue_name()
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_execute_with_timeout_success(self, mock_log_execution_status, mock_log_
def test_process(self):
self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({
"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"},
"model_name": "generic"
"model_name": "audio__Model"
})), self.mock_input_queue)])
self.queue.input_queue = MagicMock(return_value=None)
self.model.model = self.mock_model
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_extract_messages(self):
self.assertIsInstance(extracted_messages[0].body, schemas.GenericItem)
self.assertEqual(extracted_messages[0].body.text, "Test message 1")
self.assertEqual(extracted_messages[1].body.text, "Test message 2")
self.assertEqual(extracted_messages[0].model_name, "generic")
self.assertEqual(extracted_messages[0].model_name, "audio__Model")

@patch('lib.queue.worker.logger.error')
def test_log_and_handle_error(self, mock_logger_error):
Expand All @@ -234,8 +234,8 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get):
mock_cache_get.return_value = None
mock_cache_set.return_value = True
message_data = {
"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"},
"model_name": "generic"
"body": {"id": 1, "callback_url": "http://example.com", "text": "This is a testzzz"},
"model_name": "audio__Model"
}
message = schemas.parse_message(message_data)
message.body.content_hash = "test_hash"
Expand Down
1 change: 0 additions & 1 deletion test/lib/test_telemetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
import pytest
from unittest.mock import patch, MagicMock
Expand Down
Loading