From 4e0503d8f0ef22cda6f5b7527d3b4d5a32dd3fe1 Mon Sep 17 00:00:00 2001 From: Devin Gaffney Date: Tue, 17 Dec 2024 10:34:28 -0800 Subject: [PATCH] Cv2 5861 fix openai embeddings (#473) * CV2-5861 bypass presto for openai * refactor for singular path --- app/main/lib/elastic_crud.py | 11 +++++++---- app/test/test_elastic_crud.py | 4 ++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/app/main/lib/elastic_crud.py b/app/main/lib/elastic_crud.py index e8732d3b..0117c237 100644 --- a/app/main/lib/elastic_crud.py +++ b/app/main/lib/elastic_crud.py @@ -4,7 +4,7 @@ from flask import current_app as app from app.main.lib.presto import Presto, PRESTO_MODEL_MAP from app.main.lib.elasticsearch import store_document, get_by_doc_id - +from app.main.lib.openai import PREFIX_OPENAI def _after_log(retry_state): app.logger.debug("Retrying image similarity...") @@ -40,9 +40,12 @@ def get_presto_request_response(modality, callback_url, task): assert isinstance(response["body"], dict), f"Bad body for {modality}, {callback_url}, {task} - response was {response}" return response +def encodable_model(model_key, obj): + return model_key != "elasticsearch" and not obj.get('model_'+model_key) and model_key[:len(PREFIX_OPENAI)] != PREFIX_OPENAI + def requires_encoding(obj): for model_key in obj.get("models", []): - if model_key != "elasticsearch" and not obj.get('model_'+model_key): + if encodable_model(model_key, obj): return True return False @@ -55,7 +58,7 @@ def get_blocked_presto_response(task, model, modality): if requires_encoding(obj): blocked_results = [] for model_key in obj.get("models", []): - if model_key != "elasticsearch" and not obj.get('model_'+model_key): + if encodable_model(model_key, obj): response = get_presto_request_response(model_key, callback_url, obj) blocked_results.append({"model": model_key, "response": Presto.blocked_response(response, modality)}) # Warning: this is a blocking hold to wait until we get a response in @@ -73,7 +76,7 @@ def get_async_presto_response(task, model, modality): if requires_encoding(obj): responses = [] for model_key in obj.get("models", []): - if model_key != "elasticsearch" and not obj.get('model_'+model_key): + if encodable_model(model_key, obj): task["model"] = model_key responses.append(get_presto_request_response(model_key, callback_url, task)) return responses, True diff --git a/app/test/test_elastic_crud.py b/app/test/test_elastic_crud.py index 9d1b4345..8b14c203 100644 --- a/app/test/test_elastic_crud.py +++ b/app/test/test_elastic_crud.py @@ -73,6 +73,10 @@ def test_requires_encoding(self): obj = {'models': ['model1'], 'model_model1': 'encoded_data'} self.assertFalse(requires_encoding(obj)) + obj = {'models': ['openai-text-embedding-ada-002']} + self.assertFalse(requires_encoding(obj)) + + @patch('app.main.lib.elastic_crud.Presto.blocked_response') @patch('app.main.lib.elastic_crud.Presto.send_request') @patch('app.main.lib.elastic_crud.store_document')