Skip to content

Commit

Permalink
Merge branch 'refs/heads/master' into cv2-5011-internal-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ashkankzme committed Aug 14, 2024
2 parents fb9db8b + 7cf0f34 commit a6f0083
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 35 deletions.
8 changes: 8 additions & 0 deletions lib/base_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class PrestoBaseException(Exception):
def __init__(self, message: str, error_code: int = 500):
self.message = message
self.error_code = error_code
super().__init__(message)

def __str__(self):
return f"{self.error_code}: {self.message}"
4 changes: 2 additions & 2 deletions lib/model/classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.model.classycat_classify import Model as ClassifyModel
from lib.model.classycat_schema_create import Model as ClassyCatSchemaCreateModel
from lib.model.classycat_schema_lookup import Model as ClassyCatSchemaLookupModel
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand All @@ -21,5 +22,4 @@ def process(self, message: Message) -> Union[ClassyCatSchemaResponse, ClassyCatB
return ClassyCatSchemaCreateModel().process(message)
else:
logger.error(f"Unknown event type {event_type}")
message.body.result.responseMessage = f"Unknown event type {event_type}"
return message.body.result
raise PrestoBaseException(f"Unknown event type {event_type}", 422)
42 changes: 29 additions & 13 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lib.model.model import Model
from lib.schemas import Message, ClassyCatBatchClassificationResponse
from lib.s3 import load_file_from_s3, file_exists_in_s3, upload_file_to_s3
from lib.base_exception import PrestoBaseException


class LLMClient:
Expand Down Expand Up @@ -65,7 +66,9 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200):
max_tokens=(max_tokens_per_item * items_count) + 15,
temperature=0.5
)
# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)

return completion.choices[0].message.content


Expand All @@ -84,7 +87,7 @@ def get_llm_client(self, client_type, model_name):
elif client_type == 'openrouter':
return OpenRouterClient(model_name)
else:
raise Exception(f"Unknown client type: {client_type}")
raise PrestoBaseException(f"Unknown LLM client type {client_type}", 500)

def format_input_for_classification_prompt(self, items):
return '\n'.join([f"<ITEM_{i}>{item}</ITEM_{i}>" for i, item in enumerate(items)])
Expand Down Expand Up @@ -119,13 +122,26 @@ def classify_and_store_results(self, schema_id, items):
if (classification_results is None or len(classification_results) == 0
or len(classification_results) != len(items)):
logger.info(f"Classification results: {classification_results}")
raise Exception(f"Not all items were classified successfully: "
f"input length {len(items)}, output length {len(classification_results)}")
# TODO: validate response label against schema https://meedan.atlassian.net/browse/CV2-4801
raise PrestoBaseException(f"Not all items were classified successfully: input length {len(items)}, output length {len(classification_results)}", 502)

final_results = [{'id': items[i]['id'], 'text': items[i]['text'], 'labels': classification_results[i]}
for i in range(len(items))]
results_file_id = str(uuid.uuid4())
upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results))

# filtering out the results that have out-of-schema labels
# our of schema labels will not be included in the final results,
# and items with no labels can be retried later by the user, indicated by an empty list for labels
permitted_labels = [topic['topic'] for topic in schema['topics']] + ['Other', 'Unsure']
for result in final_results:

# log the items that had at least one out-of-schema label
if not all([label in permitted_labels for label in result['labels']]):
logger.error(f"Item {result['id']} had out-of-schema labels: {result['labels']}, permitted labels: {permitted_labels}")

result['labels'] = [label for label in result['labels'] if label in permitted_labels]

if not all([len(result['labels']) == 0 for result in final_results]):
results_file_id = str(uuid.uuid4())
upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results))

return final_results

Expand Down Expand Up @@ -199,18 +215,18 @@ def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
result = message.body.result

if not self.schema_id_exists(schema_id):
result.responseMessage = f"Schema id {schema_id} cannot be found"
return result
raise PrestoBaseException(f"Schema id {schema_id} cannot be found", 404)

if len(items) > self.batch_size_limit:
result.responseMessage = f"Number of items exceeds batch size limit of {self.batch_size_limit}"
return result
raise PrestoBaseException(f"Number of items exceeds batch size limit of {self.batch_size_limit}", 422)

try:
result.classification_results = self.classify_and_store_results(schema_id, items)
result.responseMessage = "success"
return result
except Exception as e:
logger.exception(f"Error classifying items: {e}")
result.responseMessage = f"Error classifying items: {e}"
return result
if isinstance(e, PrestoBaseException):
raise e
else:
raise PrestoBaseException(f"Error classifying items: {e}", 500) from e
10 changes: 4 additions & 6 deletions lib/model/classycat_schema_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lib.logger import logger
from lib.model.model import Model
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -192,24 +193,21 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} already exists"
return result
raise PrestoBaseException(f"Schema name {schema_name} already exists", 422)

try:
self.verify_schema_parameters(schema_name, topics, examples, languages)
except Exception as e:
logger.exception(f"Error verifying schema parameters: {e}")
result.responseMessage = f"Error verifying schema parameters. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error verifying schema parameters: {e}", 422) from e

try:
result.schema_id = self.create_schema(schema_name, topics, examples, languages)
result.responseMessage = 'success'
return result
except Exception as e:
logger.exception(f"Error creating schema: {e}")
result.responseMessage = f"Error creating schema. Stack trace: {e}"
return result
raise PrestoBaseException(f"Error creating schema: {e}", 500) from e


def verify_schema_parameters(self, schema_name, topics, examples, languages): #todo
Expand Down
7 changes: 3 additions & 4 deletions lib/model/classycat_schema_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.model.model import Model
from lib.s3 import load_file_from_s3, file_exists_in_s3
from lib.schemas import Message, ClassyCatSchemaResponse
from lib.base_exception import PrestoBaseException


class Model(Model):
Expand Down Expand Up @@ -63,8 +64,7 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
result = message.body.result

if not self.schema_name_exists(schema_name):
result.responseMessage = f"Schema name {schema_name} does not exist"
return result
raise PrestoBaseException(f"Schema name {schema_name} does not exist", 404)

logger.debug(f"located schema_name record for '{schema_name}'")

Expand All @@ -74,5 +74,4 @@ def process(self, message: Message) -> ClassyCatSchemaResponse:
return result
except Exception as e:
logger.error(f"Error looking up schema name {schema_name}: {e}")
result.responseMessage = f"Error looking up schema name {schema_name}: {e}"
return result
raise PrestoBaseException(f"Error looking up schema name {schema_name}", 500) from e
18 changes: 12 additions & 6 deletions lib/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from lib import schemas
from lib.cache import Cache
from lib.sentry import capture_custom_message
from lib.base_exception import PrestoBaseException


class Model(ABC):
BATCH_SIZE = 1

def __init__(self):
self.model_name = os.environ.get("MODEL_NAME")

Expand Down Expand Up @@ -43,18 +46,18 @@ def get_tempfile(self) -> Any:
def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
return []

def handle_fingerprinting_error(self, e):
def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> schemas.ErrorResponse:
error_context = {"error": str(e)}
for attr in ["__cause__", "__context__", "args", "__traceback__"]:
if attr in dir(e):
if attr == "__traceback__":
error_context[attr] = '\n'.join(traceback.format_tb(getattr(e, attr)))
else:
error_context[attr] = str(getattr(e, attr))
capture_custom_message("Error during fingerprinting for {self.model_name}", 'info', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context)
capture_custom_message(f"Error during fingerprinting for {self.model_name}", 'error', error_context)
return schemas.ErrorResponse(error=str(e), error_details=error_context, error_code=response_code)

def get_response(self, message: schemas.Message) -> schemas.GenericItem:
def get_response(self, message: schemas.Message) -> schemas.GenericItem: # TODO note: the return type is wrong here
"""
Perform a lookup on the cache for a message, and if found, return that cached value.
"""
Expand All @@ -64,7 +67,10 @@ def get_response(self, message: schemas.Message) -> schemas.GenericItem:
result = self.process(message)
Cache.set_cached_result(message.body.content_hash, result)
except Exception as e:
return self.handle_fingerprinting_error(e)
if isinstance(e, PrestoBaseException):
return self.handle_fingerprinting_error(e, e.error_code)
else:
return self.handle_fingerprinting_error(e)
return result

def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]:
Expand All @@ -76,7 +82,7 @@ def respond(self, messages: Union[List[schemas.Message], schemas.Message]) -> Li
for message in messages:
message.body.result = self.get_response(message)
return messages

@classmethod
def create(cls):
"""
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fasttext==0.9.2
langcodes==3.3.0
requests==2.32.2
pytest==7.4.0
sentry-sdk==1.30.0
sentry-sdk==2.8.0
yake==0.4.8
opentelemetry-api==1.24.0
opentelemetry-exporter-otlp-proto-http==1.24.0
Expand Down
Loading

0 comments on commit a6f0083

Please sign in to comment.