Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring parse_message() #106

Merged
merged 21 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6d1d357
WIP: refactoring presto parse_message()
ashkankzme Aug 13, 2024
2aef392
WIP: updating unit tests (doesn't pass yet)
ashkankzme Aug 14, 2024
350e892
fixing the unit tests, now they pass
ashkankzme Aug 14, 2024
5229b1a
WIP: implementing a sample input parsing and verification implementat…
ashkankzme Aug 14, 2024
08662e0
implementing sample validataion code for classycat, untested
ashkankzme Aug 15, 2024
243c4c6
WIP: fixing a minor bug for a class method
ashkankzme Aug 15, 2024
7f6954f
presto refactoring all tested and fixed (verified locally and ran uni…
ashkankzme Aug 15, 2024
8e6f491
developer guide for writing presto models
ashkankzme Aug 15, 2024
5505b4c
removing classycat garbage code
ashkankzme Aug 15, 2024
483a970
WIP: taking classycat response class definitions out of schemas.py an…
ashkankzme Aug 15, 2024
58e3979
PR comments
ashkankzme Aug 15, 2024
932e23c
addressing PR comments, plus some refactoring to fix the tests (mostl…
ashkankzme Aug 15, 2024
238710c
fixing model name in queue tests
ashkankzme Aug 15, 2024
75cc530
fixing fasttext code and tests to follow the Presto model naming conv…
ashkankzme Aug 15, 2024
1d794ca
addressing PR comments, making sure classification_results is a manda…
ashkankzme Aug 15, 2024
3eb9861
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 19, 2024
3c7b942
updating presto gitignore
ashkankzme Aug 26, 2024
a2c79e3
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 26, 2024
542e3bf
merging with master + removing unused imports
ashkankzme Aug 26, 2024
aa49ec8
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 26, 2024
66a6537
fixing those units
ashkankzme Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process_item(process_name: str, message: Dict[str, Any]):
queue_prefix = Queue.get_queue_prefix()
queue_suffix = Queue.get_queue_suffix()
queue = QueueWorker.create(process_name)
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_message({"body": message, "model_name": process_name}))
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_input_message({"body": message, "model_name": process_name}))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
Expand Down
17 changes: 16 additions & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Dict
from typing import Union, List, Dict, Any
import os
import tempfile

Expand Down Expand Up @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
finally:
os.remove(temp_file_name)
return {"hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just be passing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think passing is equivalent to what we have right now for most of these models (with the exception of classycat), not counting the validations that happen inside schema.py

I do agree that we should start implementing them, and I have created a ticket for that work (CV2-5093), but the current design is backward compatible and there is no need to implement these right now? unless you think it's urgent we address it?

Validate input data. Must be implemented by all child "Model" classes.
"""
pass


@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do nothing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: ah I see what the issue is after reading your Jira message - I think we should talk through how to not just stub all these on our next ML call, but in the meantime I think looking at what we typically unit-test for, for each model, and just doing type-checking based on the types of responses we're testing for would be appropriate.

Validate input data. Must be implemented by all child "Model" classes.
"""
return None
41 changes: 39 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union
from typing import Union, Dict, Any
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse, ClassyCatResponse
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
Expand All @@ -23,3 +23,40 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
ClassifyModel.validate_input(data)
elif event_type == 'schema_lookup':
ClassyCatSchemaLookupModel.validate_input(data)
elif event_type == 'schema_create':
ClassyCatSchemaCreateModel.validate_input(data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
result_instance_class = ClassifyModel
elif event_type == 'schema_lookup':
result_instance_class = ClassyCatSchemaLookupModel
elif event_type == 'schema_create':
result_instance_class = ClassyCatSchemaCreateModel

else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

return result_instance_class.parse_input_message(data)
33 changes: 33 additions & 0 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand Down Expand Up @@ -230,3 +231,35 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "":
raise PrestoBaseException("schema_id is required as input to classify", 422)

if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0:
raise PrestoBaseException("items are required as input to classify", 422)

for item in data["parameters"]["items"]:
if "id" not in item or item["id"] == "":
raise PrestoBaseException("id is required for each item", 422)
if "text" not in item or item["text"] == "":
raise PrestoBaseException("text is required for each item", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'classify':
return ClassyCatBatchClassificationResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
37 changes: 36 additions & 1 deletion lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand Down Expand Up @@ -210,7 +211,8 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
@classmethod
def verify_schema_parameters(cls, schema_name, topics, examples, languages):

if not schema_name or not isinstance(schema_name, str) or len(schema_name) == 0:
raise ValueError("schema_name is invalid. It must be a non-empty string")
Expand Down Expand Up @@ -243,3 +245,36 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t

def schema_name_exists(self, schema_name):
return file_exists_in_s3(self.output_bucket, f"{schema_name}.json")


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
schema_specs = data['parameters']

schema_name = schema_specs["schema_name"]
topics = schema_specs["topics"]
examples = schema_specs["examples"]
languages = schema_specs["languages"] # ['English', 'Spanish']

try:
cls.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'schema_create':
return ClassyCatSchemaResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
26 changes: 25 additions & 1 deletion lib/model/classycat_schema_lookup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
from lib.logger import logger
Expand Down Expand Up @@ -74,4 +75,27 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
return result
except Exception as e:
logger.error(f"Error looking up schema name {schema_name}: {e}")
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_name" not in data["parameters"] or data["parameters"]["schema_name"] == "":
raise PrestoBaseException("schema_name is required as input to schema look up", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'schema_lookup':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this needs to check that either the schema_id or schema_name are not empty? It looks like ClassyCatSchemaResponse considers schema_id as optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the class specific validator checks that those fields exist, look at classycat_schema_lookup.py

however this function as implemented before only searches by name, not both name and id. maybe we can file a ticket for that if it's necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was already a function for lookup by schema name and a separate one for id? (I just wasn't sure which this was)

return ClassyCatSchemaResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
16 changes: 15 additions & 1 deletion lib/model/fasttext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, List
from typing import Union, Dict, List, Any

import fasttext
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -44,3 +44,17 @@ def respond(self, docs: Union[List[schemas.Message], schemas.Message]) -> List[s
for doc, detected_lang in zip(docs, detected_langs):
doc.body.result = detected_lang
return docs

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
4 changes: 4 additions & 0 deletions lib/model/fptg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'meedan/paraphrase-filipino-mpnet-base-v2'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init FPTG model. Fairly standard for all vectorizers.
Expand Down
16 changes: 15 additions & 1 deletion lib/model/generic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Union, Dict, List
from typing import Union, Dict, List, Any
from sentence_transformers import SentenceTransformer
from lib.logger import logger
from lib.model.model import Model
Expand Down Expand Up @@ -34,3 +34,17 @@ def vectorize(self, texts: List[str]) -> List[List[float]]:
Vectorize the text! Run as batch.
"""
return {"hash_value": self.model.encode(texts).tolist()}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
17 changes: 16 additions & 1 deletion lib/model/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import io
import urllib.request

Expand Down Expand Up @@ -35,3 +35,18 @@ def process(self, image: schemas.Message) -> schemas.GenericItem:
Generic function for returning the actual response.
"""
return {"hash_value": self.compute_pdq(self.get_iobytes_for_image(image))}


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
4 changes: 4 additions & 0 deletions lib/model/indian_sbert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'meedan/indian-sbert'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init IndianSbert model. Fairly standard for all vectorizers.
Expand Down
4 changes: 4 additions & 0 deletions lib/model/mean_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib.model.generic_transformer import GenericTransformerModel

MODEL_NAME = 'xlm-r-bert-base-nli-stsb-mean-tokens'


class Model(GenericTransformerModel):
BATCH_SIZE = 100

def __init__(self):
"""
Init MeanTokens model. Fairly standard for all vectorizers.
Expand Down
17 changes: 17 additions & 0 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
message.body.result = self.get_response(message)
return messages


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by child classes.
"""
raise NotImplementedError
skyemeedan marked this conversation as resolved.
Show resolved Hide resolved


@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input data. Must be implemented by child classes.
"""
raise NotImplementedError


@classmethod
def create(cls):
"""
Expand Down
16 changes: 15 additions & 1 deletion lib/model/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import os
import uuid
import shutil
Expand Down Expand Up @@ -64,3 +64,17 @@ def process(self, video: schemas.Message) -> schemas.GenericItem:
if os.path.exists(file_path):
os.remove(file_path)
return {"folder": self.tmk_bucket(), "filepath": self.tmk_file_path(video_filename), "hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
16 changes: 15 additions & 1 deletion lib/model/yake_keywords.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any
import io
import urllib.request

Expand Down Expand Up @@ -50,3 +50,17 @@ def process(self, message: schemas.Message) -> schemas.YakeKeywordsResponse:
"""
keywords = self.run_yake(**self.get_params(message))
return keywords

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
Loading
Loading