Skip to content

Commit

Permalink
Merge pull request #104 from meedan/CV2-4952
Browse files Browse the repository at this point in the history
CV2-4952: ClassyCat Error Handling and Metrics
  • Loading branch information
ashkankzme authored Aug 13, 2024
2 parents 0278093 + 3bc9230 commit 7cf0f34
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 34 deletions.
8 changes: 8 additions & 0 deletions lib/base_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class PrestoBaseException(Exception):
def __init__(self, message: str, error_code: int = 500):
self.message = message
self.error_code = error_code
super().__init__(message)

def __str__(self):
return f"{self.error_code}: {self.message}"
4 changes: 2 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand All @@ -21,5 +22,4 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
return ClassyCatSchemaCreateModel().process(message)
else:
logger.error(f"Unknown event type {event_type}")
message.body.result.responseMessage = f"Unknown event type {event_type}"
return message.body.result
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
22 changes: 12 additions & 10 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lib.model.model import Model
from lib.schemas import Message, ClassyCatBatchClassificationResponse
from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3
from lib.base_exception import PrestoBaseException


class LLMClient:
Expand Down Expand Up @@ -65,7 +66,9 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200):
max_tokens=(max_tokens_per_item * items_count) + 15,
temperature=0.5
)
# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

return completion.choices[0].message.content


Expand All @@ -84,7 +87,7 @@ def get_llm_client(self, client_type, model_name):
elif client_type == 'openrouter':
return OpenRouterClient(model_name)
else:
raise Exception(f"Unknown client type: {client_type}")
raise PrestoBaseException(f"Unknown LLM client type {client_type}", 500)

def format_input_for_classification_prompt(self, items):
return '\n'.join([f"<ITEM_{i}>{item}</ITEM_{i}>" for i, item in enumerate(items)])
Expand Down Expand Up @@ -119,8 +122,7 @@ def classify_and_store_results(self, schema_id, items):
if (classification_results is None or len(classification_results) == 0
or len(classification_results) != len(items)):
logger.info(f"Classification results: {classification_results}")
raise Exception(f"Not all items were classified successfully: "
f"input length {len(items)}, output length {len(classification_results)}")
raise PrestoBaseException(f"Not all items were classified successfully: input length {len(items)}, output length {len(classification_results)}", 502)

final_results = [{'id': items[i]['id'], 'text': items[i]['text'], 'labels': classification_results[i]}
for i in range(len(items))]
Expand Down Expand Up @@ -213,18 +215,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
result = message.body.result

if not self.schema_id_exists(schema_id):
result.responseMessage = f"Schema id {schema_id} cannot be found"
return result
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404)

if len(items) > self.batch_size_limit:
result.responseMessage = f"Number of items exceeds batch size limit of {self.batch_size_limit}"
return result
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422)

try:
result.classification_results = self.classify_and_store_results(schema_id, items)
result.responseMessage = "success"
return result
except Exception as e:
logger.exception(f"Error classifying items: {e}")
result.responseMessage = f"Error classifying items: {e}"
return result
if isinstance(e, PrestoBaseException):
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e
10 changes: 4 additions & 6 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -192,24 +193,21 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} already exists"
return result
raise PrestoBaseException(f"Schema name {schema_name} already exists", 422)

try:
self.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
result.responseMessage = f"Error verifying schema parameters. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

try:
result.schema_id = self.create_schema(schema_name, topics, examples, languages)
result.responseMessage = 'success'
return result
except Exception as e:
logger.exception(f"Error creating schema: {e}")
result.responseMessage = f"Error creating schema. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
Expand Down
7 changes: 3 additions & 4 deletions lib/model/classycat_schema_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.model.model import Model
from lib.s3 import load_file_from_s3, file_exists_in_s3
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -63,8 +64,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if not self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} does not exist"
return result
raise PrestoBaseException(f"Schema name {schema_name} does not exist", 404)

logger.debug(f"located schema_name record for '{schema_name}'")

Expand All @@ -74,5 +74,4 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
return result
except Exception as e:
logger.error(f"Error looking up schema name {schema_name}: {e}")
result.responseMessage = f"Error looking up schema name {schema_name}: {e}"
return result
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e
18 changes: 12 additions & 6 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from lib import schemas
from lib.cache import Cache
from lib.sentry import capture_custom_message
from lib.base_exception import PrestoBaseException


class Model(ABC):
BATCH_SIZE = 1

def __init__(self):
self.model_name = os.environ.get("MODEL_NAME")

Expand Down Expand Up @@ -43,18 +46,18 @@ def get_tempfile(self) -> Any:
def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def handle_fingerprinting_error(self, e):
def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> schemas.ErrorResponse:
error_context = {"error": str(e)}
for attr in ["__cause__", "__context__", "args", "__traceback__"]:
if attr in dir(e):
if attr == "__traceback__":
error_context[attr] = '\n'.join(traceback.format_tb(getattr(e, attr)))
else:
error_context[attr] = str(getattr(e, attr))
capture_custom_message("Error during fingerprinting for {self.model_name}", 'info', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context)
capture_custom_message(f"Error during fingerprinting for {self.model_name}", 'error', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context, error_code=response_code)

def get_response(self, message: schemas.Message) -> schemas.GenericItem:
def get_response(self, message: schemas.Message) -> schemas.GenericItem: # TODO note: the return type is wrong here
"""
Perform a lookup on the cache for a message, and if found, return that cached value.
"""
Expand All @@ -64,7 +67,10 @@ def get_response(self, message: schemas.Message) -> schemas.GenericItem:
result = self.process(message)
Cache.set_cached_result(message.body.content_hash, result)
except Exception as e:
return self.handle_fingerprinting_error(e)
if isinstance(e, PrestoBaseException):
return self.handle_fingerprinting_error(e, e.error_code)
else:
return self.handle_fingerprinting_error(e)
return result

def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
Expand All @@ -76,7 +82,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
for message in messages:
message.body.result = self.get_response(message)
return messages

@classmethod
def create(cls):
"""
Expand Down
18 changes: 12 additions & 6 deletions test/lib/model/test_classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from lib.model.classycat import Model as ClassyCatModel
from lib import schemas
from lib.base_exception import PrestoBaseException
import json


Expand Down Expand Up @@ -562,9 +563,11 @@ def test_classify_fail_wrong_response_format(self, file_exists_in_s3_mock, uploa
}
}
classify_message = schemas.parse_message(classify_input)
result = self.classycat_model.process(classify_message)

self.assertEqual(result.responseMessage, "Error classifying items: list index out of range")
with self.assertRaises(PrestoBaseException) as e:
self.classycat_model.process(classify_message)
self.assertIn("Error classifying items: list index out of range", e.message)
self.assertEqual(e.error_code, 500)

@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
Expand Down Expand Up @@ -700,15 +703,18 @@ def test_classify_fail_wrong_number_of_results(self, file_exists_in_s3_mock, upl
}
}
classify_message = schemas.parse_message(classify_input)
result = self.classycat_model.process(classify_message)

self.assertEqual(result.responseMessage, "Error classifying items: Not all items were classified successfully: input length 1, output length 2")
with self.assertRaises(PrestoBaseException) as e:
self.classycat_model.process(classify_message)
self.assertIn("Not all items were classified successfully: input length 1, output length 2", e.message)
self.assertEqual(e.error_code, 502)


@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
def test_classify_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
Expand Down Expand Up @@ -858,7 +864,7 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_fail_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
def test_classify_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
Expand Down

0 comments on commit 7cf0f34

Please sign in to comment.