Skip to content

Commit

Permalink
Merge pull request #106 from meedan/cv2-5001
Browse files Browse the repository at this point in the history
Refactoring parse_message()
  • Loading branch information
ashkankzme authored Aug 26, 2024
2 parents 02ac2ac + 66a6537 commit 0b8857c
Show file tree
Hide file tree
Showing 35 changed files with 525 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
150 changes: 150 additions & 0 deletions docs/how.to.make.model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# How to make a presto model
## Your go-to guide, one-stop shop for writing your model in presto

A model in Presto is a (mostly) stateless Python class that has the following properties:
- Extends the `lib.model.model.Model` class
- `__init__()`: Overrides the default constructor method that initializes the model and its parameters.
from lib.model.classycat_classify:
```python
class Model(Model):
def __init__(self):
super().__init__()
self.output_bucket = os.getenv("CLASSYCAT_OUTPUT_BUCKET")
self.batch_size_limit = int(os.environ.get("CLASSYCAT_BATCH_SIZE_LIMIT"))
llm_client_type = os.environ.get('CLASSYCAT_LLM_CLIENT_TYPE')
llm_model_name = os.environ.get('CLASSYCAT_LLM_MODEL_NAME')
self.llm_client = self.get_llm_client(llm_client_type, llm_model_name)
```

- `process()`: A method that processes the input data and returns the specific output of the specified predefined schema
This method is the meat of your model, and is supposed to statelessly process the input data and return the output data.
It's good practice to copy a standard input and output schema so that others can consume your model easily.
This method should return your own custom defined response classes, see ClassyCatBatchClassificationResponse for an example.
```python
def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
# Example input:
# {
# "model_name": "classycat__Model",
# "body": {
# "id": 1200,
# "parameters": {
# "event_type": "classify",
# "schema_id": "4a026b82-4a16-440d-aed7-bec07af12205",
# "items": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other"
# }
# ]
# },
# "callback_url": "http://example.com?callback"
# }
# }
#
# Example output:
# {
# "body": {
# "id": 1200,
# "content_hash": null,
# "callback_url": "http://host.docker.internal:9888",
# "url": null,
# "text": null,
# "raw": {},
# "parameters": {
# "event_type": "classify",
# "schema_id": "12589852-4fff-430b-bf77-adad202d03ca",
# "items": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other"
# }
# ]
# },
# "result": {
# "responseMessage": "success",
# "classification_results": [
# {
# "id": "11",
# "text": "modi and bjp want to rule india by dividing people against each other",
# "labels": [
# "Politics",
# "Communalism"
# ]
# }
# ]
# }
# },
# "model_name": "classycat.Model",
# "retry_count": 0
# }

# unpack parameters for classify
batch_to_classify = message.body.parameters
schema_id = batch_to_classify["schema_id"]
items = batch_to_classify["items"]

result = message.body.result

if not self.schema_id_exists(schema_id):
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404)

if len(items) > self.batch_size_limit:
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422)

try:
result.classification_results = self.classify_and_store_results(schema_id, items)
result.responseMessage = "success"
return result
except Exception as e:
logger.exception(f"Error classifying items: {e}")
if isinstance(e, PrestoBaseException):
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e
```
- In case of errors, raise a `PrestoBaseException` with the appropriate error message and http status code, as seen above.
feel free to extend this class and implement more sophisticated error handling for your usecase if needed.
- `validate_input()`: A method that validates the input data and raises an exception if the input is invalid.
```python
@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "":
raise PrestoBaseException("schema_id is required as input to classify", 422)

if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0:
raise PrestoBaseException("items are required as input to classify", 422)

for item in data["parameters"]["items"]:
if "id" not in item or item["id"] == "":
raise PrestoBaseException("id is required for each item", 422)
if "text" not in item or item["text"] == "":
raise PrestoBaseException("text is required for each item", 422)
```
- `parse_input_message()`: generates the right result/response type from the raw input body:
```python
@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'classify':
return ClassyCatBatchClassificationResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
```
- Your own custom defined response classes, ideally defined inside the model file or as a separate file if too complex.
see `ClassyCatBatchClassificationResponse` for an example:
```python
class ClassyCatResponse(BaseModel):
responseMessage: Optional[str] = None

class ClassyCatBatchClassificationResponse(ClassyCatResponse):
classification_results: Optional[List[dict]] = []
```
2 changes: 1 addition & 1 deletion lib/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process_item(process_name: str, message: Dict[str, Any]):
queue_prefix = Queue.get_queue_prefix()
queue_suffix = Queue.get_queue_suffix()
queue = QueueWorker.create(process_name)
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_message({"body": message, "model_name": process_name}))
queue.push_message(f"{queue_prefix}{process_name}{queue_suffix}", schemas.parse_input_message({"body": message, "model_name": process_name}))
return {"message": "Message pushed successfully", "queue": process_name, "body": message}

@app.post("/trigger_callback")
Expand Down
17 changes: 16 additions & 1 deletion lib/model/audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Dict
from typing import Union, List, Dict, Any
import os
import tempfile

Expand Down Expand Up @@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]:
finally:
os.remove(temp_file_name)
return {"hash_value": hash_value}

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
pass


@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
return None
42 changes: 40 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Union
from typing import Union, Dict, Any
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
from lib.schemas import Message
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
from lib.model.classycat_response import ClassyCatSchemaResponse, ClassyCatBatchClassificationResponse
from lib.base_exception import PrestoBaseException


Expand All @@ -23,3 +24,40 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
ClassifyModel.validate_input(data)
elif event_type == 'schema_lookup':
ClassyCatSchemaLookupModel.validate_input(data)
elif event_type == 'schema_create':
ClassyCatSchemaCreateModel.validate_input(data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']

if event_type == 'classify':
result_instance_class = ClassifyModel
elif event_type == 'schema_lookup':
result_instance_class = ClassyCatSchemaLookupModel
elif event_type == 'schema_create':
result_instance_class = ClassyCatSchemaCreateModel

else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)

return result_instance_class.parse_input_message(data)
36 changes: 35 additions & 1 deletion lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, Any
import os
import json
import uuid
Expand All @@ -6,7 +7,8 @@
from anthropic import Anthropic
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatBatchClassificationResponse
from lib.schemas import Message
from lib.model.classycat_response import ClassyCatBatchClassificationResponse
from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3
from lib.base_exception import PrestoBaseException

Expand Down Expand Up @@ -230,3 +232,35 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "":
raise PrestoBaseException("schema_id is required as input to classify", 422)

if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0:
raise PrestoBaseException("items are required as input to classify", 422)

for item in data["parameters"]["items"]:
if "id" not in item or item["id"] == "":
raise PrestoBaseException("id is required for each item", 422)
if "text" not in item or item["text"] == "":
raise PrestoBaseException("text is required for each item", 422)

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'classify':
return ClassyCatBatchClassificationResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
14 changes: 14 additions & 0 deletions lib/model/classycat_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional, List
from pydantic import BaseModel


class ClassyCatResponse(BaseModel):
responseMessage: Optional[str] = None


class ClassyCatBatchClassificationResponse(ClassyCatResponse):
classification_results: List[dict] = []


class ClassyCatSchemaResponse(ClassyCatResponse):
schema_id: Optional[str] = None
40 changes: 38 additions & 2 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, Any
import os
import json
import uuid
from lib.s3 import upload_file_to_s3, file_exists_in_s3
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.schemas import Message
from lib.model.classycat_response import ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


Expand Down Expand Up @@ -210,7 +212,8 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
@classmethod
def verify_schema_parameters(cls, schema_name, topics, examples, languages):

if not schema_name or not isinstance(schema_name, str) or len(schema_name) == 0:
raise ValueError("schema_name is invalid. It must be a non-empty string")
Expand Down Expand Up @@ -243,3 +246,36 @@ def verify_schema_parameters(self, schema_name, topics, examples, languages): #t

def schema_name_exists(self, schema_name):
return file_exists_in_s3(self.output_bucket, f"{schema_name}.json")


@classmethod
def validate_input(cls, data: Dict) -> None:
"""
Validate input data. Must be implemented by all child "Model" classes.
"""
schema_specs = data['parameters']

schema_name = schema_specs["schema_name"]
topics = schema_specs["topics"]
examples = schema_specs["examples"]
languages = schema_specs["languages"] # ['English', 'Spanish']

try:
cls.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

@classmethod
def parse_input_message(cls, data: Dict) -> Any:
"""
Parse input into appropriate response instances.
"""
event_type = data['parameters']['event_type']
result_data = data.get('result', {})

if event_type == 'schema_create':
return ClassyCatSchemaResponse(**result_data)
else:
logger.error(f"Unknown event type {event_type}")
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
Loading

0 comments on commit 0b8857c

Please sign in to comment.