Skip to content

Commit

Permalink
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
Browse files Browse the repository at this point in the history
# Conflicts:
#	lib/model/generic_transformer.py
#	test/lib/model/test_generic.py
  • Loading branch information
ashkankzme committed Aug 19, 2024
2 parents 1d794ca + 11c2f79 commit 3eb9861
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 23 deletions.
68 changes: 58 additions & 10 deletions lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.logger import logger
from lib.model.model import Model
from lib import schemas
from lib.cache import Cache

class GenericTransformerModel(Model):
def __init__(self, model_name: str):
Expand All @@ -17,23 +18,70 @@ def __init__(self, model_name: str):

def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.GenericItem]:
"""
Force messages as list of messages in case we get a singular item. Then, run fingerprint routine.
Respond can probably be genericized across all models.
Respond to a batch of messages by vectorizing uncached texts.
"""
if not isinstance(docs, list):
docs = [docs]
logger.info(docs)
vectorizable_texts = [e.body.text for e in docs]
vectorized = self.vectorize(vectorizable_texts)
for doc, vector in zip(docs, vectorized):
doc.body.result = vector
docs = self._ensure_list(docs)
self._log_docs(docs)
docs_to_process, texts_to_vectorize = self._separate_cached_docs(docs)

if texts_to_vectorize:
self._vectorize_and_cache(docs_to_process, texts_to_vectorize)

return docs

def _ensure_list(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
"""
Ensure the input is a list of messages.
"""
return docs if isinstance(docs, list) else [docs]

def _log_docs(self, docs: List[schemas.Message]):
"""
Log the documents for debugging purposes.
"""
logger.info(docs)

def _separate_cached_docs(self, docs: List[schemas.Message]) -> (List[schemas.Message], List[str]):
"""
Separate cached documents from those that need to be vectorized.
"""
docs_to_process = []
texts_to_vectorize = []

for doc in docs:
cached_result = Cache.get_cached_result(doc.body.content_hash)
if cached_result:
doc.body.result = cached_result
else:
docs_to_process.append(doc)
texts_to_vectorize.append(doc.body.text)

return docs_to_process, texts_to_vectorize

def _vectorize_and_cache(self, docs_to_process: List[schemas.Message], texts_to_vectorize: List[str]):
"""
Vectorize the uncached texts and store the results in the cache.
"""
try:
vectorized = self.vectorize(texts_to_vectorize)
for doc, vector in zip(docs_to_process, vectorized):
doc.body.result = vector
Cache.set_cached_result(doc.body.content_hash, vector)
except Exception as e:
self.handle_fingerprinting_error(e)

def vectorize(self, texts: List[str]) -> List[List[float]]:
"""
Vectorize the text! Run as batch.
"""
return {"hash_value": self.model.encode(texts).tolist()}
return self.model.encode(texts).tolist()

def handle_fingerprinting_error(self, error: Exception):
"""
Handle any error that occurs during vectorization.
"""
logger.error(f"Error during vectorization: {error}")
raise error

@classmethod
def validate_input(cls, data: Dict) -> None:
Expand Down
1 change: 0 additions & 1 deletion lib/queue/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import json
from typing import List, Dict, Tuple
import os
Expand Down
1 change: 0 additions & 1 deletion lib/queue/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
Expand Down
1 change: 0 additions & 1 deletion lib/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
from opentelemetry import metrics
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
Expand Down
2 changes: 1 addition & 1 deletion test/lib/model/test_fptg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
100 changes: 94 additions & 6 deletions test/lib/model/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import unittest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np

Expand All @@ -13,20 +13,108 @@ def setUp(self):
self.mock_model = MagicMock()

def test_vectorize(self):
texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})]
texts = ["Hello, how are you?", "I'm doing great, thanks!"]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])

def test_respond(self):
query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"})
@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_with_cache(self, mock_set_cache, mock_get_cache):
# Simulate cache hit
mock_get_cache.return_value = [1, 2, 3]

query = schemas.parse_input_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Anong pangalan mo?"
},
"model_name": "fptg__Model"
})

response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.result, [1, 2, 3])
mock_set_cache.assert_not_called()

@patch('lib.cache.Cache.get_cached_result')
@patch('lib.cache.Cache.set_cached_result')
def test_respond_without_cache(self, mock_set_cache, mock_get_cache):
# Simulate cache miss
mock_get_cache.return_value = None

query = schemas.parse_input_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Anong pangalan mo?"
},
"model_name": "fptg__Model"
})

self.model.vectorize = MagicMock(return_value=[[1, 2, 3]])

response = self.model.respond(query)
self.assertEqual(len(response), 1)
self.assertEqual(response[0].body.result, [1, 2, 3])
mock_set_cache.assert_called_once_with(query.body.content_hash, [1, 2, 3])

def test_ensure_list(self):
single_doc = schemas.parse_input_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Hello"
},
"model_name": "fptg__Model"
})

result = self.model._ensure_list(single_doc)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
self.assertEqual(result[0], single_doc)

def test_separate_cached_docs(self):
# Mock cache
with patch('lib.cache.Cache.get_cached_result') as mock_cache:
mock_cache.side_effect = [None, [4, 5, 6]]

docs = [
schemas.parse_input_message({
"body": {
"id": "123",
"callback_url": "http://example.com/callback",
"text": "Hello"
},
"model_name": "fptg__Model"
}),
schemas.parse_input_message({
"body": {
"id": "456",
"callback_url": "http://example.com/callback",
"text": "How are you?"
},
"model_name": "fptg__Model"
})
]

docs_to_process, texts_to_vectorize = self.model._separate_cached_docs(docs)
self.assertEqual(len(docs_to_process), 1)
self.assertEqual(len(texts_to_vectorize), 1)
self.assertEqual(texts_to_vectorize[0], "Hello")
self.assertEqual(docs[1].body.result, [4, 5, 6])

@patch('lib.model.generic_transformer.logger')
def test_handle_fingerprinting_error(self, mock_logger):
with self.assertRaises(Exception) as context:
self.model.handle_fingerprinting_error(ValueError("An error occurred"))

mock_logger.error.assert_called_once_with("Error during vectorization: An error occurred")
self.assertTrue(isinstance(context.exception, ValueError))

if __name__ == '__main__':
unittest.main()
unittest.main()
2 changes: 1 addition & 1 deletion test/lib/model/test_indian_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "indian_sbert__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "indian_sbert__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
2 changes: 1 addition & 1 deletion test/lib/model/test_meantokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_vectorize(self):
texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "mean_tokens__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "mean_tokens__Model"})]
self.model.model = self.mock_model
self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]]))
vectors = self.model.vectorize(texts)["hash_value"]
vectors = self.model.vectorize(texts)
self.assertEqual(len(vectors), 2)
self.assertEqual(vectors[0], [4, 5, 6])
self.assertEqual(vectors[1], [7, 8, 9])
Expand Down
1 change: 0 additions & 1 deletion test/lib/test_telemetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import os
import pytest
from unittest.mock import patch, MagicMock
Expand Down

0 comments on commit 3eb9861

Please sign in to comment.