diff --git a/lib/model/__pycache__/fptg.cpython-39.pyc b/lib/model/__pycache__/fptg.cpython-39.pyc index dc5a4af8..01d9571c 100644 Binary files a/lib/model/__pycache__/fptg.cpython-39.pyc and b/lib/model/__pycache__/fptg.cpython-39.pyc differ diff --git a/lib/model/__pycache__/indian_sbert.cpython-39.pyc b/lib/model/__pycache__/indian_sbert.cpython-39.pyc index 7eaafb95..2e88dcd4 100644 Binary files a/lib/model/__pycache__/indian_sbert.cpython-39.pyc and b/lib/model/__pycache__/indian_sbert.cpython-39.pyc differ diff --git a/lib/model/paraphrase_multilingual.py b/lib/model/paraphrase_multilingual.py new file mode 100644 index 00000000..05dcc4e6 --- /dev/null +++ b/lib/model/paraphrase_multilingual.py @@ -0,0 +1,9 @@ +from lib.model.generic_transformer import GenericTransformerModel +MODEL_NAME = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' +class Model(GenericTransformerModel): + BATCH_SIZE = 100 + def __init__(self): + """ + Init ParaphraseMultilingual model. Fairly standard for all vectorizers. + """ + super().__init__(MODEL_NAME) diff --git a/test/lib/model/test_indian_sbert.py b/test/lib/model/test_indian_sbert.py index 0115dfe6..7c21fa6c 100644 --- a/test/lib/model/test_indian_sbert.py +++ b/test/lib/model/test_indian_sbert.py @@ -29,4 +29,4 @@ def test_respond(self): self.assertEqual(response[0].body.result, [1, 2, 3]) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/lib/model/test_paraphrase_multilingual.py b/test/lib/model/test_paraphrase_multilingual.py new file mode 100644 index 00000000..2a387097 --- /dev/null +++ b/test/lib/model/test_paraphrase_multilingual.py @@ -0,0 +1,32 @@ +import os +import unittest +from unittest.mock import MagicMock + +import numpy as np + +from lib.model.generic_transformer import GenericTransformerModel +from lib import schemas + +class TestParaphraseMultilingual(unittest.TestCase): + def setUp(self): + self.model = GenericTransformerModel(None) + self.mock_model = MagicMock() + + def test_vectorize(self): + texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "paraphrase_multilingual__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "paraphrase_multilingual__Model"})] + self.model.model = self.mock_model + self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) + vectors = self.model.vectorize(texts) + self.assertEqual(len(vectors), 2) + self.assertEqual(vectors[0], [4, 5, 6]) + self.assertEqual(vectors[1], [7, 8, 9]) + + def test_respond(self): + query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, "model_name": "paraphrase_multilingual__Model"}) + self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) + response = self.model.respond(query) + self.assertEqual(len(response), 1) + self.assertEqual(response[0].body.result, [1, 2, 3]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file