-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring parse_message() #106
Changes from all commits
6d1d357
2aef392
350e892
5229b1a
08662e0
243c4c6
7f6954f
8e6f491
5505b4c
483a970
58e3979
932e23c
238710c
75cc530
1d794ca
3eb9861
3c7b942
a2c79e3
542e3bf
aa49ec8
66a6537
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# How to make a presto model | ||
## Your go-to guide, one-stop shop for writing your model in presto | ||
|
||
A model in Presto is a (mostly) stateless Python class that has the following properties: | ||
- Extends the `lib.model.model.Model` class | ||
- `__init__()`: Overrides the default constructor method that initializes the model and its parameters. | ||
from lib.model.classycat_classify: | ||
```python | ||
class Model(Model): | ||
def __init__(self): | ||
super().__init__() | ||
self.output_bucket = os.getenv("CLASSYCAT_OUTPUT_BUCKET") | ||
self.batch_size_limit = int(os.environ.get("CLASSYCAT_BATCH_SIZE_LIMIT")) | ||
llm_client_type = os.environ.get('CLASSYCAT_LLM_CLIENT_TYPE') | ||
llm_model_name = os.environ.get('CLASSYCAT_LLM_MODEL_NAME') | ||
self.llm_client = self.get_llm_client(llm_client_type, llm_model_name) | ||
``` | ||
|
||
- `process()`: A method that processes the input data and returns the specific output of the specified predefined schema | ||
This method is the meat of your model, and is supposed to statelessly process the input data and return the output data. | ||
It's good practice to copy a standard input and output schema so that others can consume your model easily. | ||
This method should return your own custom defined response classes, see ClassyCatBatchClassificationResponse for an example. | ||
```python | ||
def process(self, message: Message) -> ClassyCatBatchClassificationResponse: | ||
# Example input: | ||
# { | ||
# "model_name": "classycat__Model", | ||
# "body": { | ||
# "id": 1200, | ||
# "parameters": { | ||
# "event_type": "classify", | ||
# "schema_id": "4a026b82-4a16-440d-aed7-bec07af12205", | ||
# "items": [ | ||
# { | ||
# "id": "11", | ||
# "text": "modi and bjp want to rule india by dividing people against each other" | ||
# } | ||
# ] | ||
# }, | ||
# "callback_url": "http://example.com?callback" | ||
# } | ||
# } | ||
# | ||
# Example output: | ||
# { | ||
# "body": { | ||
# "id": 1200, | ||
# "content_hash": null, | ||
# "callback_url": "http://host.docker.internal:9888", | ||
# "url": null, | ||
# "text": null, | ||
# "raw": {}, | ||
# "parameters": { | ||
# "event_type": "classify", | ||
# "schema_id": "12589852-4fff-430b-bf77-adad202d03ca", | ||
# "items": [ | ||
# { | ||
# "id": "11", | ||
# "text": "modi and bjp want to rule india by dividing people against each other" | ||
# } | ||
# ] | ||
# }, | ||
# "result": { | ||
# "responseMessage": "success", | ||
# "classification_results": [ | ||
# { | ||
# "id": "11", | ||
# "text": "modi and bjp want to rule india by dividing people against each other", | ||
# "labels": [ | ||
# "Politics", | ||
# "Communalism" | ||
# ] | ||
# } | ||
# ] | ||
# } | ||
# }, | ||
# "model_name": "classycat.Model", | ||
# "retry_count": 0 | ||
# } | ||
|
||
# unpack parameters for classify | ||
batch_to_classify = message.body.parameters | ||
schema_id = batch_to_classify["schema_id"] | ||
items = batch_to_classify["items"] | ||
|
||
result = message.body.result | ||
|
||
if not self.schema_id_exists(schema_id): | ||
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404) | ||
|
||
if len(items) > self.batch_size_limit: | ||
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422) | ||
|
||
try: | ||
result.classification_results = self.classify_and_store_results(schema_id, items) | ||
result.responseMessage = "success" | ||
return result | ||
except Exception as e: | ||
logger.exception(f"Error classifying items: {e}") | ||
if isinstance(e, PrestoBaseException): | ||
raise e | ||
else: | ||
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e | ||
``` | ||
- In case of errors, raise a `PrestoBaseException` with the appropriate error message and http status code, as seen above. | ||
feel free to extend this class and implement more sophisticated error handling for your usecase if needed. | ||
- `validate_input()`: A method that validates the input data and raises an exception if the input is invalid. | ||
```python | ||
@classmethod | ||
def validate_input(cls, data: Dict) -> None: | ||
""" | ||
Validate input data. Must be implemented by all child "Model" classes. | ||
""" | ||
if "schema_id" not in data["parameters"] or data["parameters"]["schema_id"] == "": | ||
raise PrestoBaseException("schema_id is required as input to classify", 422) | ||
|
||
if "items" not in data["parameters"] or len(data["parameters"]["items"]) == 0: | ||
raise PrestoBaseException("items are required as input to classify", 422) | ||
|
||
for item in data["parameters"]["items"]: | ||
if "id" not in item or item["id"] == "": | ||
raise PrestoBaseException("id is required for each item", 422) | ||
if "text" not in item or item["text"] == "": | ||
raise PrestoBaseException("text is required for each item", 422) | ||
``` | ||
- `parse_input_message()`: generates the right result/response type from the raw input body: | ||
```python | ||
@classmethod | ||
def parse_input_message(cls, data: Dict) -> Any: | ||
""" | ||
Parse input into appropriate response instances. | ||
""" | ||
event_type = data['parameters']['event_type'] | ||
result_data = data.get('result', {}) | ||
|
||
if event_type == 'classify': | ||
return ClassyCatBatchClassificationResponse(**result_data) | ||
else: | ||
logger.error(f"Unknown event type {event_type}") | ||
raise PrestoBaseException(f"Unknown event type {event_type}", 422) | ||
``` | ||
- Your own custom defined response classes, ideally defined inside the model file or as a separate file if too complex. | ||
see `ClassyCatBatchClassificationResponse` for an example: | ||
```python | ||
class ClassyCatResponse(BaseModel): | ||
responseMessage: Optional[str] = None | ||
|
||
class ClassyCatBatchClassificationResponse(ClassyCatResponse): | ||
classification_results: Optional[List[dict]] = [] | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Union, List, Dict | ||
from typing import Union, List, Dict, Any | ||
import os | ||
import tempfile | ||
|
||
|
@@ -27,3 +27,18 @@ def process(self, audio: schemas.Message) -> Dict[str, Union[str, List[int]]]: | |
finally: | ||
os.remove(temp_file_name) | ||
return {"hash_value": hash_value} | ||
|
||
@classmethod | ||
def validate_input(cls, data: Dict) -> None: | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just be passing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think passing is equivalent to what we have right now for most of these models (with the exception of classycat), not counting the validations that happen inside schema.py I do agree that we should start implementing them, and I have created a ticket for that work (CV2-5093), but the current design is backward compatible and there is no need to implement these right now? unless you think it's urgent we address it? |
||
Validate input data. Must be implemented by all child "Model" classes. | ||
""" | ||
pass | ||
|
||
|
||
@classmethod | ||
def parse_input_message(cls, data: Dict) -> Any: | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do nothing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit: ah I see what the issue is after reading your Jira message - I think we should talk through how to not just stub all these on our next ML call, but in the meantime I think looking at what we typically unit-test for, for each model, and just doing type-checking based on the types of responses we're testing for would be appropriate. |
||
Validate input data. Must be implemented by all child "Model" classes. | ||
""" | ||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Optional, List | ||
from pydantic import BaseModel | ||
|
||
|
||
class ClassyCatResponse(BaseModel): | ||
responseMessage: Optional[str] = None | ||
|
||
|
||
class ClassyCatBatchClassificationResponse(ClassyCatResponse): | ||
classification_results: List[dict] = [] | ||
|
||
|
||
class ClassyCatSchemaResponse(ClassyCatResponse): | ||
schema_id: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really helpful! :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this rules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very needed. Thank you!