diff --git a/lib/base_exception.py b/lib/base_exception.py new file mode 100644 index 0000000..534d08c --- /dev/null +++ b/lib/base_exception.py @@ -0,0 +1,8 @@ +class PrestoBaseException(Exception): + def __init__(self, message: str, error_code: int = 500): + self.message = message + self.error_code = error_code + super().__init__(message) + + def __str__(self): + return f"{self.error_code}: {self.message}" \ No newline at end of file diff --git a/lib/model/classycat.py b/lib/model/classycat.py index 219016f..cceb17f 100644 --- a/lib/model/classycat.py +++ b/lib/model/classycat.py @@ -5,6 +5,7 @@ from lib.model.classycat_classify import Model as ClassifyModel from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel +from lib.base_exception import PrestoBaseException class Model(Model): @@ -21,5 +22,4 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB return ClassyCatSchemaCreateModel().process(message) else: logger.error(f"Unknown event type {event_type}") - message.body.result.responseMessage = f"Unknown event type {event_type}" - return message.body.result \ No newline at end of file + raise PrestoBaseException(f"Unknown event type {event_type}", 422) diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 1b47579..2f4b185 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -8,6 +8,7 @@ from lib.model.model import Model from lib.schemas import Message, ClassyCatBatchClassificationResponse from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3 +from lib.base_exception import PrestoBaseException class LLMClient: @@ -65,7 +66,9 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200): max_tokens=(max_tokens_per_item * items_count) + 15, temperature=0.5 ) -# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987) + + # TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987) + return completion.choices[0].message.content @@ -84,7 +87,7 @@ def get_llm_client(self, client_type, model_name): elif client_type == 'openrouter': return OpenRouterClient(model_name) else: - raise Exception(f"Unknown client type: {client_type}") + raise PrestoBaseException(f"Unknown LLM client type {client_type}", 500) def format_input_for_classification_prompt(self, items): return '\n'.join([f"{item}" for i, item in enumerate(items)]) @@ -119,8 +122,7 @@ def classify_and_store_results(self, schema_id, items): if (classification_results is None or len(classification_results) == 0 or len(classification_results) != len(items)): logger.info(f"Classification results: {classification_results}") - raise Exception(f"Not all items were classified successfully: " - f"input length {len(items)}, output length {len(classification_results)}") + raise PrestoBaseException(f"Not all items were classified successfully: input length {len(items)}, output length {len(classification_results)}", 502) final_results = [{'id': items[i]['id'], 'text': items[i]['text'], 'labels': classification_results[i]} for i in range(len(items))] @@ -213,12 +215,10 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse: result = message.body.result if not self.schema_id_exists(schema_id): - result.responseMessage = f"Schema id {schema_id} cannot be found" - return result + raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404) if len(items) > self.batch_size_limit: - result.responseMessage = f"Number of items exceeds batch size limit of {self.batch_size_limit}" - return result + raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422) try: result.classification_results = self.classify_and_store_results(schema_id, items) @@ -226,5 +226,7 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse: return result except Exception as e: logger.exception(f"Error classifying items: {e}") - result.responseMessage = f"Error classifying items: {e}" - return result + if isinstance(e, PrestoBaseException): + raise e + else: + raise PrestoBaseException(f"Error classifying items: {e}", 500) from e diff --git a/lib/model/classycat_schema_create.py b/lib/model/classycat_schema_create.py index edbd9bb..b229c44 100644 --- a/lib/model/classycat_schema_create.py +++ b/lib/model/classycat_schema_create.py @@ -5,6 +5,7 @@ from lib.logger import logger from lib.model.model import Model from lib.schemas import Message, ClassyCatSchemaResponse +from lib.base_exception import PrestoBaseException class Model(Model): @@ -192,15 +193,13 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: result = message.body.result if self.schema_name_exists(schema_name): - result.responseMessage = f"Schema name {schema_name} already exists" - return result + raise PrestoBaseException(f"Schema name {schema_name} already exists", 422) try: self.verify_schema_parameters(schema_name, topics, examples, languages) except Exception as e: logger.exception(f"Error verifying schema parameters: {e}") - result.responseMessage = f"Error verifying schema parameters. Stack trace: {e}" - return result + raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e try: result.schema_id = self.create_schema(schema_name, topics, examples, languages) @@ -208,8 +207,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: return result except Exception as e: logger.exception(f"Error creating schema: {e}") - result.responseMessage = f"Error creating schema. Stack trace: {e}" - return result + raise PrestoBaseException(f"Error creating schema: {e}", 500) from e def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo diff --git a/lib/model/classycat_schema_lookup.py b/lib/model/classycat_schema_lookup.py index f017800..af51b44 100644 --- a/lib/model/classycat_schema_lookup.py +++ b/lib/model/classycat_schema_lookup.py @@ -4,6 +4,7 @@ from lib.model.model import Model from lib.s3 import load_file_from_s3, file_exists_in_s3 from lib.schemas import Message, ClassyCatSchemaResponse +from lib.base_exception import PrestoBaseException class Model(Model): @@ -63,8 +64,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: result = message.body.result if not self.schema_name_exists(schema_name): - result.responseMessage = f"Schema name {schema_name} does not exist" - return result + raise PrestoBaseException(f"Schema name {schema_name} does not exist", 404) logger.debug(f"located schema_name record for '{schema_name}'") @@ -74,5 +74,4 @@ def process(self, message: Message) -> ClassyCatSchemaResponse: return result except Exception as e: logger.error(f"Error looking up schema name {schema_name}: {e}") - result.responseMessage = f"Error looking up schema name {schema_name}: {e}" - return result \ No newline at end of file + raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e \ No newline at end of file diff --git a/lib/model/model.py b/lib/model/model.py index 3e359e2..fc79271 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -10,9 +10,12 @@ from lib import schemas from lib.cache import Cache from lib.sentry import capture_custom_message +from lib.base_exception import PrestoBaseException + class Model(ABC): BATCH_SIZE = 1 + def __init__(self): self.model_name = os.environ.get("MODEL_NAME") @@ -43,7 +46,7 @@ def get_tempfile(self) -> Any: def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: return [] - def handle_fingerprinting_error(self, e): + def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> schemas.ErrorResponse: error_context = {"error": str(e)} for attr in ["__cause__", "__context__", "args", "__traceback__"]: if attr in dir(e): @@ -51,10 +54,10 @@ def handle_fingerprinting_error(self, e): error_context[attr] = '\n'.join(traceback.format_tb(getattr(e, attr))) else: error_context[attr] = str(getattr(e, attr)) - capture_custom_message("Error during fingerprinting for {self.model_name}", 'info', error_context) - return schemas.ErrorResponse(error=str(e), error_details=error_context) + capture_custom_message(f"Error during fingerprinting for {self.model_name}", 'error', error_context) + return schemas.ErrorResponse(error=str(e), error_details=error_context, error_code=response_code) - def get_response(self, message: schemas.Message) -> schemas.GenericItem: + def get_response(self, message: schemas.Message) -> schemas.GenericItem: # TODO note: the return type is wrong here """ Perform a lookup on the cache for a message, and if found, return that cached value. """ @@ -64,7 +67,10 @@ def get_response(self, message: schemas.Message) -> schemas.GenericItem: result = self.process(message) Cache.set_cached_result(message.body.content_hash, result) except Exception as e: - return self.handle_fingerprinting_error(e) + if isinstance(e, PrestoBaseException): + return self.handle_fingerprinting_error(e, e.error_code) + else: + return self.handle_fingerprinting_error(e) return result def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: @@ -76,7 +82,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li for message in messages: message.body.result = self.get_response(message) return messages - + @classmethod def create(cls): """ diff --git a/test/lib/model/test_classycat.py b/test/lib/model/test_classycat.py index 46d4b44..e85e708 100644 --- a/test/lib/model/test_classycat.py +++ b/test/lib/model/test_classycat.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from lib.model.classycat import Model as ClassyCatModel from lib import schemas +from lib.base_exception import PrestoBaseException import json @@ -562,9 +563,11 @@ def test_classify_fail_wrong_response_format(self, file_exists_in_s3_mock, uploa } } classify_message = schemas.parse_message(classify_input) - result = self.classycat_model.process(classify_message) - self.assertEqual(result.responseMessage, "Error classifying items: list index out of range") + with self.assertRaises(PrestoBaseException) as e: + self.classycat_model.process(classify_message) + self.assertIn("Error classifying items: list index out of range", e.message) + self.assertEqual(e.error_code, 500) @patch('lib.model.classycat_classify.OpenRouterClient.classify') @patch('lib.model.classycat_classify.load_file_from_s3') @@ -700,15 +703,18 @@ def test_classify_fail_wrong_number_of_results(self, file_exists_in_s3_mock, upl } } classify_message = schemas.parse_message(classify_input) - result = self.classycat_model.process(classify_message) - self.assertEqual(result.responseMessage, "Error classifying items: Not all items were classified successfully: input length 1, output length 2") + with self.assertRaises(PrestoBaseException) as e: + self.classycat_model.process(classify_message) + self.assertIn("Not all items were classified successfully: input length 1, output length 2", e.message) + self.assertEqual(e.error_code, 502) + @patch('lib.model.classycat_classify.OpenRouterClient.classify') @patch('lib.model.classycat_classify.load_file_from_s3') @patch('lib.model.classycat_classify.upload_file_to_s3') @patch('lib.model.classycat_classify.file_exists_in_s3') - def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock, + def test_classify_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock, load_file_from_s3_mock, openrouter_classify_mock): file_exists_in_s3_mock.return_value = True upload_file_to_s3_mock.return_value = None @@ -858,7 +864,7 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u @patch('lib.model.classycat_classify.load_file_from_s3') @patch('lib.model.classycat_classify.upload_file_to_s3') @patch('lib.model.classycat_classify.file_exists_in_s3') - def test_classify_fail_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock, + def test_classify_all_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock, load_file_from_s3_mock, openrouter_classify_mock): file_exists_in_s3_mock.return_value = True upload_file_to_s3_mock.return_value = None