From 350e892894ccf18300fee4c23d85e2ac263cd373 Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Wed, 14 Aug 2024 10:58:09 -0700 Subject: [PATCH] fixing the unit tests, now they pass --- lib/model/audio.py | 17 ++++++++++++++++- lib/model/classycat.py | 16 +++++++++++++++- lib/model/classycat_classify.py | 16 ++++++++++++++++ lib/model/classycat_schema_create.py | 16 ++++++++++++++++ lib/model/classycat_schema_lookup.py | 18 +++++++++++++++++- lib/model/fasttext.py | 16 +++++++++++++++- lib/model/fptg.py | 4 ++++ lib/model/generic_transformer.py | 16 +++++++++++++++- lib/model/image.py | 17 ++++++++++++++++- lib/model/indian_sbert.py | 4 ++++ lib/model/mean_tokens.py | 4 ++++ lib/model/video.py | 16 +++++++++++++++- lib/model/yake_keywords.py | 16 +++++++++++++++- 13 files changed, 168 insertions(+), 8 deletions(-) diff --git a/lib/model/audio.py b/lib/model/audio.py index fea03cb..f23cd90 100644 --- a/lib/model/audio.py +++ b/lib/model/audio.py @@ -1,4 +1,4 @@ -from typing import Union, List, Dict +from typing import Union, List, Dict, Any import os import tempfile @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: finally: os.remove(temp_file_name) return {"hash_value": hash_value} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat.py b/lib/model/classycat.py index cceb17f..72e0ae9 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Dict, Any from lib.logger import logger from lib.model.model import Model from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse @@ -23,3 +23,17 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB else: logger.error(f"Unknown event type {event_type}") raise PrestoBaseException(f"Unknown event type {event_type}", 422) + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 2f4b185..1e63e69 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json import uuid @@ -230,3 +231,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse: raise e else: raise PrestoBaseException(f"Error classifying items: {e}", 500) from e + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index b229c44..2ef8986 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json import uuid @@ -243,3 +244,18 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t def schema_name_exists(self, schema_name): return file_exists_in_s3(self.output_bucket, f"{schema_name}.json") + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index af51b44..480bc63 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -1,3 +1,4 @@ +from typing import Dict, Any import os import json from lib.logger import logger @@ -74,4 +75,19 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: return result except Exception as e: logger.error(f"Error looking up schema name {schema_name}: {e}") - raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e \ No newline at end of file + raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/fasttext.py b/lib/model/fasttext.py index e251f56..0aa5a47 100644 --- a/lib/model/fasttext.py +++ b/lib/model/fasttext.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, List +from typing import Union, Dict, List, Any import fasttext from huggingface_hub import hf_hub_download @@ -44,3 +44,17 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s for doc, detected_lang in zip(docs, detected_langs): doc.body.result = detected_lang return docs + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/fptg.py b/lib/model/fptg.py index 2c0731f..f5244d9 100644 --- a/lib/model/fptg.py +++ b/lib/model/fptg.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'meedan/paraphrase-filipino-mpnet-base-v2' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init FPTG model. Fairly standard for all vectorizers. diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index cb567fe..e1d2a11 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -1,5 +1,5 @@ import os -from typing import Union, Dict, List +from typing import Union, Dict, List, Any from sentence_transformers import SentenceTransformer from lib.logger import logger from lib.model.model import Model @@ -34,3 +34,17 @@ def vectorize(self, texts: List[str]) -> List[List[float]]: Vectorize the text! Run as batch. """ return {"hash_value": self.model.encode(texts).tolist()} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/image.py b/lib/model/image.py index 5197f11..a3a7fa7 100644 --- a/lib/model/image.py +++ b/lib/model/image.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import io import urllib.request @@ -35,3 +35,18 @@ def process(self, image: schemas.Message) -> schemas.GenericItem: Generic function for returning the actual response. """ return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))} + + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None \ No newline at end of file diff --git a/lib/model/indian_sbert.py b/lib/model/indian_sbert.py index db529ba..d530216 100644 --- a/lib/model/indian_sbert.py +++ b/lib/model/indian_sbert.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'meedan/indian-sbert' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init IndianSbert model. Fairly standard for all vectorizers. diff --git a/lib/model/mean_tokens.py b/lib/model/mean_tokens.py index a3f77e0..3953bff 100644 --- a/lib/model/mean_tokens.py +++ b/lib/model/mean_tokens.py @@ -1,7 +1,11 @@ from lib.model.generic_transformer import GenericTransformerModel + MODEL_NAME = 'xlm-r-bert-base-nli-stsb-mean-tokens' + + class Model(GenericTransformerModel): BATCH_SIZE = 100 + def __init__(self): """ Init MeanTokens model. Fairly standard for all vectorizers. diff --git a/lib/model/video.py b/lib/model/video.py index cbb95c4..285ccf5 100644 --- a/lib/model/video.py +++ b/lib/model/video.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import os import uuid import shutil @@ -64,3 +64,17 @@ def process(self, video: schemas.Message) -> schemas.GenericItem: if os.path.exists(file_path): os.remove(file_path) return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value} + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None diff --git a/lib/model/yake_keywords.py b/lib/model/yake_keywords.py index 8fc8948..e8adc80 100644 --- a/lib/model/yake_keywords.py +++ b/lib/model/yake_keywords.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any import io import urllib.request @@ -50,3 +50,17 @@ def process(self, message: schemas.Message) -> schemas.YakeKeywordsResponse: """ keywords = self.run_yake(**self.get_params(message)) return keywords + + @classmethod + def validate_input(cls, data: Dict) -> None: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + pass + + @classmethod + def parse_input_message(cls, data: Dict) -> Any: + """ + Validate input data. Must be implemented by all child "Model" classes. + """ + return None