Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CV2-5370] media text split #460

Merged
merged 14 commits into from
Oct 14, 2024
4 changes: 4 additions & 0 deletions .env_file.example
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ PROVIDER_IMAGE_CLASSIFICATION=google
# AWS_ACCESS_KEY_ID=
# AWS_SECRET_ACCESS_KEY=
# AWS_SESSION_TOKEN=
S3_ENDPOINT=http://minio:9000
AWS_DEFAULT_REGION=us-east-1
AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE
AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY


# Service host URLs
Expand Down
2 changes: 1 addition & 1 deletion .env_file.test
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DATABASE_HOST=postgres
DATABASE_USER=postgres
DATABASE_PASS=postgres
S3_ENDPOINT=http://minio:9000
AWS_DEFAULT_REGION=eu-west-1
AWS_DEFAULT_REGION=us-east-1
DGaffney marked this conversation as resolved.
Show resolved Hide resolved
AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE
AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY

Expand Down
2 changes: 1 addition & 1 deletion app/main/controller/similarity_sync_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def post(self, similarity_type):
app.logger.debug(f"Args are {args}")
if similarity_type == "text":
package = similarity.get_body_for_text_document(args, 'query')
return similarity.get_similar_items(package, similarity_type)
return similarity.blocking_get_similar_items(package, similarity_type)
else:
package = similarity.get_body_for_media_document(args, 'query')
return similarity.blocking_get_similar_items(package, similarity_type)
8 changes: 4 additions & 4 deletions app/main/lib/elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_presto_request_response(modality, callback_url, task):

def requires_encoding(obj):
for model_key in obj.get("models", []):
if not obj.get('model_'+model_key):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
return True
return False

Expand All @@ -58,15 +58,15 @@ def get_blocked_presto_response(task, model, modality):
for model_key in obj.pop("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
response = get_presto_request_response(model_key, callback_url, obj)
blocked_results.append(Presto.blocked_response(response, modality))
blocked_results.append({"model": model_key, "response": Presto.blocked_response(response, modality)})
# Warning: this is a blocking hold to wait until we get a response in
# a redis key that we've received something from presto.
return obj, temporary, get_context_for_search(task), blocked_results[-1]
return obj, temporary, get_context_for_search(task), blocked_results
else:
return obj, temporary, get_context_for_search(task), {"body": obj}

def get_async_presto_response(task, model, modality):
app.logger.error(f"get_async_presto_response: {task} {model} {modality}")
app.logger.info(f"get_async_presto_response: {task} {model} {modality}")
obj, _ = get_object(task, model)
callback_url = Presto.add_item_callback_url(app.config['ALEGRE_HOST'], modality)
if task.get("doc_id") is None:
Expand Down
35 changes: 21 additions & 14 deletions app/main/lib/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,24 @@ def get_all_documents_matching_context(context):
return []

def generate_matches(context):
"""
If the keys are not project_media_id, has_custom_id, or field, return ANDs for each field,
with ORs for intra-key values (e.g. foo = bar AND baz = (blah|bat))
DGaffney marked this conversation as resolved.
Show resolved Hide resolved
"""
matches = []
clause_count = 0
for key in context:
if isinstance(context[key], list):
clause_count += len(context[key])
matches.append({
'query_string': { 'query': str.join(" OR ", [f"context.{key}: {v}" for v in context[key]])}
})
else:
clause_count += 1
matches.append({
'match': { 'context.' + key: context[key] }
})
if key not in ["project_media_id", "has_custom_id", "field"]:
DGaffney marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(context[key], list):
clause_count += len(context[key])
matches.append({
'query_string': { 'query': str.join(" OR ", [f"context.{key}: {v}" for v in context[key]])}
})
else:
clause_count += 1
matches.append({
'match': { 'context.' + key: context[key] }
})
return matches, clause_count

def truncate_query(query, clause_count):
Expand Down Expand Up @@ -112,12 +117,14 @@ def get_by_doc_id(doc_id):
return response['_source']

def store_document(body, doc_id, language=None):
for field in ["per_model_threshold", "threshold", "model", "confirmed", "limit", "requires_callback"]:
body.pop(field, None)
storable_doc = {}
for k, v in body.items():
if k not in ["per_model_threshold", "threshold", "model", "confirmed", "limit", "requires_callback"]:
storable_doc[k] = v
indices = [app.config['ELASTICSEARCH_SIMILARITY']]
# 'auto' indicates we should try to guess the appropriate language
if language == 'auto':
text = body['content']
text = storable_doc['content']
language = LangidProvider.langid(text)['result']['language']
if language not in SUPPORTED_LANGUAGES:
app.logger.warning('Detected language {} is not supported'.format(language))
Expand All @@ -129,7 +136,7 @@ def store_document(body, doc_id, language=None):

results = []
for index in indices:
index_result = update_or_create_document(body, doc_id, index)
index_result = update_or_create_document(storable_doc, doc_id, index)
results.append(index_result)
if index_result['result'] not in ['created', 'updated', 'noop']:
app.logger.warning('Problem adding document to ES index for language {0}: {1}'.format(language, index_result))
Expand Down
1 change: 1 addition & 0 deletions app/main/lib/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"xlm-r-bert-base-nli-stsb-mean-tokens": "mean_tokens__Model",
"indian-sbert": "indian_sbert__Model",
"paraphrase-filipino-mpnet-base-v2": "fptg__Model",
"paraphrase-multilingual-mpnet-base-v2": "paraphrase_multilingual__Model"
}
PRESTO_RESPONSE_TIMEOUT = os.getenv('PRESTO_RESPONSE_TIMEOUT', 120)

Expand Down
6 changes: 5 additions & 1 deletion app/main/lib/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from app.main.lib.shared_models.video_model import VideoModel
from app.main.lib.presto import Presto, PRESTO_MODEL_MAP
from app.main.lib.image_similarity import add_image, callback_add_image, delete_image, blocking_search_image, async_search_image, async_search_image_on_callback
from app.main.lib.text_similarity import add_text, async_search_text, async_search_text_on_callback, callback_add_text, delete_text, search_text
from app.main.lib.text_similarity import add_text, async_search_text, async_search_text_on_callback, callback_add_text, delete_text, search_text, sync_search_text
DEFAULT_SEARCH_LIMIT = 200
logging.basicConfig(level=logging.INFO)
def get_body_for_media_document(params, mode):
Expand Down Expand Up @@ -200,6 +200,10 @@ def blocking_get_similar_items(item, similarity_type):
response = video_model().blocking_search(model_response_package(item, "search"), "video")
app.logger.info(f"[Alegre Similarity] [Item {item}, Similarity type: {similarity_type}] response for search was {response}")
return response
elif similarity_type == "text":
response = sync_search_text(item, "text")
DGaffney marked this conversation as resolved.
Show resolved Hide resolved
app.logger.info(f"[Alegre Similarity] [Item {item}, Similarity type: {similarity_type}] response for search was {response}")
return response
else:
raise Exception(f"{similarity_type} modality not implemented for blocking requests!")

Expand Down
39 changes: 31 additions & 8 deletions app/main/lib/text_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,28 @@ def get_document_body(body):
def async_search_text(task, modality):
return elastic_crud.get_async_presto_response(task, "text", modality)

def sync_search_text(task, modality):
obj, temporary, context, presto_result = elastic_crud.get_blocked_presto_response(task, "text", modality)
obj["models"] = ["elasticsearch"]
if isinstance(presto_result, list):
for presto_vector_result in presto_result:
obj['vector_'+presto_vector_result["model"]] = presto_vector_result["response"]["body"]["result"]
obj['model_'+presto_vector_result["model"]] = 1
obj["models"].append(presto_vector_result["model"])
document, _ = elastic_crud.get_object(obj, "text")
return search_text(document, True)

def fill_in_openai_embeddings(document):
for model_key in document.pop("models", []):
for model_key in document.get("models", []):
if model_key != "elasticsearch" and model_key[:len(PREFIX_OPENAI)] == PREFIX_OPENAI:
document['vector_'+model_key] = retrieve_openai_embeddings(document['content'], model_key)
document['model_'+model_key] = 1
store_document(document, document["doc_id"], document["language"])

def async_search_text_on_callback(task):
app.logger.info(f"async_search_text_on_callback(task) is {task}")
document = elastic_crud.get_object_by_doc_id(task["id"])
doc_id = task.get("raw", {}).get("doc_id")
document = elastic_crud.get_object_by_doc_id(doc_id)
fill_in_openai_embeddings(document)
app.logger.info(f"async_search_text_on_callback(task) document is {document}")
if not elastic_crud.requires_encoding(document):
Expand Down Expand Up @@ -76,7 +88,7 @@ def search_text(search_params, use_document_vectors=False):
if model_key != "elasticsearch":
search_params.pop("model", None)
if use_document_vectors:
vector_for_search = search_params[model_key+"-tokens"]
vector_for_search = search_params["vector_"+model_key]
else:
vector_for_search = None
result = search_text_by_model(dict(**search_params, **{'model': model_key}), vector_for_search)
Expand Down Expand Up @@ -175,6 +187,15 @@ def insert_model_into_response(hits, model_key):
hit["_source"]["model"] = model_key
return hits

def return_sources(results):
"""
Results come back as embedded responses raw from elasticsearch - Other services expect the
_source value to be the root dict, and also needs index and score to be persisted as well.
May throw an error if source has index and score keys some day, but easy to fix for that,
and should noisily break since it would have other downstream consequences.
"""
return [dict(**r["_source"], **{"index": r["_index"], "score": r["_score"]}) for r in results]
DGaffney marked this conversation as resolved.
Show resolved Hide resolved

def strip_vectors(results):
for result in results:
vector_keys = [key for key in result["_source"].keys() if key[:7] == "vector_"]
Expand Down Expand Up @@ -260,11 +281,13 @@ def search_text_by_model(search_params, vector_for_search):
body=body,
index=search_indices
)
response = strip_vectors(
restrict_results(
insert_model_into_response(result['hits']['hits'], model_key),
search_params,
model_key
response = return_sources(
strip_vectors(
restrict_results(
insert_model_into_response(result['hits']['hits'], model_key),
search_params,
model_key
)
)
)
return {
Expand Down
12 changes: 6 additions & 6 deletions app/test/test_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ def test_elasticsearch_performs_correct_fuzzy_search(self):
post_response = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
lookup["fuzzy"] = True
post_response_fuzzy = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
self.assertGreater(json.loads(post_response_fuzzy.data.decode())["result"][0]["_score"], json.loads(post_response.data.decode())["result"][0]["_score"])
self.assertGreater(json.loads(post_response_fuzzy.data.decode())["result"][0]["score"], json.loads(post_response.data.decode())["result"][0]["score"])
lookup["fuzzy"] = False
post_response_fuzzy = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
self.assertEqual(json.loads(post_response_fuzzy.data.decode())["result"][0]["_score"], json.loads(post_response.data.decode())["result"][0]["_score"])
self.assertEqual(json.loads(post_response_fuzzy.data.decode())["result"][0]["score"], json.loads(post_response.data.decode())["result"][0]["score"])

def test_elasticsearch_update_text(self):
with self.client:
Expand Down Expand Up @@ -455,7 +455,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

response = self.client.post(
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

response = self.client.post(
Expand All @@ -501,7 +501,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

def test_wrong_model_key(self):
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_min_es_search(self):
result = json.loads(response.data.decode())

self.assertEqual(1, len(result['result']))
data['min_es_score']=10+result['result'][0]['_score']
data['min_es_score']=10+result['result'][0]['score']

response = self.client.post(
'/text/similarity/search/',
Expand Down
6 changes: 3 additions & 3 deletions app/test/test_similarity_lang_analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_all_analyzers(self):
content_type='application/json'
)
result = json.loads(response.data.decode())
self.assertTrue(app.config['ELASTICSEARCH_SIMILARITY']+"_"+example['language'] in [e['_index'] for e in result['result']])
self.assertTrue(app.config['ELASTICSEARCH_SIMILARITY']+"_"+example['language'] in [e['index'] for e in result['result']])

def test_auto_language_id(self):
# language examples as input to language classifier
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_auto_language_id(self):
index_alias = app.config['ELASTICSEARCH_SIMILARITY']
if expected_lang is not None:
index_alias = app.config['ELASTICSEARCH_SIMILARITY']+"_"+expected_lang
self.assertTrue(index_alias in [e['_index'] for e in result['result']])
self.assertTrue(index_alias in [e['index'] for e in result['result']])

def test_auto_language_query(self):
# language examples as input to language classifier
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_auto_language_query(self):
index_alias = app.config['ELASTICSEARCH_SIMILARITY']
if expected_lang is not None:
index_alias = app.config['ELASTICSEARCH_SIMILARITY']+"_"+expected_lang
self.assertTrue(index_alias in [e['_index'] for e in result['result']])
self.assertTrue(index_alias in [e['index'] for e in result['result']])


if __name__ == '__main__':
Expand Down
Loading