Skip to content

Commit

Permalink
Cv2 5861 fix openai embeddings (#473)
Browse files Browse the repository at this point in the history
* CV2-5861 bypass presto for openai

* refactor for singular path
  • Loading branch information
DGaffney authored Dec 17, 2024
1 parent afb4431 commit 4e0503d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 7 additions & 4 deletions app/main/lib/elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import current_app as app
from app.main.lib.presto import Presto, PRESTO_MODEL_MAP
from app.main.lib.elasticsearch import store_document, get_by_doc_id

from app.main.lib.openai import PREFIX_OPENAI
def _after_log(retry_state):
app.logger.debug("Retrying image similarity...")

Expand Down Expand Up @@ -40,9 +40,12 @@ def get_presto_request_response(modality, callback_url, task):
assert isinstance(response["body"], dict), f"Bad body for {modality}, {callback_url}, {task} - response was {response}"
return response

def encodable_model(model_key, obj):
return model_key != "elasticsearch" and not obj.get('model_'+model_key) and model_key[:len(PREFIX_OPENAI)] != PREFIX_OPENAI

def requires_encoding(obj):
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
return True
return False

Expand All @@ -55,7 +58,7 @@ def get_blocked_presto_response(task, model, modality):
if requires_encoding(obj):
blocked_results = []
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
response = get_presto_request_response(model_key, callback_url, obj)
blocked_results.append({"model": model_key, "response": Presto.blocked_response(response, modality)})
# Warning: this is a blocking hold to wait until we get a response in
Expand All @@ -73,7 +76,7 @@ def get_async_presto_response(task, model, modality):
if requires_encoding(obj):
responses = []
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
task["model"] = model_key
responses.append(get_presto_request_response(model_key, callback_url, task))
return responses, True
Expand Down
4 changes: 4 additions & 0 deletions app/test/test_elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def test_requires_encoding(self):
obj = {'models': ['model1'], 'model_model1': 'encoded_data'}
self.assertFalse(requires_encoding(obj))

obj = {'models': ['openai-text-embedding-ada-002']}
self.assertFalse(requires_encoding(obj))


@patch('app.main.lib.elastic_crud.Presto.blocked_response')
@patch('app.main.lib.elastic_crud.Presto.send_request')
@patch('app.main.lib.elastic_crud.store_document')
Expand Down

0 comments on commit 4e0503d

Please sign in to comment.