Skip to content

Commit

Permalink
Add l2 normalization if needed when setting vector as target
Browse files Browse the repository at this point in the history
  • Loading branch information
ffont committed Dec 11, 2024
1 parent 9f0fba6 commit 8f29b21
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
6 changes: 2 additions & 4 deletions utils/search/backends/solr555pysolr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from utils.text import remove_control_chars
from utils.search import SearchEngineBase, SearchResults, SearchEngineException
from utils.search.backends.solr_common import SolrQuery, SolrResponseInterpreter
from utils.similarity_utilities import get_similarity_search_target_vector
from utils.similarity_utilities import get_similarity_search_target_vector, get_l2_normalized_vector


SOLR_FORUM_URL = f"{settings.SOLR5_BASE_URL}/forum"
Expand Down Expand Up @@ -278,9 +278,7 @@ def add_similarity_vectors_to_documents(self, sound_objects, documents):

if config_options.get('l2_norm', False):
# Normalize the vector to have unit length
norm = math.sqrt(sum([v*v for v in vector_data]))
if norm > 0:
vector_data = [v/norm for v in vector_data]
vector_data = get_l2_normalized_vector(vector_data)

sim_vector_document_data = {
'content_type': SOLR_DOC_CONTENT_TYPES['similarity_vector'],
Expand Down
10 changes: 10 additions & 0 deletions utils/similarity_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#

import logging
import math
import traceback

from django.conf import settings
Expand Down Expand Up @@ -170,6 +171,13 @@ def hash_cache_key(key):
return create_hash(key, limit=32)


def get_l2_normalized_vector(vector):
norm = math.sqrt(sum([v*v for v in vector]))
if norm > 0:
vector = [v/norm for v in vector]
return vector


def get_similarity_search_target_vector(sound_id, analyzer=settings.SEARCH_ENGINE_DEFAULT_SIMILARITY_ANALYZER):
# If the sound has been analyzed for similarity, returns the vector to be used for similarity search
sa = sounds.models.SoundAnalysis.objects.filter(sound_id=sound_id, analyzer=analyzer, analysis_status="OK")
Expand All @@ -180,5 +188,7 @@ def get_similarity_search_target_vector(sound_id, analyzer=settings.SEARCH_ENGIN
if data is not None:
vector_raw = data[config_options['vector_property_name']]
if vector_raw is not None:
if config_options['l2_norm']:
vector_raw = get_l2_normalized_vector(vector_raw)
return vector_raw
return None

0 comments on commit 8f29b21

Please sign in to comment.