diff --git a/utils/search/backends/solr555pysolr.py b/utils/search/backends/solr555pysolr.py index 2a7154365..f9ecd4c84 100644 --- a/utils/search/backends/solr555pysolr.py +++ b/utils/search/backends/solr555pysolr.py @@ -33,7 +33,7 @@ from utils.text import remove_control_chars from utils.search import SearchEngineBase, SearchResults, SearchEngineException from utils.search.backends.solr_common import SolrQuery, SolrResponseInterpreter -from utils.similarity_utilities import get_similarity_search_target_vector +from utils.similarity_utilities import get_similarity_search_target_vector, get_l2_normalized_vector SOLR_FORUM_URL = f"{settings.SOLR5_BASE_URL}/forum" @@ -278,9 +278,7 @@ def add_similarity_vectors_to_documents(self, sound_objects, documents): if config_options.get('l2_norm', False): # Normalize the vector to have unit length - norm = math.sqrt(sum([v*v for v in vector_data])) - if norm > 0: - vector_data = [v/norm for v in vector_data] + vector_data = get_l2_normalized_vector(vector_data) sim_vector_document_data = { 'content_type': SOLR_DOC_CONTENT_TYPES['similarity_vector'], diff --git a/utils/similarity_utilities.py b/utils/similarity_utilities.py index e7719d7b1..1178f105d 100644 --- a/utils/similarity_utilities.py +++ b/utils/similarity_utilities.py @@ -19,6 +19,7 @@ # import logging +import math import traceback from django.conf import settings @@ -170,6 +171,13 @@ def hash_cache_key(key): return create_hash(key, limit=32) +def get_l2_normalized_vector(vector): + norm = math.sqrt(sum([v*v for v in vector])) + if norm > 0: + vector = [v/norm for v in vector] + return vector + + def get_similarity_search_target_vector(sound_id, analyzer=settings.SEARCH_ENGINE_DEFAULT_SIMILARITY_ANALYZER): # If the sound has been analyzed for similarity, returns the vector to be used for similarity search sa = sounds.models.SoundAnalysis.objects.filter(sound_id=sound_id, analyzer=analyzer, analysis_status="OK") @@ -180,5 +188,7 @@ def get_similarity_search_target_vector(sound_id, analyzer=settings.SEARCH_ENGIN if data is not None: vector_raw = data[config_options['vector_property_name']] if vector_raw is not None: + if config_options['l2_norm']: + vector_raw = get_l2_normalized_vector(vector_raw) return vector_raw return None