From 4394c8b59872bdebfcd7f7ba1e2b5883f75f8857 Mon Sep 17 00:00:00 2001 From: Abhinav Dangeti Date: Wed, 18 Sep 2024 12:23:08 -0600 Subject: [PATCH] MB-63334: Use faiss's advertized method to normalize vector Uses new version of blevesearch/go-faiss that brings in: * f19c1d4 Abhi Dangeti | MB-63334: Method to normalize a single vector (https://github.com/blevesearch/go-faiss/pull/36) --- go.mod | 2 +- go.sum | 4 ++-- mapping/mapping_vectors.go | 20 +++----------------- mapping/mapping_vectors_test.go | 26 ++++++++++++++++++++++++++ 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index 7da6dbbde..e1d30aadb 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/bits-and-blooms/bitset v1.12.0 github.com/blevesearch/bleve_index_api v1.1.12 github.com/blevesearch/geo v0.1.20 + github.com/blevesearch/go-faiss v1.0.22-0.20240918182005-f19c1d446e92 github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/blevesearch/go-porterstemmer v1.0.3 github.com/blevesearch/goleveldb v1.0.1 @@ -32,7 +33,6 @@ require ( ) require ( - github.com/blevesearch/go-faiss v1.0.22-0.20240909180832-35a1ff78ead4 // indirect github.com/blevesearch/mmap-go v1.0.4 // indirect github.com/couchbase/ghistogram v0.1.0 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect diff --git a/go.sum b/go.sum index 660c832bb..137f0da0c 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/blevesearch/bleve_index_api v1.1.12 h1:P4bw9/G/5rulOF7SJ9l4FsDoo7UFJ+ github.com/blevesearch/bleve_index_api v1.1.12/go.mod h1:PbcwjIcRmjhGbkS/lJCpfgVSMROV6TRubGGAODaK1W8= github.com/blevesearch/geo v0.1.20 h1:paaSpu2Ewh/tn5DKn/FB5SzvH0EWupxHEIwbCk/QPqM= github.com/blevesearch/geo v0.1.20/go.mod h1:DVG2QjwHNMFmjo+ZgzrIq2sfCh6rIHzy9d9d0B59I6w= -github.com/blevesearch/go-faiss v1.0.22-0.20240909180832-35a1ff78ead4 h1:riy8XP3UIBeVjMhsq1r1aGfjvTf3aPp2PuXxdiw9P4s= -github.com/blevesearch/go-faiss v1.0.22-0.20240909180832-35a1ff78ead4/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk= +github.com/blevesearch/go-faiss v1.0.22-0.20240918182005-f19c1d446e92 h1:pDbDTN8dgycpdp9eCzrNp9e6Z4C+UQhCUAZbaarQ6Bs= +github.com/blevesearch/go-faiss v1.0.22-0.20240918182005-f19c1d446e92/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:kDy+zgJFJJoJYBvdfBSiZYBbdsUL0XcjHYWezpQBGPA= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 6591a1b13..ad9ffbcd2 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -19,9 +19,9 @@ package mapping import ( "fmt" - "math" "reflect" + faiss "github.com/blevesearch/go-faiss" "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" @@ -262,20 +262,6 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, return nil } -func NormalizeVector(vector []float32) []float32 { - // first calculate the magnitude of the vector - var mag float64 - for _, v := range vector { - mag += float64(v) * float64(v) - } - // cannot normalize a zero vector - // if the magnitude is 1, then the vector is already normalized - if mag != 0 && mag != 1 { - mag = math.Sqrt(mag) - // normalize the vector - for i, v := range vector { - vector[i] = float32(float64(v) / mag) - } - } - return vector +func NormalizeVector(vec []float32) []float32 { + return faiss.NormalizeVector(vec) } diff --git a/mapping/mapping_vectors_test.go b/mapping/mapping_vectors_test.go index 970a7de76..6d409760f 100644 --- a/mapping/mapping_vectors_test.go +++ b/mapping/mapping_vectors_test.go @@ -18,6 +18,7 @@ package mapping import ( + "reflect" "testing" ) @@ -306,3 +307,28 @@ func TestProcessVector(t *testing.T) { } } } + +func TestNormalizeVector(t *testing.T) { + vectors := [][]float32{ + []float32{1,2,3,4,5}, + []float32{1,0,0,0,0}, + []float32{0.182574183,0.365148365,0.547722578,0.730296731}, + []float32{1,1,1,1,1,1,1,1}, + []float32{0}, + } + + expectedNormalizedVectors := [][]float32{ + []float32{0.13483998,0.26967996,0.40451995,0.5393599,0.67419994}, + []float32{1,0,0,0,0}, + []float32{0.18257418,0.36514837,0.5477226,0.73029673}, + []float32{0.35355338,0.35355338,0.35355338,0.35355338,0.35355338,0.35355338,0.35355338,0.35355338}, + []float32{0}, + } + + for i := 0; i < len(vectors); i++ { + normalizedVector := NormalizeVector(vectors[i]) + if !reflect.DeepEqual(normalizedVector, expectedNormalizedVectors[i]) { + t.Errorf("[vector-%d] Expected: %v, Got: %v", i+1, expectedNormalizedVectors[i], normalizedVector) + } + } +}