From 6d1d35726a431b45a2c37c199a19e2171abac848 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Tue, 13 Aug 2024 12:31:50 -0700 Subject: [PATCH 01/18] WIP: refactoring presto parse_message() --- lib/http.py | 2 +- lib/model/model.py | 17 +++++++++++++++ lib/queue/processor.py | 1 + lib/queue/worker.py | 6 +++--- lib/schemas.py | 47 ++++++++++++++++++++++++++++-------------- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/lib/http.py b/lib/http.py index eada83da..2cf79285 100644 --- a/lib/http.py +++ b/lib/http.py @@ -36,7 +36,7 @@ def process_item(process_name: str, message: Dict[str, Any]): queue_prefix = Queue.get_queue_prefix() queue_suffix = Queue.get_queue_suffix() queue = QueueWorker.create(process_name) - queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_message({"body": message, "model_name": process_name})) + queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_input_message({"body": message, "model_name": process_name})) return {"message": "Message pushed successfully", "queue": process_name, "body": message} @app.post("/trigger_callback") diff --git a/lib/model/model.py b/lib/model/model.py index fc79271a..044f6945 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -83,6 +83,23 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li message.body.result = self.get_response(message) return messages + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by child classes. + """ + raise NotImplementedError + + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Parse input data. Must be implemented by child classes. + """ + raise NotImplementedError + + @classmethod def create(cls): """ diff --git a/lib/queue/processor.py b/lib/queue/processor.py index 50916199..a891baf3 100644 --- a/lib/queue/processor.py +++ b/lib/queue/processor.py @@ -58,6 +58,7 @@ def send_callback(self, message): """ logger.info(f"Message for callback is: {message}") try: + schemas.parse_output_message(message) # will raise exceptions if not valid, e.g. too large of a message callback_url = message.get("body", {}).get("callback_url") response = requests.post( callback_url, diff --git a/lib/queue/worker.py b/lib/queue/worker.py index edc99149..f3c8ef0b 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -82,7 +82,7 @@ def extract_messages(messages_with_queues: List[Tuple], model: Model) -> List[sc Returns: - List[schemas.Message]: A list of Message objects ready for processing. """ - return [schemas.parse_message({**json.loads(message.body), **{"model_name": model.model_name}}) + return [schemas.parse_input_message({**json.loads(message.body), **{"model_name": model.model_name}}) for message, queue in messages_with_queues] @staticmethod @@ -175,9 +175,9 @@ def increment_message_error_counts(self, messages_with_queues: List[Tuple]): if retry_count > MAX_RETRIES: logger.info(f"Message {message_body} exceeded max retries. Moving to DLQ.") capture_custom_message("Message exceeded max retries. Moving to DLQ.", 'info', {"message_body": message_body}) - self.push_to_dead_letter_queue(schemas.parse_message(message_body)) + self.push_to_dead_letter_queue(schemas.parse_input_message(message_body)) else: - updated_message = schemas.parse_message(message_body) + updated_message = schemas.parse_input_message(message_body) updated_message.retry_count = retry_count queue.delete_messages(Entries=[self.delete_message_entry(message)]) self.push_message(self.input_queue_name, updated_message) diff --git a/lib/schemas.py b/lib/schemas.py index bfbf7ab0..05a435c3 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,5 +1,9 @@ from pydantic import BaseModel, ValidationError from typing import Any, Dict, List, Optional, Union +from lib.helpers import get_class +import os + + class ErrorResponse(BaseModel): error: Optional[str] = None error_details: Optional[Dict] = None @@ -32,34 +36,47 @@ class GenericItem(BaseModel): text: Optional[str] = None raw: Optional[Dict] = {} parameters: Optional[Dict] = {} - result: Optional[Union[ErrorResponse, MediaResponse, VideoResponse, YakeKeywordsResponse, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse]] = None + result: Optional[Any] = None class Message(BaseModel): body: GenericItem model_name: str retry_count: int = 0 -def parse_message(message_data: Dict) -> Message: +def parse_input_message(message_data: Dict) -> Message: body_data = message_data['body'] model_name = message_data['model_name'] result_data = body_data.get('result', {}) - if 'yake_keywords' in model_name: - result_instance = YakeKeywordsResponse(**result_data) - elif 'classycat' in model_name: - event_type = body_data['parameters']['event_type'] - if event_type == 'classify': - result_instance = ClassyCatBatchClassificationResponse(**result_data) - elif event_type == 'schema_lookup' or event_type == 'schema_create': - result_instance = ClassyCatSchemaResponse(**result_data) + + modelClass = get_class('lib.model.', os.environ.get('MODEL_NAME')) + modelClass.validate_input(result_data) # will raise exceptions in case of validation errors + # parse_input_message will enable us to have more complicated result types without having to change the schema file + result_instance = modelClass.parse_input_message(result_data) # assumes input is valid + + if result_instance is None: # in case the model does not have a parse_input_message method implemented + if 'yake_keywords' in model_name: + result_instance = YakeKeywordsResponse(**result_data) + elif 'classycat' in model_name: + event_type = body_data['parameters']['event_type'] + if event_type == 'classify': + result_instance = ClassyCatBatchClassificationResponse(**result_data) + elif event_type == 'schema_lookup' or event_type == 'schema_create': + result_instance = ClassyCatSchemaResponse(**result_data) + else: + result_instance = ClassyCatResponse(**result_data) + elif 'video' in model_name: + result_instance = VideoResponse(**result_data) else: - result_instance = ClassyCatResponse(**result_data) - elif 'video' in model_name: - result_instance = VideoResponse(**result_data) - else: - result_instance = MediaResponse(**result_data) + result_instance = MediaResponse(**result_data) + if 'result' in body_data: del body_data['result'] + body_instance = GenericItem(**body_data) body_instance.result = result_instance message_instance = Message(body=body_instance, model_name=model_name) return message_instance + + +def parse_output_message(message_data: Message) -> None: + pass \ No newline at end of file From 2aef39274741d7a906c9c880dd91f99ab03a9ae1 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Wed, 14 Aug 2024 10:47:57 -0700 Subject: [PATCH 02/18] WIP: updating unit tests (doesn't pass yet) --- test/lib/model/test_audio.py | 4 ++-- test/lib/model/test_classycat.py | 14 +++++++------- test/lib/model/test_fasttext.py | 2 +- test/lib/model/test_fptg.py | 4 ++-- test/lib/model/test_generic.py | 4 ++-- test/lib/model/test_image.py | 6 +++--- test/lib/model/test_indian_sbert.py | 4 ++-- test/lib/model/test_meantokens.py | 4 ++-- test/lib/model/test_model.py | 2 +- test/lib/model/test_video.py | 6 +++--- test/lib/model/test_yake_keywords.py | 12 ++++++------ test/lib/queue/test_queue.py | 6 +++--- test/lib/test_schemas.py | 10 +++++----- 13 files changed, 39 insertions(+), 39 deletions(-) diff --git a/test/lib/model/test_audio.py b/test/lib/model/test_audio.py index 7ad6822f..c0c009a4 100644 --- a/test/lib/model/test_audio.py +++ b/test/lib/model/test_audio.py @@ -24,7 +24,7 @@ def test_process_audio_success(self, mock_fingerprint_file, mock_request, mock_u mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - audio = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"}) + audio = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"}) result = self.audio_model.process(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) @@ -45,7 +45,7 @@ def test_process_audio_failure(self, mock_decode_fingerprint, mock_fingerprint_f mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=contents)) - audio = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"}) + audio = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "url": "https://example.com/audio.mp3"}, "model_name": "audio__Model"}) result = self.audio_model.process(audio) mock_request.assert_called_once_with(audio.body.url, headers={'User-Agent': 'Mozilla/5.0'}) mock_urlopen.assert_called_once_with(mock_request) diff --git a/test/lib/model/test_classycat.py b/test/lib/model/test_classycat.py index e85e7088..fdc6e0bf 100644 --- a/test/lib/model/test_classycat.py +++ b/test/lib/model/test_classycat.py @@ -129,7 +129,7 @@ def test_schema_create(self, file_exists_mock, upload_file_to_s3_mock): ] }, "callback_url": "http://example.com?callback"}} - schema_message = schemas.parse_message(schema_input) + schema_message = schemas.parse_input_message(schema_input) result = self.classycat_model.process(schema_message) self.assertEqual("success", result.responseMessage) @@ -272,7 +272,7 @@ def test_schema_lookup(self, file_exists_mock, load_file_from_s3_mock): } } - schema_lookup_message = schemas.parse_message(schema_lookup_input) + schema_lookup_message = schemas.parse_input_message(schema_lookup_input) result = self.classycat_model.process(schema_lookup_message) self.assertEqual("success", result.responseMessage) @@ -413,7 +413,7 @@ def test_classify_success(self, file_exists_in_s3_mock, upload_file_to_s3_mock, "callback_url": "http://example.com?callback" } } - classify_message = schemas.parse_message(classify_input) + classify_message = schemas.parse_input_message(classify_input) result = self.classycat_model.process(classify_message) # example response for this input: @@ -562,7 +562,7 @@ def test_classify_fail_wrong_response_format(self, file_exists_in_s3_mock, uploa "callback_url": "http://example.com?callback" } } - classify_message = schemas.parse_message(classify_input) + classify_message = schemas.parse_input_message(classify_input) with self.assertRaises(PrestoBaseException) as e: self.classycat_model.process(classify_message) @@ -702,7 +702,7 @@ def test_classify_fail_wrong_number_of_results(self, file_exists_in_s3_mock, upl "callback_url": "http://example.com?callback" } } - classify_message = schemas.parse_message(classify_input) + classify_message = schemas.parse_input_message(classify_input) with self.assertRaises(PrestoBaseException) as e: self.classycat_model.process(classify_message) @@ -851,7 +851,7 @@ def test_classify_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload "callback_url": "http://example.com?callback" } } - classify_message = schemas.parse_message(classify_input) + classify_message = schemas.parse_input_message(classify_input) result = self.classycat_model.process(classify_message) self.assertEqual("success", result.responseMessage) @@ -997,7 +997,7 @@ def test_classify_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_ "callback_url": "http://example.com?callback" } } - classify_message = schemas.parse_message(classify_input) + classify_message = schemas.parse_input_message(classify_input) result = self.classycat_model.process(classify_message) self.assertEqual("success", result.responseMessage) diff --git a/test/lib/model/test_fasttext.py b/test/lib/model/test_fasttext.py index ed0c5f09..2f8ca7f9 100644 --- a/test/lib/model/test_fasttext.py +++ b/test/lib/model/test_fasttext.py @@ -15,7 +15,7 @@ def test_respond(self, mock_fasttext_load_model, mock_hf_hub_download): mock_fasttext_load_model.return_value = self.mock_model self.mock_model.predict.return_value = (['__label__eng_Latn'], np.array([0.9])) model = FasttextModel() # Now it uses mocked functions - query = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fasttext__Model"})] + query = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fasttext__Model"})] response = model.respond(query) self.assertEqual(len(response), 1) self.assertEqual(response[0].body.result, {'language': 'en', 'script': None, 'score': 0.9}) diff --git a/test/lib/model/test_fptg.py b/test/lib/model/test_fptg.py index 335b3ad8..e8aad673 100644 --- a/test/lib/model/test_fptg.py +++ b/test/lib/model/test_fptg.py @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})] + texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts)["hash_value"] @@ -22,7 +22,7 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"}) + query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"}) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) diff --git a/test/lib/model/test_generic.py b/test/lib/model/test_generic.py index bdd7eafc..6a4991ef 100644 --- a/test/lib/model/test_generic.py +++ b/test/lib/model/test_generic.py @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})] + texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "fptg__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "fptg__Model"})] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts)["hash_value"] @@ -22,7 +22,7 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"}) + query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Anong pangalan mo?"}, "model_name": "fptg__Model"}) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) diff --git a/test/lib/model/test_image.py b/test/lib/model/test_image.py index fcccdd4f..a15940a9 100644 --- a/test/lib/model/test_image.py +++ b/test/lib/model/test_image.py @@ -25,7 +25,7 @@ def test_get_iobytes_for_image(self, mock_urlopen): mock_response = Mock() mock_response.read.return_value = image_content mock_urlopen.return_value = mock_response - image = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) + image = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) result = Model().get_iobytes_for_image(image) self.assertIsInstance(result, io.BytesIO) self.assertEqual(result.read(), image_content) @@ -33,7 +33,7 @@ def test_get_iobytes_for_image(self, mock_urlopen): @patch("urllib.request.urlopen") def test_get_iobytes_for_image_raises_error(self, mock_urlopen): mock_urlopen.side_effect = URLError('test error') - image = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) + image = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) with self.assertRaises(URLError): Model().get_iobytes_for_image(image) @@ -42,7 +42,7 @@ def test_get_iobytes_for_image_raises_error(self, mock_urlopen): def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image): mock_compute_pdq.return_value = "1001" mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes") - image = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) + image = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, "model_name": "image__Model"}) result = Model().process(image) self.assertEqual(result, {"hash_value": "1001"}) diff --git a/test/lib/model/test_indian_sbert.py b/test/lib/model/test_indian_sbert.py index a944adfd..b5b02650 100644 --- a/test/lib/model/test_indian_sbert.py +++ b/test/lib/model/test_indian_sbert.py @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "indian_sbert__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "indian_sbert__Model"})] + texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "indian_sbert__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "indian_sbert__Model"})] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts)["hash_value"] @@ -22,7 +22,7 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, "model_name": "indian_sbert__Model"}) + query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, "model_name": "indian_sbert__Model"}) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) diff --git a/test/lib/model/test_meantokens.py b/test/lib/model/test_meantokens.py index 8845fb14..67adc4b4 100644 --- a/test/lib/model/test_meantokens.py +++ b/test/lib/model/test_meantokens.py @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "mean_tokens__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "mean_tokens__Model"})] + texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "mean_tokens__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "mean_tokens__Model"})] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts)["hash_value"] @@ -22,7 +22,7 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of France?"}, "model_name": "mean_tokens__Model"}) + query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of France?"}, "model_name": "mean_tokens__Model"}) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1) diff --git a/test/lib/model/test_model.py b/test/lib/model/test_model.py index d6941804..a51c856b 100644 --- a/test/lib/model/test_model.py +++ b/test/lib/model/test_model.py @@ -6,7 +6,7 @@ # class TestModel(unittest.TestCase): # def test_respond(self): # model = Model() -# self.assertEqual(model.respond(schemas.parse_message({"body": schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello")})), model.respond(schemas.parse_message({"body": schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello"), "response": []}))) +# self.assertEqual(model.respond(schemas.parse_input_message({"body": schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello")})), model.respond(schemas.parse_input_message({"body": schemas.GenericItem(id='123', callback_url="http://example.com/callback", text="hello"), "response": []}))) # # if __name__ == '__main__': # unittest.main() \ No newline at end of file diff --git a/test/lib/model/test_video.py b/test/lib/model/test_video.py index e8c802ce..b93ee7bf 100644 --- a/test/lib/model/test_video.py +++ b/test/lib/model/test_video.py @@ -32,7 +32,7 @@ def test_process_video(self, mock_pathlib, mock_upload_file_to_s3, mock_hash_video_output.getPureAverageFeature.return_value = "hash_value" mock_hash_video.return_value = mock_hash_video_output mock_urlopen.return_value = MagicMock(read=MagicMock(return_value=video_contents)) - self.video_model.process(schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"})) + self.video_model.process(schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"})) mock_urlopen.assert_called_once() mock_hash_video.assert_called_once_with(ANY, "/usr/local/bin/ffmpeg") @@ -53,7 +53,7 @@ def test_tmk_program_name(self): def test_respond_with_single_video(self, mock_cache_set, mock_cache_get): mock_cache_get.return_value = None mock_cache_set.return_value = True - video = schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"}) + video = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"}) mock_process = MagicMock() self.video_model.process = mock_process result = self.video_model.respond(video) @@ -65,7 +65,7 @@ def test_respond_with_single_video(self, mock_cache_set, mock_cache_get): def test_respond_with_multiple_videos(self, mock_cache_set, mock_cache_get): mock_cache_get.return_value = None mock_cache_set.return_value = True - videos = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video2.mp4"}, "model_name": "video__Model"})] + videos = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video.mp4"}, "model_name": "video__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://blah.com?callback_id=123", "url": "http://example.com/video2.mp4"}, "model_name": "video__Model"})] mock_process = MagicMock() self.video_model.process = mock_process result = self.video_model.respond(videos) diff --git a/test/lib/model/test_yake_keywords.py b/test/lib/model/test_yake_keywords.py index 47ebef57..87e87e3a 100644 --- a/test/lib/model/test_yake_keywords.py +++ b/test/lib/model/test_yake_keywords.py @@ -10,7 +10,7 @@ def setUp(self): @patch('yake.KeywordExtractor.extract_keywords') def test_process(self, mock_yake_response): - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", "text": "Some Text", @@ -22,7 +22,7 @@ def test_process(self, mock_yake_response): @patch('yake.KeywordExtractor.extract_keywords') def test_run_yake(self, mock_yake_response): - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", "text": "Some Text", @@ -33,7 +33,7 @@ def test_run_yake(self, mock_yake_response): self.assertEqual(self.yake_model.run_yake(**self.yake_model.get_params(message)), {"keywords": [["ball", 0.23]]}) def test_run_yake_real(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", "text": "I love Meedan", @@ -44,7 +44,7 @@ def test_run_yake_real(self): self.assertEqual(results, {"keywords": [('love Meedan', 0.0013670273525686505)]}) def test_get_params_with_defaults(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", "text": "Some Text", @@ -56,7 +56,7 @@ def test_get_params_with_defaults(self): def test_get_params_with_specifics(self): params = {'language': "hi", 'max_ngram_size': 10, 'deduplication_threshold': 0.2, 'deduplication_algo': 'goop', 'window_size': 10, 'num_of_keywords': 100} - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", "text": "Some Text", @@ -68,7 +68,7 @@ def test_get_params_with_specifics(self): self.assertEqual(self.yake_model.get_params(message), expected) def test_get_params_with_defaults_no_text(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ "body": { "id": "1234", }, diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index a21bbf6c..0547e333 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -148,7 +148,7 @@ def test_delete_messages_from_queue(self, mock_logger): mock_logger.assert_called_with(f"Deleting message: {mock_messages[-1]}") def test_push_message(self): - message_to_push = schemas.parse_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) + message_to_push = schemas.parse_input_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) # Call push_message returned_message = self.queue.push_message(self.queue_name_output, message_to_push) # Check if the message was correctly serialized and sent @@ -156,7 +156,7 @@ def test_push_message(self): self.assertEqual(returned_message, message_to_push) def test_push_to_dead_letter_queue(self): - message_to_push = schemas.parse_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) + message_to_push = schemas.parse_input_message({"body": {"id": 1, "content_hash": None, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "mean_tokens__Model"}) # Call push_to_dead_letter_queue self.queue.push_to_dead_letter_queue(message_to_push) # Check if the message was correctly serialized and sent to the DLQ @@ -237,7 +237,7 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get): "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, "model_name": "generic" } - message = schemas.parse_message(message_data) + message = schemas.parse_input_message(message_data) message.body.content_hash = "test_hash" # Simulate an error in the process method diff --git a/test/lib/test_schemas.py b/test/lib/test_schemas.py index 2d531699..e5a1ce08 100644 --- a/test/lib/test_schemas.py +++ b/test/lib/test_schemas.py @@ -7,7 +7,7 @@ from lib import schemas class TestSchemas(unittest.TestCase): def test_audio_output(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ 'body': { 'id': '123', 'callback_url': 'http://0.0.0.0:80/callback_url', @@ -23,7 +23,7 @@ def test_audio_output(self): self.assertIsInstance(message.body.result, schemas.MediaResponse) def test_image_output(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ 'body': { 'id': '123', 'callback_url': 'http://0.0.0.0:80/callback_url', @@ -39,7 +39,7 @@ def test_image_output(self): self.assertIsInstance(message.body.result, schemas.MediaResponse) def test_video_output(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ 'body': { 'id': '123', 'callback_url': 'http://0.0.0.0:80/callback_url', @@ -55,7 +55,7 @@ def test_video_output(self): self.assertIsInstance(message.body.result, schemas.VideoResponse) def test_text_output(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ 'body': { 'id': '123', 'callback_url': 'http://0.0.0.0:80/callback_url', @@ -71,7 +71,7 @@ def test_text_output(self): self.assertIsInstance(message.body.result, schemas.MediaResponse) def test_yake_keyword_output(self): - message = schemas.parse_message({ + message = schemas.parse_input_message({ 'body': { 'id': '123', 'callback_url': 'http://0.0.0.0:80/callback_url', From 350e892894ccf18300fee4c23d85e2ac263cd373 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Wed, 14 Aug 2024 10:58:09 -0700 Subject: [PATCH 03/18] fixing the unit tests, now they pass --- lib/model/audio.py | 17 ++++++++++++++++- lib/model/classycat.py | 16 +++++++++++++++- lib/model/classycat_classify.py | 16 ++++++++++++++++ lib/model/classycat_schema_create.py | 16 ++++++++++++++++ lib/model/classycat_schema_lookup.py | 18 +++++++++++++++++- lib/model/fasttext.py | 16 +++++++++++++++- lib/model/fptg.py | 4 ++++ lib/model/generic_transformer.py | 16 +++++++++++++++- lib/model/image.py | 17 ++++++++++++++++- lib/model/indian_sbert.py | 4 ++++ lib/model/mean_tokens.py | 4 ++++ lib/model/video.py | 16 +++++++++++++++- lib/model/yake_keywords.py | 16 +++++++++++++++- 13 files changed, 168 insertions(+), 8 deletions(-) diff --git a/lib/model/audio.py b/lib/model/audio.py index fea03cbf..f23cd906 100644 --- a/lib/model/audio.py +++ b/lib/model/audio.py @@ -1,4 +1,4 @@ -from typing import Union, List, Dict +from typing import Union, List, Dict, Any import os import tempfile @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: finally: os.remove(temp_file_name) return {"hash_value": hash_value} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat.py b/lib/model/classycat.py index cceb17f1..72e0ae96 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Dict, Any from lib.logger import logger from lib.model.model import Model from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse @@ -23,3 +23,17 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB else: logger.error(f"Unknown event type {event_type}") raise PrestoBaseException(f"Unknown event type {event_type}", 422) + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 2f4b1857..1e63e694 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json import uuid @@ -230,3 +231,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse: raise e else: raise PrestoBaseException(f"Error classifying items: {e}", 500) from e + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index b229c44e..2ef8986f 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json import uuid @@ -243,3 +244,18 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t def schema_name_exists(self, schema_name): return file_exists_in_s3(self.output_bucket, f"{schema_name}.json") + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index af51b443..480bc634 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json from lib.logger import logger @@ -74,4 +75,19 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: return result except Exception as e: logger.error(f"Error looking up schema name {schema_name}: {e}") - raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e \ No newline at end of file + raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/fasttext.py b/lib/model/fasttext.py index e251f566..0aa5a47a 100644 --- a/lib/model/fasttext.py +++ b/lib/model/fasttext.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, List +from typing import Union, Dict, List, Any import fasttext from huggingface_hub import hf_hub_download @@ -44,3 +44,17 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s for doc, detected_lang in zip(docs, detected_langs): doc.body.result = detected_lang return docs + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/fptg.py b/lib/model/fptg.py index 2c0731ff..f5244d9f 100644 --- a/lib/model/fptg.py +++ b/lib/model/fptg.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'meedan/paraphrase-filipino-mpnet-base-v2' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init FPTG model. Fairly standard for all vectorizers. diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index cb567fef..e1d2a118 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -1,5 +1,5 @@ import os -from typing import Union, Dict, List +from typing import Union, Dict, List, Any from sentence_transformers import SentenceTransformer from lib.logger import logger from lib.model.model import Model @@ -34,3 +34,17 @@ def vectorize(self, texts: List[str]) -> List[List[float]]: Vectorize the text! Run as batch. """ return {"hash_value": self.model.encode(texts).tolist()} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/image.py b/lib/model/image.py index 5197f11d..a3a7fa79 100644 --- a/lib/model/image.py +++ b/lib/model/image.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import io import urllib.request @@ -35,3 +35,18 @@ def process(self, image: schemas.Message) -> schemas.GenericItem: Generic function for returning the actual response. """ return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))} + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/indian_sbert.py b/lib/model/indian_sbert.py index db529ba3..d5302163 100644 --- a/lib/model/indian_sbert.py +++ b/lib/model/indian_sbert.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'meedan/indian-sbert' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init IndianSbert model. Fairly standard for all vectorizers. diff --git a/lib/model/mean_tokens.py b/lib/model/mean_tokens.py index a3f77e0d..3953bff8 100644 --- a/lib/model/mean_tokens.py +++ b/lib/model/mean_tokens.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'xlm-r-bert-base-nli-stsb-mean-tokens' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init MeanTokens model. Fairly standard for all vectorizers. diff --git a/lib/model/video.py b/lib/model/video.py index cbb95c46..285ccf5f 100644 --- a/lib/model/video.py +++ b/lib/model/video.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import os import uuid import shutil @@ -64,3 +64,17 @@ def process(self, video: schemas.Message) -> schemas.GenericItem: if os.path.exists(file_path): os.remove(file_path) return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/yake_keywords.py b/lib/model/yake_keywords.py index 8fc89481..e8adc80d 100644 --- a/lib/model/yake_keywords.py +++ b/lib/model/yake_keywords.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import io import urllib.request @@ -50,3 +50,17 @@ def process(self, message: schemas.Message) -> schemas.YakeKeywordsResponse: """ keywords = self.run_yake(**self.get_params(message)) return keywords + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None From 5229b1acdec540c4e459a156be31e53cf38c0ca3 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Wed, 14 Aug 2024 15:18:33 -0700 Subject: [PATCH 04/18] WIP: implementing a sample input parsing and verification implementation for classycat --- lib/model/classycat.py | 31 ++++++++++++++++++++++++---- lib/model/classycat_schema_create.py | 28 +++++++++++++++++++++---- lib/model/classycat_schema_lookup.py | 15 +++++++++++--- lib/schemas.py | 2 +- 4 files changed, 64 insertions(+), 12 deletions(-) diff --git a/lib/model/classycat.py b/lib/model/classycat.py index 72e0ae96..a339ea8d 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -1,7 +1,7 @@ from typing import Union, Dict, Any from lib.logger import logger from lib.model.model import Model -from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse +from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse, ClassyCatResponse from lib.model.classycat_classify import Model as ClassifyModel from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel @@ -29,11 +29,34 @@ def validate_input(cls, data: Dict) -> None: """ Validate input data. Must be implemented by all child "Model" classes. """ - pass + event_type = data['parameters']['event_type'] + + if event_type == 'classify': + ClassifyModel.validate_input(data) + elif event_type == 'schema_lookup': + ClassyCatSchemaLookupModel.validate_input(data) + elif event_type == 'schema_create': + ClassyCatSchemaCreateModel.validate_input(data) + else: + logger.error(f"Unknown event type {event_type}") + raise PrestoBaseException(f"Unknown event type {event_type}", 422) @classmethod def parse_input_message(cls, data: Dict) -> Any: """ - Validate input data. Must be implemented by all child "Model" classes. + Parse input into appropriate response instances. """ - return None + event_type = data['parameters']['event_type'] + + if event_type == 'classify': + result_instance_class = ClassifyModel + elif event_type == 'schema_lookup': + result_instance_class = ClassyCatSchemaLookupModel + elif event_type == 'schema_create': + result_instance_class = ClassyCatSchemaCreateModel + + else: + logger.error(f"Unknown event type {event_type}") + raise PrestoBaseException(f"Unknown event type {event_type}", 422) + + return result_instance_class.parse_input_message(data) diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index 2ef8986f..d8dc0094 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -211,7 +211,8 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: raise PrestoBaseException(f"Error creating schema: {e}", 500) from e - def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo + @classmethod + def verify_schema_parameters(schema_name, topics, examples, languages): if not schema_name or not isinstance(schema_name, str) or len(schema_name) == 0: raise ValueError("schema_name is invalid. It must be a non-empty string") @@ -251,11 +252,30 @@ def validate_input(cls, data: Dict) -> None: """ Validate input data. Must be implemented by all child "Model" classes. """ - pass + schema_specs = data['parameters'] + + schema_name = schema_specs["schema_name"] + topics = schema_specs["topics"] + examples = schema_specs["examples"] + languages = schema_specs["languages"] # ['English', 'Spanish'] + + try: + cls.verify_schema_parameters(schema_name, topics, examples, languages) + except Exception as e: + logger.exception(f"Error verifying schema parameters: {e}") + raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e @classmethod def parse_input_message(cls, data: Dict) -> Any: """ - Validate input data. Must be implemented by all child "Model" classes. + Parse input into appropriate response instances. """ - return None + event_type = data['parameters']['event_type'] + result_data = data.get('result', {}) + + if event_type == 'schema_create': + result_instance = ClassyCatSchemaResponse(**result_data) + else: + raise PrestoBaseException(f"Unknown event type {event_type}", 422) + + return result_instance diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index 480bc634..b7d8cc98 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -83,11 +83,20 @@ def validate_input(cls, data: Dict) -> None: """ Validate input data. Must be implemented by all child "Model" classes. """ - pass + if "schema_name" not in data["parameters"] or data["parameters"]["schema_name"] == "": + raise PrestoBaseException("schema_name is required as input to schema look up", 422) @classmethod def parse_input_message(cls, data: Dict) -> Any: """ - Validate input data. Must be implemented by all child "Model" classes. + Parse input into appropriate response instances. """ - return None + event_type = data['parameters']['event_type'] + result_data = data.get('result', {}) + + if event_type == 'schema_lookup': + result_instance = ClassyCatSchemaResponse(**result_data) + else: + raise PrestoBaseException(f"Unknown event type {event_type}", 422) + + return result_instance diff --git a/lib/schemas.py b/lib/schemas.py index 05a435c3..3ede687a 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -49,7 +49,7 @@ def parse_input_message(message_data: Dict) -> Message: result_data = body_data.get('result', {}) modelClass = get_class('lib.model.', os.environ.get('MODEL_NAME')) - modelClass.validate_input(result_data) # will raise exceptions in case of validation errors + modelClass.validate_input(body_data) # will raise exceptions in case of validation errors # parse_input_message will enable us to have more complicated result types without having to change the schema file result_instance = modelClass.parse_input_message(result_data) # assumes input is valid From 08662e073655a517e7ae707e290c00e1d081559d Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 10:30:19 -0700 Subject: [PATCH 05/18] implementing sample validataion code for classycat, untested --- lib/model/classycat_classify.py | 23 ++++++++++++++++++++--- lib/model/classycat_schema_create.py | 5 ++--- lib/model/classycat_schema_lookup.py | 5 ++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 1e63e694..cb5b28d4 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -238,11 +238,28 @@ def validate_input(cls, data: Dict) -> None: """ Validate input data. Must be implemented by all child "Model" classes. """ - pass + if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "": + raise PrestoBaseException("schema_id is required as input to classify", 422) + + if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0: + raise PrestoBaseException("items are required as input to classify", 422) + + for item in data["parameters"]["items"]: + if "id" not in item or item["id"] == "": + raise PrestoBaseException("id is required for each item", 422) + if "text" not in item or item["text"] == "": + raise PrestoBaseException("text is required for each item", 422) @classmethod def parse_input_message(cls, data: Dict) -> Any: """ - Validate input data. Must be implemented by all child "Model" classes. + Parse input into appropriate response instances. """ - return None \ No newline at end of file + event_type = data['parameters']['event_type'] + result_data = data.get('result', {}) + + if event_type == 'classify': + return ClassyCatBatchClassificationResponse(**result_data) + else: + logger.error(f"Unknown event type {event_type}") + raise PrestoBaseException(f"Unknown event type {event_type}", 422) \ No newline at end of file diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index d8dc0094..f04994de 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -274,8 +274,7 @@ def parse_input_message(cls, data: Dict) -> Any: result_data = data.get('result', {}) if event_type == 'schema_create': - result_instance = ClassyCatSchemaResponse(**result_data) + return ClassyCatSchemaResponse(**result_data) else: + logger.error(f"Unknown event type {event_type}") raise PrestoBaseException(f"Unknown event type {event_type}", 422) - - return result_instance diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index b7d8cc98..23f6b360 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -95,8 +95,7 @@ def parse_input_message(cls, data: Dict) -> Any: result_data = data.get('result', {}) if event_type == 'schema_lookup': - result_instance = ClassyCatSchemaResponse(**result_data) + return ClassyCatSchemaResponse(**result_data) else: + logger.error(f"Unknown event type {event_type}") raise PrestoBaseException(f"Unknown event type {event_type}", 422) - - return result_instance From 243c4c6ef3af5a1c9376a8dd9a344de47e358cc6 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 10:42:56 -0700 Subject: [PATCH 06/18] WIP: fixing a minor bug for a class method --- lib/model/classycat_schema_create.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index f04994de..9b66fc06 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -212,7 +212,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: @classmethod - def verify_schema_parameters(schema_name, topics, examples, languages): + def verify_schema_parameters(cls, schema_name, topics, examples, languages): if not schema_name or not isinstance(schema_name, str) or len(schema_name) == 0: raise ValueError("schema_name is invalid. It must be a non-empty string") From 7f6954f0c5b7919fc219b35dab00612b9fc7f0b0 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 11:21:35 -0700 Subject: [PATCH 07/18] presto refactoring all tested and fixed (verified locally and ran unit tests) --- lib/schemas.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/schemas.py b/lib/schemas.py index 3ede687a..ce5997fe 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -51,8 +51,11 @@ def parse_input_message(message_data: Dict) -> Message: modelClass = get_class('lib.model.', os.environ.get('MODEL_NAME')) modelClass.validate_input(body_data) # will raise exceptions in case of validation errors # parse_input_message will enable us to have more complicated result types without having to change the schema file - result_instance = modelClass.parse_input_message(result_data) # assumes input is valid + result_instance = modelClass.parse_input_message(body_data) # assumes input is valid + # TODO: the following is a temporary solution to handle the case where the model does not have a + # parse_input_message method implemented but we must ultimately implement parse_input_message and + # validate_input in all models. ticket: https://meedan.atlassian.net/browse/CV2-5093 if result_instance is None: # in case the model does not have a parse_input_message method implemented if 'yake_keywords' in model_name: result_instance = YakeKeywordsResponse(**result_data) From 8e6f491e7b17cfbd9b4cb675cd8a03bae9c551d5 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 13:06:09 -0700 Subject: [PATCH 08/18] developer guide for writing presto models --- docs/how.to.make.model.md | 150 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 docs/how.to.make.model.md diff --git a/docs/how.to.make.model.md b/docs/how.to.make.model.md new file mode 100644 index 00000000..b4d9b3a4 --- /dev/null +++ b/docs/how.to.make.model.md @@ -0,0 +1,150 @@ +# How to make a presto model +## Your go-to guide, one-stop shop for writing your model in presto + +A model in Presto is a (mostly) stateless Python class that has the following properties: +- Extends the `lib.model.model.Model` class +- `__init__()`: Overrides the default constructor method that initializes the model and its parameters. +from lib.model.classycat_classify: +```python +class Model(Model): + def __init__(self): + super().__init__() + self.output_bucket = os.getenv("CLASSYCAT_OUTPUT_BUCKET") + self.batch_size_limit = int(os.environ.get("CLASSYCAT_BATCH_SIZE_LIMIT")) + llm_client_type = os.environ.get('CLASSYCAT_LLM_CLIENT_TYPE') + llm_model_name = os.environ.get('CLASSYCAT_LLM_MODEL_NAME') + self.llm_client = self.get_llm_client(llm_client_type, llm_model_name) +``` + +- `process()`: A method that processes the input data and returns the specific output of the specified predefined schema +This method is the meat of your model, and is supposed to statelessly process the input data and return the output data. +It's good practice to copy a standard input and output schema so that others can consume your model easily. +This method should return your own custom defined response classes, see ClassyCatBatchClassificationResponse for an example. +```python +def process(self, message: Message) -> ClassyCatBatchClassificationResponse: + # Example input: + # { + # "model_name": "classycat__Model", + # "body": { + # "id": 1200, + # "parameters": { + # "event_type": "classify", + # "schema_id": "4a026b82-4a16-440d-aed7-bec07af12205", + # "items": [ + # { + # "id": "11", + # "text": "modi and bjp want to rule india by dividing people against each other" + # } + # ] + # }, + # "callback_url": "http://example.com?callback" + # } + # } + # + # Example output: + # { + # "body": { + # "id": 1200, + # "content_hash": null, + # "callback_url": "http://host.docker.internal:9888", + # "url": null, + # "text": null, + # "raw": {}, + # "parameters": { + # "event_type": "classify", + # "schema_id": "12589852-4fff-430b-bf77-adad202d03ca", + # "items": [ + # { + # "id": "11", + # "text": "modi and bjp want to rule india by dividing people against each other" + # } + # ] + # }, + # "result": { + # "responseMessage": "success", + # "classification_results": [ + # { + # "id": "11", + # "text": "modi and bjp want to rule india by dividing people against each other", + # "labels": [ + # "Politics", + # "Communalism" + # ] + # } + # ] + # } + # }, + # "model_name": "classycat.Model", + # "retry_count": 0 + # } + + # unpack parameters for classify + batch_to_classify = message.body.parameters + schema_id = batch_to_classify["schema_id"] + items = batch_to_classify["items"] + + result = message.body.result + + if not self.schema_id_exists(schema_id): + raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404) + + if len(items) > self.batch_size_limit: + raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422) + + try: + result.classification_results = self.classify_and_store_results(schema_id, items) + result.responseMessage = "success" + return result + except Exception as e: + logger.exception(f"Error classifying items: {e}") + if isinstance(e, PrestoBaseException): + raise e + else: + raise PrestoBaseException(f"Error classifying items: {e}", 500) from e +``` +- In case of errors, raise a `PrestoBaseException` with the appropriate error message and http status code, as seen above. +feel free to extend this class and implement more sophisticated error handling for your usecase if needed. +- `validate_input()`: A method that validates the input data and raises an exception if the input is invalid. +```python +@classmethod +def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "": + raise PrestoBaseException("schema_id is required as input to classify", 422) + + if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0: + raise PrestoBaseException("items are required as input to classify", 422) + + for item in data["parameters"]["items"]: + if "id" not in item or item["id"] == "": + raise PrestoBaseException("id is required for each item", 422) + if "text" not in item or item["text"] == "": + raise PrestoBaseException("text is required for each item", 422) +``` +- `parse_input_message()`: generates the right result/response type from the raw input body: +```python +@classmethod +def parse_input_message(cls, data: Dict) -> Any: + """ + Parse input into appropriate response instances. + """ + event_type = data['parameters']['event_type'] + result_data = data.get('result', {}) + + if event_type == 'classify': + return ClassyCatBatchClassificationResponse(**result_data) + else: + logger.error(f"Unknown event type {event_type}") + raise PrestoBaseException(f"Unknown event type {event_type}", 422) +``` +- Your own custom defined response classes, ideally defined inside the model file or as a separate file if too complex. +see `ClassyCatBatchClassificationResponse` for an example: +```python +class ClassyCatResponse(BaseModel): + responseMessage: Optional[str] = None + +class ClassyCatBatchClassificationResponse(ClassyCatResponse): + classification_results: Optional[List[dict]] = [] +``` \ No newline at end of file From 5505b4c72d98c618670693ac87ecfdfee6e05304 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 13:10:34 -0700 Subject: [PATCH 09/18] removing classycat garbage code --- lib/schemas.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lib/schemas.py b/lib/schemas.py index ce5997fe..480a6242 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -59,14 +59,6 @@ def parse_input_message(message_data: Dict) -> Message: if result_instance is None: # in case the model does not have a parse_input_message method implemented if 'yake_keywords' in model_name: result_instance = YakeKeywordsResponse(**result_data) - elif 'classycat' in model_name: - event_type = body_data['parameters']['event_type'] - if event_type == 'classify': - result_instance = ClassyCatBatchClassificationResponse(**result_data) - elif event_type == 'schema_lookup' or event_type == 'schema_create': - result_instance = ClassyCatSchemaResponse(**result_data) - else: - result_instance = ClassyCatResponse(**result_data) elif 'video' in model_name: result_instance = VideoResponse(**result_data) else: From 483a970b2abfc0344e0071782d17b29940ab7ac0 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 13:18:47 -0700 Subject: [PATCH 10/18] WIP: taking classycat response class definitions out of schemas.py and into appropriate model file. untested. --- lib/model/classycat.py | 17 +++++++++++++++-- lib/model/classycat_classify.py | 3 ++- lib/model/classycat_schema_create.py | 3 ++- lib/model/classycat_schema_lookup.py | 3 ++- lib/schemas.py | 13 ++++--------- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/lib/model/classycat.py b/lib/model/classycat.py index a339ea8d..28845e24 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -1,13 +1,26 @@ -from typing import Union, Dict, Any +from typing import Union, Dict, Any, Optional, List +from pydantic import BaseModel from lib.logger import logger from lib.model.model import Model -from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse, ClassyCatResponse +from lib.schemas import Message from lib.model.classycat_classify import Model as ClassifyModel from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel from lib.base_exception import PrestoBaseException +class ClassyCatResponse(BaseModel): + responseMessage: Optional[str] = None + + +class ClassyCatBatchClassificationResponse(ClassyCatResponse): + classification_results: Optional[List[dict]] = [] + + +class ClassyCatSchemaResponse(ClassyCatResponse): + schema_id: Optional[str] = None + + class Model(Model): def __init__(self): super().__init__() diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index cb5b28d4..319d9018 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -7,7 +7,8 @@ from anthropic import Anthropic from lib.logger import logger from lib.model.model import Model -from lib.schemas import Message, ClassyCatBatchClassificationResponse +from lib.schemas import Message +from lib.model.classycat import ClassyCatBatchClassificationResponse from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3 from lib.base_exception import PrestoBaseException diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index 9b66fc06..080634bd 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -5,7 +5,8 @@ from lib.s3 import upload_file_to_s3, file_exists_in_s3 from lib.logger import logger from lib.model.model import Model -from lib.schemas import Message, ClassyCatSchemaResponse +from lib.schemas import Message +from lib.model.classycat import ClassyCatSchemaResponse from lib.base_exception import PrestoBaseException diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index 23f6b360..b59ab4fb 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -4,7 +4,8 @@ from lib.logger import logger from lib.model.model import Model from lib.s3 import load_file_from_s3, file_exists_in_s3 -from lib.schemas import Message, ClassyCatSchemaResponse +from lib.schemas import Message +from lib.model.classycat import ClassyCatSchemaResponse from lib.base_exception import PrestoBaseException diff --git a/lib/schemas.py b/lib/schemas.py index 480a6242..5672cf85 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from typing import Any, Dict, List, Optional, Union from lib.helpers import get_class import os @@ -9,24 +9,19 @@ class ErrorResponse(BaseModel): error_details: Optional[Dict] = None error_code: int = 500 +# TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class MediaResponse(BaseModel): hash_value: Optional[Any] = None +# TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class VideoResponse(MediaResponse): folder: Optional[str] = None filepath: Optional[str] = None +# TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class YakeKeywordsResponse(BaseModel): keywords: Optional[List[List[Union[str, float]]]] = None -class ClassyCatResponse(BaseModel): - responseMessage: Optional[str] = None - -class ClassyCatBatchClassificationResponse(ClassyCatResponse): - classification_results: Optional[List[dict]] = [] - -class ClassyCatSchemaResponse(ClassyCatResponse): - schema_id: Optional[str] = None class GenericItem(BaseModel): id: Union[str, int, float] From 58e3979123dce8238f4fd61fdc6cc3493f6668eb Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 14:25:05 -0700 Subject: [PATCH 11/18] PR comments --- lib/schemas.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/lib/schemas.py b/lib/schemas.py index 5672cf85..a732e321 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from typing import Any, Dict, List, Optional, Union from lib.helpers import get_class +from base_exception import PrestoBaseException import os @@ -9,15 +10,18 @@ class ErrorResponse(BaseModel): error_details: Optional[Dict] = None error_code: int = 500 + # TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class MediaResponse(BaseModel): hash_value: Optional[Any] = None + # TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class VideoResponse(MediaResponse): folder: Optional[str] = None filepath: Optional[str] = None + # TODO move below definition to the model specific file. ticket: https://meedan.atlassian.net/browse/CV2-5093 class YakeKeywordsResponse(BaseModel): keywords: Optional[List[List[Union[str, float]]]] = None @@ -33,20 +37,31 @@ class GenericItem(BaseModel): parameters: Optional[Dict] = {} result: Optional[Any] = None + class Message(BaseModel): body: GenericItem model_name: str retry_count: int = 0 + def parse_input_message(message_data: Dict) -> Message: + if 'body' not in message_data or 'model_name' not in message_data: + raise PrestoBaseException("Invalid message data: message should at minimum include body and model_name" + , 422) + body_data = message_data['body'] model_name = message_data['model_name'] result_data = body_data.get('result', {}) - modelClass = get_class('lib.model.', os.environ.get('MODEL_NAME')) - modelClass.validate_input(body_data) # will raise exceptions in case of validation errors + try: + presto_model_class_name = model_name.replace('__', '.') # todo don't love this line of code + model_class = get_class('lib.model.', presto_model_class_name) + except Exception as e: + raise PrestoBaseException(f"Error loading model {model_name}, model_name is not supported: {e}", 404) from e + + model_class.validate_input(body_data) # will raise exceptions in case of validation errors # parse_input_message will enable us to have more complicated result types without having to change the schema file - result_instance = modelClass.parse_input_message(body_data) # assumes input is valid + result_instance = model_class.parse_input_message(body_data) # assumes input is valid # TODO: the following is a temporary solution to handle the case where the model does not have a # parse_input_message method implemented but we must ultimately implement parse_input_message and @@ -69,4 +84,4 @@ def parse_input_message(message_data: Dict) -> Message: def parse_output_message(message_data: Message) -> None: - pass \ No newline at end of file + pass From 932e23c722c835045fa41aabb9c8dcdb4f3bb8ae Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 14:36:48 -0700 Subject: [PATCH 12/18] addressing PR comments, plus some refactoring to fix the tests (mostly), still not green --- lib/model/classycat.py | 16 ++-------------- lib/model/classycat_classify.py | 2 +- lib/model/classycat_response.py | 14 ++++++++++++++ lib/model/classycat_schema_create.py | 2 +- lib/model/classycat_schema_lookup.py | 2 +- lib/schemas.py | 2 +- 6 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 lib/model/classycat_response.py diff --git a/lib/model/classycat.py b/lib/model/classycat.py index 28845e24..27d8d962 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -1,26 +1,14 @@ -from typing import Union, Dict, Any, Optional, List -from pydantic import BaseModel +from typing import Union, Dict, Any from lib.logger import logger from lib.model.model import Model from lib.schemas import Message from lib.model.classycat_classify import Model as ClassifyModel from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel +from lib.model.classycat_response import ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse from lib.base_exception import PrestoBaseException -class ClassyCatResponse(BaseModel): - responseMessage: Optional[str] = None - - -class ClassyCatBatchClassificationResponse(ClassyCatResponse): - classification_results: Optional[List[dict]] = [] - - -class ClassyCatSchemaResponse(ClassyCatResponse): - schema_id: Optional[str] = None - - class Model(Model): def __init__(self): super().__init__() diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 319d9018..f8ed2b83 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -8,7 +8,7 @@ from lib.logger import logger from lib.model.model import Model from lib.schemas import Message -from lib.model.classycat import ClassyCatBatchClassificationResponse +from lib.model.classycat_response import ClassyCatBatchClassificationResponse from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3 from lib.base_exception import PrestoBaseException diff --git a/lib/model/classycat_response.py b/lib/model/classycat_response.py new file mode 100644 index 00000000..9a02c64f --- /dev/null +++ b/lib/model/classycat_response.py @@ -0,0 +1,14 @@ +from typing import Optional, List +from pydantic import BaseModel + + +class ClassyCatResponse(BaseModel): + responseMessage: Optional[str] = None + + +class ClassyCatBatchClassificationResponse(ClassyCatResponse): + classification_results: Optional[List[dict]] = [] + + +class ClassyCatSchemaResponse(ClassyCatResponse): + schema_id: Optional[str] = None \ No newline at end of file diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index 080634bd..fa49bbf7 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -6,7 +6,7 @@ from lib.logger import logger from lib.model.model import Model from lib.schemas import Message -from lib.model.classycat import ClassyCatSchemaResponse +from lib.model.classycat_response import ClassyCatSchemaResponse from lib.base_exception import PrestoBaseException diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index b59ab4fb..94526a45 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -5,7 +5,7 @@ from lib.model.model import Model from lib.s3 import load_file_from_s3, file_exists_in_s3 from lib.schemas import Message -from lib.model.classycat import ClassyCatSchemaResponse +from lib.model.classycat_response import ClassyCatSchemaResponse from lib.base_exception import PrestoBaseException diff --git a/lib/schemas.py b/lib/schemas.py index a732e321..6e606932 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from typing import Any, Dict, List, Optional, Union from lib.helpers import get_class -from base_exception import PrestoBaseException +from lib.base_exception import PrestoBaseException import os From 238710cc8218e7adefbfa0a6d822eef60ff3f95a Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 15:12:18 -0700 Subject: [PATCH 13/18] fixing model name in queue tests --- test/lib/queue/test_queue.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/lib/queue/test_queue.py b/test/lib/queue/test_queue.py index 0547e333..f166a78b 100644 --- a/test/lib/queue/test_queue.py +++ b/test/lib/queue/test_queue.py @@ -5,7 +5,7 @@ import numpy as np import time from typing import Union, List -from lib.model.generic_transformer import GenericTransformerModel +from lib.model.audio import Model as AudioModel from lib.queue.queue import Queue from lib.queue.worker import QueueWorker from lib import schemas @@ -31,8 +31,8 @@ class TestQueueWorker(unittest.TestCase): @patch('lib.helpers.get_environment_setting', return_value='us-west-1') @patch('lib.telemetry.OpenTelemetryExporter.log_execution_time') def setUp(self, mock_log_execution_time, mock_get_env_setting, mock_boto_resource): - self.model = GenericTransformerModel(None) - self.model.model_name = "generic" + self.model = AudioModel() + self.model.model_name = "audio__Model" self.mock_model = MagicMock() self.queue_name_input = Queue.get_input_queue_name() self.queue_name_output = Queue.get_output_queue_name() @@ -84,7 +84,7 @@ def test_execute_with_timeout_success(self, mock_log_execution_status, mock_log_ def test_process(self): self.queue.receive_messages = MagicMock(return_value=[(FakeSQSMessage(receipt_handle="blah", body=json.dumps({ "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "model_name": "audio__Model" })), self.mock_input_queue)]) self.queue.input_queue = MagicMock(return_value=None) self.model.model = self.mock_model @@ -212,7 +212,7 @@ def test_extract_messages(self): self.assertIsInstance(extracted_messages[0].body, schemas.GenericItem) self.assertEqual(extracted_messages[0].body.text, "Test message 1") self.assertEqual(extracted_messages[1].body.text, "Test message 2") - self.assertEqual(extracted_messages[0].model_name, "generic") + self.assertEqual(extracted_messages[0].model_name, "audio__Model") @patch('lib.queue.worker.logger.error') def test_log_and_handle_error(self, mock_logger_error): @@ -234,8 +234,8 @@ def test_error_capturing_in_get_response(self, mock_cache_set, mock_cache_get): mock_cache_get.return_value = None mock_cache_set.return_value = True message_data = { - "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a test"}, - "model_name": "generic" + "body": {"id": 1, "callback_url": "http://example.com", "text": "This is a testzzz"}, + "model_name": "audio__Model" } message = schemas.parse_input_message(message_data) message.body.content_hash = "test_hash" From 75cc53062e2a7a65ebb448fddfa6d77d9f608da8 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 15:17:49 -0700 Subject: [PATCH 14/18] fixing fasttext code and tests to follow the Presto model naming convention, now all tests pass :yay: --- lib/model/fasttext.py | 2 +- test/lib/model/test_fasttext.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/model/fasttext.py b/lib/model/fasttext.py index 0aa5a47a..dc8aba09 100644 --- a/lib/model/fasttext.py +++ b/lib/model/fasttext.py @@ -9,7 +9,7 @@ from lib import schemas -class FasttextModel(Model): +class Model(Model): def __init__(self): """ Load fasttext model (https://huggingface.co/facebook/fasttext-language-identification) diff --git a/test/lib/model/test_fasttext.py b/test/lib/model/test_fasttext.py index 2f8ca7f9..d7012738 100644 --- a/test/lib/model/test_fasttext.py +++ b/test/lib/model/test_fasttext.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch, MagicMock import numpy as np -from lib.model.fasttext import FasttextModel +from lib.model.fasttext import Model as FasttextModel from lib import schemas class TestFasttextModel(unittest.TestCase): From 1d794cadc0932e77a69f0be46e61955dc74c3cf1 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Thu, 15 Aug 2024 15:25:32 -0700 Subject: [PATCH 15/18] addressing PR comments, making sure classification_results is a mandatory field for classycatbatchresponse class. --- lib/model/classycat_response.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/model/classycat_response.py b/lib/model/classycat_response.py index 9a02c64f..56931dc6 100644 --- a/lib/model/classycat_response.py +++ b/lib/model/classycat_response.py @@ -7,8 +7,8 @@ class ClassyCatResponse(BaseModel): class ClassyCatBatchClassificationResponse(ClassyCatResponse): - classification_results: Optional[List[dict]] = [] + classification_results: List[dict] = [] class ClassyCatSchemaResponse(ClassyCatResponse): - schema_id: Optional[str] = None \ No newline at end of file + schema_id: Optional[str] = None From 3c7b9421e465c108cbc7e15beaf5e113d7b75019 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Mon, 26 Aug 2024 11:19:48 -0700 Subject: [PATCH 16/18] updating presto gitignore --- .gitignore | 163 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 162 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 88125d6a..09fa136d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,165 @@ *.cpython-39.pyc *.pyc .env_file -.env \ No newline at end of file +.env + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ \ No newline at end of file From 542e3bf78042fea79651323eb3fe50bad491dc49 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Mon, 26 Aug 2024 11:35:34 -0700 Subject: [PATCH 17/18] merging with master + removing unused imports --- test/lib/model/test_generic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/lib/model/test_generic.py b/test/lib/model/test_generic.py index 326bd2c4..d245e84e 100644 --- a/test/lib/model/test_generic.py +++ b/test/lib/model/test_generic.py @@ -1,5 +1,3 @@ -import traceback -import os import unittest from unittest.mock import MagicMock, patch From 66a65372052b129425b1acc65594ea078f93f6ae Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Mon, 26 Aug 2024 13:50:56 -0700 Subject: [PATCH 18/18] fixing those units --- test/lib/model/test_paraphrase_multilingual.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lib/model/test_paraphrase_multilingual.py b/test/lib/model/test_paraphrase_multilingual.py index 2a387097..ed0588db 100644 --- a/test/lib/model/test_paraphrase_multilingual.py +++ b/test/lib/model/test_paraphrase_multilingual.py @@ -13,7 +13,7 @@ def setUp(self): self.mock_model = MagicMock() def test_vectorize(self): - texts = [schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "paraphrase_multilingual__Model"}), schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "paraphrase_multilingual__Model"})] + texts = [schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "Hello, how are you?"}, "model_name": "paraphrase_multilingual__Model"}), schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "I'm doing great, thanks!"}, "model_name": "paraphrase_multilingual__Model"})] self.model.model = self.mock_model self.model.model.encode = MagicMock(return_value=np.array([[4, 5, 6], [7, 8, 9]])) vectors = self.model.vectorize(texts) @@ -22,7 +22,7 @@ def test_vectorize(self): self.assertEqual(vectors[1], [7, 8, 9]) def test_respond(self): - query = schemas.parse_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, "model_name": "paraphrase_multilingual__Model"}) + query = schemas.parse_input_message({"body": {"id": "123", "callback_url": "http://example.com/callback", "text": "What is the capital of India?"}, "model_name": "paraphrase_multilingual__Model"}) self.model.vectorize = MagicMock(return_value=[[1, 2, 3]]) response = self.model.respond(query) self.assertEqual(len(response), 1)