Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-4952: ClassyCat Error Handling and Metrics #104

Merged
merged 12 commits into from
Aug 13, 2024
8 changes: 8 additions & 0 deletions lib/base_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class PrestoBaseException(Exception):
def __init__(self, message: str, error_code: int = 500):
self.message = message
self.error_code = error_code
super().__init__(message)

def __str__(self):
return f"{self.error_code}: {self.message}"
4 changes: 2 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand All @@ -21,5 +22,4 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
return ClassyCatSchemaCreateModel().process(message)
else:
logger.error(f"Unknown event type {event_type}")
message.body.result.responseMessage = f"Unknown event type {event_type}"
return message.body.result
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
22 changes: 12 additions & 10 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lib.model.model import Model
from lib.schemas import Message, ClassyCatBatchClassificationResponse
from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3
from lib.base_exception import PrestoBaseException


class LLMClient:
Expand Down Expand Up @@ -65,7 +66,9 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200):
max_tokens=(max_tokens_per_item * items_count) + 15,
temperature=0.5
)
# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

return completion.choices[0].message.content


Expand All @@ -84,7 +87,7 @@ def get_llm_client(self, client_type, model_name):
elif client_type == 'openrouter':
return OpenRouterClient(model_name)
else:
raise Exception(f"Unknown client type: {client_type}")
raise PrestoBaseException(f"Unknown LLM client type {client_type}", 500)
skyemeedan marked this conversation as resolved.
Show resolved Hide resolved

def format_input_for_classification_prompt(self, items):
return '\n'.join([f"<ITEM_{i}>{item}</ITEM_{i}>" for i, item in enumerate(items)])
Expand Down Expand Up @@ -119,8 +122,7 @@ def classify_and_store_results(self, schema_id, items):
if (classification_results is None or len(classification_results) == 0
or len(classification_results) != len(items)):
logger.info(f"Classification results: {classification_results}")
raise Exception(f"Not all items were classified successfully: "
f"input length {len(items)}, output length {len(classification_results)}")
raise PrestoBaseException(f"Not all items were classified successfully: input length {len(items)}, output length {len(classification_results)}", 502)

final_results = [{'id': items[i]['id'], 'text': items[i]['text'], 'labels': classification_results[i]}
for i in range(len(items))]
Expand Down Expand Up @@ -213,18 +215,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
result = message.body.result

if not self.schema_id_exists(schema_id):
result.responseMessage = f"Schema id {schema_id} cannot be found"
return result
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404)

if len(items) > self.batch_size_limit:
result.responseMessage = f"Number of items exceeds batch size limit of {self.batch_size_limit}"
return result
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422)

try:
result.classification_results = self.classify_and_store_results(schema_id, items)
result.responseMessage = "success"
return result
except Exception as e:
logger.exception(f"Error classifying items: {e}")
result.responseMessage = f"Error classifying items: {e}"
return result
if isinstance(e, PrestoBaseException):
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e
10 changes: 4 additions & 6 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -192,24 +193,21 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} already exists"
return result
raise PrestoBaseException(f"Schema name {schema_name} already exists", 422)

try:
self.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
result.responseMessage = f"Error verifying schema parameters. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

try:
result.schema_id = self.create_schema(schema_name, topics, examples, languages)
result.responseMessage = 'success'
return result
except Exception as e:
logger.exception(f"Error creating schema: {e}")
result.responseMessage = f"Error creating schema. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
Expand Down
7 changes: 3 additions & 4 deletions lib/model/classycat_schema_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.model.model import Model
from lib.s3 import load_file_from_s3, file_exists_in_s3
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -63,8 +64,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if not self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} does not exist"
return result
raise PrestoBaseException(f"Schema name {schema_name} does not exist", 404)

logger.debug(f"located schema_name record for '{schema_name}'")

Expand All @@ -74,5 +74,4 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
return result
except Exception as e:
logger.error(f"Error looking up schema name {schema_name}: {e}")
result.responseMessage = f"Error looking up schema name {schema_name}: {e}"
return result
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e
18 changes: 12 additions & 6 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from lib import schemas
from lib.cache import Cache
from lib.sentry import capture_custom_message
from lib.base_exception import PrestoBaseException


class Model(ABC):
BATCH_SIZE = 1

def __init__(self):
self.model_name = os.environ.get("MODEL_NAME")

Expand Down Expand Up @@ -43,18 +46,18 @@ def get_tempfile(self) -> Any:
def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def handle_fingerprinting_error(self, e):
def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> schemas.ErrorResponse:
error_context = {"error": str(e)}
for attr in ["__cause__", "__context__", "args", "__traceback__"]:
if attr in dir(e):
if attr == "__traceback__":
error_context[attr] = '\n'.join(traceback.format_tb(getattr(e, attr)))
else:
error_context[attr] = str(getattr(e, attr))
capture_custom_message("Error during fingerprinting for {self.model_name}", 'info', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context)
capture_custom_message(f"Error during fingerprinting for {self.model_name}", 'error', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context, error_code=response_code)

def get_response(self, message: schemas.Message) -> schemas.GenericItem:
def get_response(self, message: schemas.Message) -> schemas.GenericItem: # TODO note: the return type is wrong here
"""
Perform a lookup on the cache for a message, and if found, return that cached value.
"""
Expand All @@ -64,7 +67,10 @@ def get_response(self, message: schemas.Message) -> schemas.GenericItem:
result = self.process(message)
Cache.set_cached_result(message.body.content_hash, result)
except Exception as e:
return self.handle_fingerprinting_error(e)
if isinstance(e, PrestoBaseException):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are (I think) nicely bubbling up errors to this level, should logger.exception() here so that it will still get logged to Sentry from the presto side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this exists on line 58, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I didn't realize that is what capture_custom_messages does. In Presto we are explicitly capturing message so it will go to Sentry only and not normal python logs? If so maybe that capture message argument should be 'error' instead of 'info'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! (changed info to error)

overall, I also think it's better to log errors when they happen. makes debugging easier, and we also won't lose the line number on which the error logging happens. for all of classycat errors, I made sure we log the errors where they happen.

return self.handle_fingerprinting_error(e, e.error_code)
else:
return self.handle_fingerprinting_error(e)
return result

def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
Expand All @@ -76,7 +82,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
for message in messages:
message.body.result = self.get_response(message)
return messages

@classmethod
def create(cls):
"""
Expand Down
18 changes: 12 additions & 6 deletions test/lib/model/test_classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from lib.model.classycat import Model as ClassyCatModel
from lib import schemas
from lib.base_exception import PrestoBaseException
import json


Expand Down Expand Up @@ -562,9 +563,11 @@ def test_classify_fail_wrong_response_format(self, file_exists_in_s3_mock, uploa
}
}
classify_message = schemas.parse_message(classify_input)
result = self.classycat_model.process(classify_message)

self.assertEqual(result.responseMessage, "Error classifying items: list index out of range")
with self.assertRaises(PrestoBaseException) as e:
self.classycat_model.process(classify_message)
self.assertIn("Error classifying items: list index out of range", e.message)
self.assertEqual(e.error_code, 500)

@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
Expand Down Expand Up @@ -700,15 +703,18 @@ def test_classify_fail_wrong_number_of_results(self, file_exists_in_s3_mock, upl
}
}
classify_message = schemas.parse_message(classify_input)
result = self.classycat_model.process(classify_message)

self.assertEqual(result.responseMessage, "Error classifying items: Not all items were classified successfully: input length 1, output length 2")
with self.assertRaises(PrestoBaseException) as e:
self.classycat_model.process(classify_message)
self.assertIn("Not all items were classified successfully: input length 1, output length 2", e.message)
self.assertEqual(e.error_code, 502)


@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
def test_classify_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
Expand Down Expand Up @@ -858,7 +864,7 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_fail_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
def test_classify_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
Expand Down
Loading