Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MB-57926 - Searcher for similarity search #1869

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions index/scorch/snapshot_index_vr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package scorch

import (
"bytes"
"context"
"fmt"
"reflect"

"github.com/blevesearch/bleve/v2/size"
index "github.com/blevesearch/bleve_index_api"
segment_api "github.com/blevesearch/scorch_segment_api/v2"
)

var reflectStaticSizeIndexSnapshotVectorReader int

func init() {
var istfr IndexSnapshotVectorReader
reflectStaticSizeIndexSnapshotVectorReader = int(reflect.TypeOf(istfr).Size())
}

type IndexSnapshotVectorReader struct {
vector []float32
field string
k int64
snapshot *IndexSnapshot
postings []segment_api.VecPostingsList
iterators []segment_api.VecPostingsIterator
segmentOffset int
currPosting segment_api.VecPosting
currID index.IndexInternalID
ctx context.Context
}

func (i *IndexSnapshotVectorReader) Size() int {
sizeInBytes := reflectStaticSizeIndexSnapshotVectorReader + size.SizeOfPtr +
len(i.vector) + len(i.field) + len(i.currID)

for _, entry := range i.postings {
sizeInBytes += entry.Size()
}

for _, entry := range i.iterators {
sizeInBytes += entry.Size()
}

if i.currPosting != nil {
sizeInBytes += i.currPosting.Size()
}

return sizeInBytes
}

func (i *IndexSnapshotVectorReader) Next(preAlloced *index.VectorDoc) (
*index.VectorDoc, error) {
rv := preAlloced
if rv == nil {
rv = &index.VectorDoc{}
}

for i.segmentOffset < len(i.iterators) {
next, err := i.iterators[i.segmentOffset].Next()
if err != nil {
return nil, err
}
if next != nil {
// make segment number into global number by adding offset
globalOffset := i.snapshot.offsets[i.segmentOffset]
nnum := next.Number()
rv.ID = docNumberToBytes(rv.ID, nnum+globalOffset)
rv.Score = float64(next.Score())

i.currID = rv.ID
i.currPosting = next

return rv, nil
}
i.segmentOffset++
}

return nil, nil
}

func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID,
preAlloced *index.VectorDoc) (*index.VectorDoc, error) {

if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 {
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k)
if err != nil {
return nil, err
}
// close the current term field reader before replacing it with a new one
_ = i.Close()
*i = *(i2.(*IndexSnapshotVectorReader))
}

num, err := docInternalToNumber(ID)
if err != nil {
return nil, fmt.Errorf("error converting to doc number % x - %v", ID, err)
}
segIndex, ldocNum := i.snapshot.segmentIndexAndLocalDocNumFromGlobal(num)
if segIndex >= len(i.snapshot.segment) {
return nil, fmt.Errorf("computed segment index %d out of bounds %d",
segIndex, len(i.snapshot.segment))
}
// skip directly to the target segment
i.segmentOffset = segIndex
next, err := i.iterators[i.segmentOffset].Advance(ldocNum)
if err != nil {
return nil, err
}
if next == nil {
// we jumped directly to the segment that should have contained it
// but it wasn't there, so reuse Next() which should correctly
// get the next hit after it (we moved i.segmentOffset)
return i.Next(preAlloced)
}

if preAlloced == nil {
preAlloced = &index.VectorDoc{}
}
preAlloced.ID = docNumberToBytes(preAlloced.ID, next.Number()+
i.snapshot.offsets[segIndex])
i.currID = preAlloced.ID
i.currPosting = next
return preAlloced, nil
}

func (i *IndexSnapshotVectorReader) Count() uint64 {
var rv uint64
for _, posting := range i.postings {
rv += posting.Count()
}
return rv
}

func (i *IndexSnapshotVectorReader) Close() error {
// TODO Consider if any scope of recycling here.
return nil
}
57 changes: 57 additions & 0 deletions index/scorch/snapshot_vector_index.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package scorch

import (
"context"

index "github.com/blevesearch/bleve_index_api"
segment_api "github.com/blevesearch/scorch_segment_api/v2"
)

func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32,
field string, k int64) (
index.VectorReader, error) {

rv := &IndexSnapshotVectorReader{
vector: vector,
field: field,
k: k,
snapshot: is,
}

if rv.postings == nil {
rv.postings = make([]segment_api.VecPostingsList, len(is.segment))
}
if rv.iterators == nil {
rv.iterators = make([]segment_api.VecPostingsIterator, len(is.segment))
}

for i, seg := range is.segment {
if sv, ok := seg.segment.(segment_api.VectorSegment); ok {
pl, err := sv.SimilarVectors(field, vector, k, seg.deleted)
if err != nil {
return nil, err
}
rv.postings[i] = pl
rv.iterators[i] = pl.Iterator(rv.iterators[i])
}
}

return rv, nil
}
6 changes: 5 additions & 1 deletion index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,11 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey,
search.GeoBufferPoolCallbackFunc(getBufferPool))

searcher, err := req.Query.Searcher(ctx, indexReader, i.m, search.SearcherOptions{
// Using a disjunction query to get union of results from KNN query
// and the original query
searchQuery := disjunctQueryWithKNN(req)

searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{
Explain: req.Explain,
IncludeTermVectors: req.IncludeLocations || req.Highlight != nil,
Score: req.Score,
Expand Down
33 changes: 33 additions & 0 deletions knn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package bleve

import (
"github.com/blevesearch/bleve/v2/search/query"
)

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
if req.KNN != nil {
knnQuery := query.NewKNNQuery(req.KNN.Vector)
knnQuery.SetFieldVal(req.KNN.Field)
knnQuery.SetK(req.KNN.K)
knnQuery.SetBoost(req.KNN.Boost.Value())
return query.NewDisjunctionQuery([]query.Query{req.Query, knnQuery})
}
return req.Query
}
24 changes: 24 additions & 0 deletions knn_noop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !vectors
// +build !vectors

package bleve

import "github.com/blevesearch/bleve/v2/search/query"

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
return req.Query
}
27 changes: 27 additions & 0 deletions mapping/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,33 @@ func (im *IndexMappingImpl) FieldAnalyzer(field string) string {
return im.AnalyzerNameForPath(field)
}

// FieldMappingForPath returns the mapping for a specific field 'path'.
func (im *IndexMappingImpl) FieldMappingForPath(path string) FieldMapping {
if im.TypeMapping != nil {
for _, v := range im.TypeMapping {
for field, property := range v.Properties {
for _, v1 := range property.Fields {
if field == path {
// Return field mapping if the name matches the path param.
return *v1
}
}
}
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From @Thejas-bhat : i think you might need to check the default mapping for the path as well (if the type mappings don't have the path in it). just like how AnalyzerNameForPath(paht string) does

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, pls take a look and let me know.


for field, property := range im.DefaultMapping.Properties {
for _, v1 := range property.Fields {
if field == path {
// Return field mapping if the name matches the path param.
return *v1
}
}
}

return FieldMapping{}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm thinking if its better we return a nil over here, avoiding allocation of a struct over here, and in the caller if the field mapping is nil, let the caller handle accordingly. for example, in the searcher construction for knn, if the field mapping is nil, we use the default SimilarityDefaultVal. any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I can do that since I'm not returning a pointer over here.
I will add a check in the caller though to use the default val if it's "".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls resolve this if it makes sense, thanks!

}

// wrapper to satisfy new interface

func (im *IndexMappingImpl) DefaultSearchField() string {
Expand Down
2 changes: 2 additions & 0 deletions mapping/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ type IndexMapping interface {

AnalyzerNameForPath(path string) string
AnalyzerNamed(name string) analysis.Analyzer

FieldMappingForPath(path string) FieldMapping
}
4 changes: 2 additions & 2 deletions mapping_with_dv.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ package bleve

import "github.com/blevesearch/bleve/v2/mapping"

func NewVectorFieldMapping() *mapping.FieldMapping {
return mapping.NewVectorFieldMapping()
func NewDenseVectorFieldMapping() *mapping.FieldMapping {
return mapping.NewDenseVectorFieldMapping()
}
Loading
Loading