Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring parse_message() #106

Merged
merged 21 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6d1d357
WIP: refactoring presto parse_message()
ashkankzme Aug 13, 2024
2aef392
WIP: updating unit tests (doesn't pass yet)
ashkankzme Aug 14, 2024
350e892
fixing the unit tests, now they pass
ashkankzme Aug 14, 2024
5229b1a
WIP: implementing a sample input parsing and verification implementat…
ashkankzme Aug 14, 2024
08662e0
implementing sample validataion code for classycat, untested
ashkankzme Aug 15, 2024
243c4c6
WIP: fixing a minor bug for a class method
ashkankzme Aug 15, 2024
7f6954f
presto refactoring all tested and fixed (verified locally and ran uni…
ashkankzme Aug 15, 2024
8e6f491
developer guide for writing presto models
ashkankzme Aug 15, 2024
5505b4c
removing classycat garbage code
ashkankzme Aug 15, 2024
483a970
WIP: taking classycat response class definitions out of schemas.py an…
ashkankzme Aug 15, 2024
58e3979
PR comments
ashkankzme Aug 15, 2024
932e23c
addressing PR comments, plus some refactoring to fix the tests (mostl…
ashkankzme Aug 15, 2024
238710c
fixing model name in queue tests
ashkankzme Aug 15, 2024
75cc530
fixing fasttext code and tests to follow the Presto model naming conv…
ashkankzme Aug 15, 2024
1d794ca
addressing PR comments, making sure classification_results is a manda…
ashkankzme Aug 15, 2024
3eb9861
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 19, 2024
3c7b942
updating presto gitignore
ashkankzme Aug 26, 2024
a2c79e3
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 26, 2024
542e3bf
merging with master + removing unused imports
ashkankzme Aug 26, 2024
aa49ec8
Merge branch 'refs/heads/master' into cv2-5001-parse-message-refactor
ashkankzme Aug 26, 2024
66a6537
fixing those units
ashkankzme Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
150 changes: 150 additions & 0 deletions docs/how.to.make.model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# How to make a presto model
## Your go-to guide, one-stop shop for writing your model in presto
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really helpful! :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this rules

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very needed. Thank you!


A model in Presto is a (mostly) stateless Python class that has the following properties:
- Extends the `lib.model.model.Model` class
- `__init__()`: Overrides the default constructor method that initializes the model and its parameters.
from lib.model.classycat_classify:
```python
class Model(Model):
def __init__(self):
super().__init__()
self.output_bucket = os.getenv("CLASSYCAT_OUTPUT_BUCKET")
self.batch_size_limit = int(os.environ.get("CLASSYCAT_BATCH_SIZE_LIMIT"))
llm_client_type = os.environ.get('CLASSYCAT_LLM_CLIENT_TYPE')
llm_model_name = os.environ.get('CLASSYCAT_LLM_MODEL_NAME')
self.llm_client = self.get_llm_client(llm_client_type, llm_model_name)
```

- `process()`: A method that processes the input data and returns the specific output of the specified predefined schema
This method is the meat of your model, and is supposed to statelessly process the input data and return the output data.
It's good practice to copy a standard input and output schema so that others can consume your model easily.
This method should return your own custom defined response classes, see ClassyCatBatchClassificationResponse for an example.
```python
def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
# Example input:
# {
# "model_name": "classycat__Model",
# "body": {
# "id": 1200,
# "parameters": {
# "event_type": "classify",
# "schema_id": "4a026b82-4a16-440d-aed7-bec07af12205",
# "items": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other"
# }
# ]
# },
# "callback_url": "http://example.com?callback"
# }
# }
#
# Example output:
# {
# "body": {
# "id": 1200,
# "content_hash": null,
# "callback_url": "http://host.docker.internal:9888",
# "url": null,
# "text": null,
# "raw": {},
# "parameters": {
# "event_type": "classify",
# "schema_id": "12589852-4fff-430b-bf77-adad202d03ca",
# "items": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other"
# }
# ]
# },
# "result": {
# "responseMessage": "success",
# "classification_results": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other",
# "labels": [
# "Politics",
# "Communalism"
# ]
# }
# ]
# }
# },
# "model_name": "classycat.Model",
# "retry_count": 0
# }

# unpack parameters for classify
batch_to_classify = message.body.parameters
schema_id = batch_to_classify["schema_id"]
items = batch_to_classify["items"]

result = message.body.result

if not self.schema_id_exists(schema_id):
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404)

if len(items) > self.batch_size_limit:
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422)

try:
result.classification_results = self.classify_and_store_results(schema_id, items)
result.responseMessage = "success"
return result
except Exception as e:
logger.exception(f"Error classifying items: {e}")
if isinstance(e, PrestoBaseException):
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e
```
- In case of errors, raise a `PrestoBaseException` with the appropriate error message and http status code, as seen above.
feel free to extend this class and implement more sophisticated error handling for your usecase if needed.
- `validate_input()`: A method that validates the input data and raises an exception if the input is invalid.
```python
@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "":
raise PrestoBaseException("schema_id is required as input to classify", 422)

if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0:
raise PrestoBaseException("items are required as input to classify", 422)

for item in data["parameters"]["items"]:
if "id" not in item or item["id"] == "":
raise PrestoBaseException("id is required for each item", 422)
if "text" not in item or item["text"] == "":
raise PrestoBaseException("text is required for each item", 422)
```
- `parse_input_message()`: generates the right result/response type from the raw input body:
```python
@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'classify':
return ClassyCatBatchClassificationResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
```
- Your own custom defined response classes, ideally defined inside the model file or as a separate file if too complex.
see `ClassyCatBatchClassificationResponse` for an example:
```python
class ClassyCatResponse(BaseModel):
responseMessage: Optional[str] = None

class ClassyCatBatchClassificationResponse(ClassyCatResponse):
classification_results: Optional[List[dict]] = []
```
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process_item(process_name: str, message: Dict[str, Any]):
queue_prefix = Queue.get_queue_prefix()
queue_suffix = Queue.get_queue_suffix()
queue = QueueWorker.create(process_name)
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_message({"body": message, "model_name": process_name}))
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_input_message({"body": message, "model_name": process_name}))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
Expand Down
17 changes: 16 additions & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Dict
from typing import Union, List, Dict, Any
import os
import tempfile

Expand Down Expand Up @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
finally:
os.remove(temp_file_name)
return {"hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just be passing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think passing is equivalent to what we have right now for most of these models (with the exception of classycat), not counting the validations that happen inside schema.py

I do agree that we should start implementing them, and I have created a ticket for that work (CV2-5093), but the current design is backward compatible and there is no need to implement these right now? unless you think it's urgent we address it?

Validate input data. Must be implemented by all child "Model" classes.
"""
pass


@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do nothing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: ah I see what the issue is after reading your Jira message - I think we should talk through how to not just stub all these on our next ML call, but in the meantime I think looking at what we typically unit-test for, for each model, and just doing type-checking based on the types of responses we're testing for would be appropriate.

Validate input data. Must be implemented by all child "Model" classes.
"""
return None
42 changes: 40 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Union
from typing import Union, Dict, Any
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
from lib.schemas import Message
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
from lib.model.classycat_response import ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
from lib.base_exception import PrestoBaseException


Expand All @@ -23,3 +24,40 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
ClassifyModel.validate_input(data)
elif event_type == 'schema_lookup':
ClassyCatSchemaLookupModel.validate_input(data)
elif event_type == 'schema_create':
ClassyCatSchemaCreateModel.validate_input(data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
result_instance_class = ClassifyModel
elif event_type == 'schema_lookup':
result_instance_class = ClassyCatSchemaLookupModel
elif event_type == 'schema_create':
result_instance_class = ClassyCatSchemaCreateModel

else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

return result_instance_class.parse_input_message(data)
36 changes: 35 additions & 1 deletion lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand All @@ -6,7 +7,8 @@
from anthropic import Anthropic
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatBatchClassificationResponse
from lib.schemas import Message
from lib.model.classycat_response import ClassyCatBatchClassificationResponse
from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3
from lib.base_exception import PrestoBaseException

Expand Down Expand Up @@ -230,3 +232,35 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "":
raise PrestoBaseException("schema_id is required as input to classify", 422)

if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0:
raise PrestoBaseException("items are required as input to classify", 422)

for item in data["parameters"]["items"]:
if "id" not in item or item["id"] == "":
raise PrestoBaseException("id is required for each item", 422)
if "text" not in item or item["text"] == "":
raise PrestoBaseException("text is required for each item", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'classify':
return ClassyCatBatchClassificationResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
14 changes: 14 additions & 0 deletions lib/model/classycat_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional, List
from pydantic import BaseModel


class ClassyCatResponse(BaseModel):
responseMessage: Optional[str] = None


class ClassyCatBatchClassificationResponse(ClassyCatResponse):
classification_results: List[dict] = []


class ClassyCatSchemaResponse(ClassyCatResponse):
schema_id: Optional[str] = None
40 changes: 38 additions & 2 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, Any
import os
import json
import uuid
from lib.s3 import upload_file_to_s3, file_exists_in_s3
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.schemas import Message
from lib.model.classycat_response import ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


Expand Down Expand Up @@ -210,7 +212,8 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
@classmethod
def verify_schema_parameters(cls, schema_name, topics, examples, languages):

if not schema_name or not isinstance(schema_name, str) or len(schema_name) == 0:
raise ValueError("schema_name is invalid. It must be a non-empty string")
Expand Down Expand Up @@ -243,3 +246,36 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t

def schema_name_exists(self, schema_name):
return file_exists_in_s3(self.output_bucket, f"{schema_name}.json")


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
schema_specs = data['parameters']

schema_name = schema_specs["schema_name"]
topics = schema_specs["topics"]
examples = schema_specs["examples"]
languages = schema_specs["languages"] # ['English', 'Spanish']

try:
cls.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'schema_create':
return ClassyCatSchemaResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
Loading
Loading