Skip to content

Commit

Permalink
fixing the unit tests, now they pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ashkankzme committed Aug 14, 2024
1 parent 2aef392 commit 350e892
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 8 deletions.
17 changes: 16 additions & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Dict
from typing import Union, List, Dict, Any
import os
import tempfile

Expand Down Expand Up @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
finally:
os.remove(temp_file_name)
return {"hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass


@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 15 additions & 1 deletion lib/model/classycat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Dict, Any
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
Expand All @@ -23,3 +23,17 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 16 additions & 0 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand Down Expand Up @@ -230,3 +231,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 16 additions & 0 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand Down Expand Up @@ -243,3 +244,18 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t

def schema_name_exists(self, schema_name):
return file_exists_in_s3(self.output_bucket, f"{schema_name}.json")


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
18 changes: 17 additions & 1 deletion lib/model/classycat_schema_lookup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
from lib.logger import logger
Expand Down Expand Up @@ -74,4 +75,19 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
return result
except Exception as e:
logger.error(f"Error looking up schema name {schema_name}: {e}")
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 15 additions & 1 deletion lib/model/fasttext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, List
from typing import Union, Dict, List, Any

import fasttext
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -44,3 +44,17 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
for doc, detected_lang in zip(docs, detected_langs):
doc.body.result = detected_lang
return docs

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
4 changes: 4 additions & 0 deletions lib/model/fptg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'meedan/paraphrase-filipino-mpnet-base-v2'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init FPTG model. Fairly standard for all vectorizers.
Expand Down
16 changes: 15 additions & 1 deletion lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Union, Dict, List
from typing import Union, Dict, List, Any
from sentence_transformers import SentenceTransformer
from lib.logger import logger
from lib.model.model import Model
Expand Down Expand Up @@ -34,3 +34,17 @@ def vectorize(self, texts: List[str]) -> List[List[float]]:
Vectorize the text! Run as batch.
"""
return {"hash_value": self.model.encode(texts).tolist()}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
17 changes: 16 additions & 1 deletion lib/model/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import io
import urllib.request

Expand Down Expand Up @@ -35,3 +35,18 @@ def process(self, image: schemas.Message) -> schemas.GenericItem:
Generic function for returning the actual response.
"""
return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))}


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
4 changes: 4 additions & 0 deletions lib/model/indian_sbert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'meedan/indian-sbert'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init IndianSbert model. Fairly standard for all vectorizers.
Expand Down
4 changes: 4 additions & 0 deletions lib/model/mean_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'xlm-r-bert-base-nli-stsb-mean-tokens'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init MeanTokens model. Fairly standard for all vectorizers.
Expand Down
16 changes: 15 additions & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import os
import uuid
import shutil
Expand Down Expand Up @@ -64,3 +64,17 @@ def process(self, video: schemas.Message) -> schemas.GenericItem:
if os.path.exists(file_path):
os.remove(file_path)
return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 15 additions & 1 deletion lib/model/yake_keywords.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import io
import urllib.request

Expand Down Expand Up @@ -50,3 +50,17 @@ def process(self, message: schemas.Message) -> schemas.YakeKeywordsResponse:
"""
keywords = self.run_yake(**self.get_params(message))
return keywords

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None

0 comments on commit 350e892

Please sign in to comment.