From d18916ae9f29a10e62c43102f3f73ce56d936209 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Mon, 4 Nov 2024 07:04:59 -0800 Subject: [PATCH] CV2-5050 add additional context when transformer vectorization fails --- lib/model/generic_transformer.py | 2 +- lib/model/model.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/model/generic_transformer.py b/lib/model/generic_transformer.py index 193365a3..43fd858d 100644 --- a/lib/model/generic_transformer.py +++ b/lib/model/generic_transformer.py @@ -69,7 +69,7 @@ def _vectorize_and_cache(self, docs_to_process: List[schemas.Message], texts_to_ doc.body.result = vector Cache.set_cached_result(doc.body.content_hash, vector) except Exception as e: - self.handle_fingerprinting_error(e) + self.handle_fingerprinting_error(e, 500, {"texts_to_vectorize": texts_to_vectorize}) def vectorize(self, texts: List[str]) -> List[List[float]]: """ diff --git a/lib/model/model.py b/lib/model/model.py index 044f6945..dd0917ed 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -46,7 +46,7 @@ def get_tempfile(self) -> Any: def process(self, messages: Union[List[schemas.Message], schemas.Message]) -> List[schemas.Message]: return [] - def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> schemas.ErrorResponse: + def handle_fingerprinting_error(self, e: Exception, response_code: int = 500, additional_context: dict = {}) -> schemas.ErrorResponse: error_context = {"error": str(e)} for attr in ["__cause__", "__context__", "args", "__traceback__"]: if attr in dir(e): @@ -54,6 +54,8 @@ def handle_fingerprinting_error(self, e: Exception, response_code: int = 500) -> error_context[attr] = '\n'.join(traceback.format_tb(getattr(e, attr))) else: error_context[attr] = str(getattr(e, attr)) + for k,v in additional_context.items(): + error_context[k] = v capture_custom_message(f"Error during fingerprinting for {self.model_name}", 'error', error_context) return schemas.ErrorResponse(error=str(e), error_details=error_context, error_code=response_code)