Skip to content

Commit

Permalink
Add select star option and nested field expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
nuric committed Sep 7, 2024
1 parent e4c7cf1 commit debf542
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 29 deletions.
6 changes: 4 additions & 2 deletions httpapi/v2/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,8 @@ func Test_UpdatePoints(t *testing.T) {
Operator: models.OperatorEquals,
},
},
Limit: 10,
Select: []string{"*"},
Limit: 10,
}
var searchResp v2.SearchPointsResponse
resp = makeRequest(t, router, "POST", "/v1/collections/gandalf/points/search", sr, &searchResp)
Expand Down Expand Up @@ -646,7 +647,8 @@ func Test_SearchPoints(t *testing.T) {
Limit: 10,
},
},
Limit: 10,
Select: []string{"description"},
Limit: 10,
}
var respBody v2.SearchPointsResponse
resp := makeRequest(t, router, "POST", "/v1/collections/gandalf/points/search", sr, &respBody)
Expand Down
7 changes: 5 additions & 2 deletions shard/points.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func GetPointByUUID(bucket diskstore.ReadOnlyBucket, pointId uuid.UUID) (ShardPo
return sp, nil
}

func GetPointByNodeId(bucket diskstore.ReadOnlyBucket, nodeId uint64) (ShardPoint, error) {
func GetPointByNodeId(bucket diskstore.ReadOnlyBucket, nodeId uint64, withData bool) (ShardPoint, error) {
pointIdBytes := bucket.Get(conversion.NodeKey(nodeId, 'i'))
if pointIdBytes == nil {
return ShardPoint{}, ErrPointDoesNotExist
Expand All @@ -102,7 +102,10 @@ func GetPointByNodeId(bucket diskstore.ReadOnlyBucket, nodeId uint64) (ShardPoin
if err != nil {
return ShardPoint{}, fmt.Errorf("could not parse point id: %w", err)
}
data := bucket.Get(conversion.NodeKey(nodeId, 'd'))
var data []byte
if withData {
data = bucket.Get(conversion.NodeKey(nodeId, 'd'))
}
sp := ShardPoint{
Point: models.Point{
Id: pointId,
Expand Down
61 changes: 49 additions & 12 deletions shard/shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -354,7 +355,7 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
// ---------------------------
// Backfill point UUID and data
for _, r := range results {
sp, err := GetPointByNodeId(bPoints, r.NodeId)
sp, err := GetPointByNodeId(bPoints, r.NodeId, len(searchRequest.Select) > 0)
if err != nil {
return fmt.Errorf("could not get point by node id %d: %w", r.NodeId, err)
}
Expand All @@ -366,7 +367,7 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
it := rSet.Iterator()
for it.HasNext() {
nodeId := it.Next()
sp, err := GetPointByNodeId(bPoints, nodeId)
sp, err := GetPointByNodeId(bPoints, nodeId, len(searchRequest.Select) > 0)
if err != nil {
return fmt.Errorf("could not get point by node id %d: %w", nodeId, err)
}
Expand All @@ -381,8 +382,12 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
}
cacheTx.Commit(false)
// ---------------------------
// Select and sort
if len(searchRequest.Select) > 0 {
/* Select and sort, if we have * (star) then we don't need to do anything and
* let upstream handle decoding the whole data point. Otherwise we need to
* selectively decode the required properties. Note that we are allowing
* sorting after selecting star. So if there is something to sort even if
* select star is given we need to decode the lot. */
if (len(searchRequest.Select) > 0 && searchRequest.Select[0] != "*") || len(searchRequest.Sort) > 0 {
selectSortStart := time.Now()
/* We are selecting only a subset of the point data. We need to partial
* decode and re-encode the point data. */
Expand All @@ -396,8 +401,14 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
}
// E.g. ["name", "age"]
for _, p := range searchRequest.Select {
// E.g. p = "name"
// E.g. p = "name" or "*" (star)
dec.Reset(bytes.NewReader(r.Point.Data))
if p == "*" {
if err := dec.Decode(&finalResults[i].DecodedData); err != nil {
return nil, fmt.Errorf("could not decode all point data: %w", err)
}
break
}
res, err := dec.Query(p)
if err != nil {
return nil, fmt.Errorf("could not select point data, %s: %w", p, err)
Expand All @@ -407,16 +418,40 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
continue
}
// ---------------------------
// We originally implemented nested fields to create nested maps
// and populate accordingly but it adds extra for loops and
// complexity. It also entangles the sorting code below as well.
// For now, a select field such as "nested.field" will comes
// back flattened, e.g. {"nested.field": value} as opposed to
// {"nested": {"field": value}}.
/* We originally implemented nested fields to create nested maps
* and populate accordingly but it adds extra for loops and
* complexity. It also entangles the sorting code below as well.
* For now, a select field such as "nested.field" will comes
* back flattened, e.g. {"nested.field": value} as opposed to
* {"nested": {"field": value}}.
*
* UPDATE: We have decided to implemented the nested fields as it
* is more consistent with how the data is inputted. That is, the
* user gives us nested fields but upon retrieval we used to
* flatten it. This was confusing and we had implemented it as
* expanding nested fields originally, so we are going back to
* how things were. */
// ---------------------------
// Assign the value to final decoded data. This makes
// {"property": value} e.g. {"name": "james"}
finalResults[i].DecodedData[p] = res[0]
segments := strings.Split(p, ".")
// e.g. segments = ["nested", "field"] or ["name"]
current := finalResults[i].DecodedData
for j, s := range segments {
if j == len(segments)-1 {
current[s] = res[0]
break
}
// If the nested field does not exist, we create it
if _, ok := current[s]; !ok {
current[s] = make(map[string]any)
}
var ok bool
current, ok = current[s].(map[string]any)
if !ok {
return nil, fmt.Errorf("could not access nested property when selecting: %s", p)
}
}
}
// We erase data information as it is not needed any more, saves us
// from transmitting it
Expand All @@ -430,6 +465,8 @@ func (s *Shard) SearchPoints(searchRequest models.SearchRequest) ([]models.Searc
// ---------------------------
s.logger.Debug().Str("duration", time.Since(selectSortStart).String()).Msg("Search - Select Sort")
}
/* End of select sort, if we skipped it then the encoded data is transmitted,
* otherwise DecodedData is populated and sent instead. */
// ---------------------------
// Offset and limit
if searchRequest.Limit == 0 {
Expand Down
111 changes: 104 additions & 7 deletions shard/shard_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,63 @@ if rand.Float32() < 0.5 {
}
*/

func TestSearch_SelectNone(t *testing.T) {
// ---------------------------
s := tempShard(t)
points := randPoints(100)
err := s.InsertPoints(points)
require.NoError(t, err)
// ---------------------------
sr := models.SearchRequest{
Query: models.Query{
Property: "size",
Integer: &models.SearchIntegerOptions{
Value: 10,
EndValue: 15,
Operator: models.OperatorInRange,
},
},
}
res, err := s.SearchPoints(sr)
require.NoError(t, err)
require.Len(t, res, 6)
for i := 0; i < len(res); i++ {
require.Nil(t, res[i].Data)
require.Nil(t, res[i].Distance)
require.Nil(t, res[i].Score)
require.Nil(t, res[i].DecodedData)
}
}

func TestSearch_SelectStar(t *testing.T) {
// ---------------------------
s := tempShard(t)
points := randPoints(100)
err := s.InsertPoints(points)
require.NoError(t, err)
// ---------------------------
sr := models.SearchRequest{
Query: models.Query{
Property: "size",
Integer: &models.SearchIntegerOptions{
Value: 10,
EndValue: 15,
Operator: models.OperatorInRange,
},
},
Select: []string{"*"},
}
res, err := s.SearchPoints(sr)
require.NoError(t, err)
require.Len(t, res, 6)
for i := 0; i < len(res); i++ {
require.NotNil(t, res[i].Data)
require.Nil(t, res[i].Distance)
require.Nil(t, res[i].Score)
require.Nil(t, res[i].DecodedData)
}
}

func TestSearch_Select(t *testing.T) {
// ---------------------------
s := tempShard(t)
Expand Down Expand Up @@ -54,7 +111,7 @@ func TestSearch_Select(t *testing.T) {
}
}

func TestSearch_NestedField(t *testing.T) {
func TestSearch_SelectNestedField(t *testing.T) {
// ---------------------------
s := tempShard(t)
points := randPoints(10)
Expand All @@ -71,22 +128,60 @@ func TestSearch_NestedField(t *testing.T) {
Operator: "near",
},
},
Select: []string{"nested.vector", "nested.size", "nested"},
Select: []string{"nested.vector", "nested.size", "nested", "nested.size"},
}
s.InsertPoints(points)
res, err := s.SearchPoints(sr)
require.NoError(t, err)
require.Len(t, res, 5)
require.Equal(t, points[3].Id, res[0].Point.Id)
require.EqualValues(t, 0, *res[0].Distance)
// We're expecting something like {"nested.vector": [0.0, 1.0, 2.0, 3.0, 4.0], "nested.size": 3, "nested": {...}}
require.Len(t, res[0].DecodedData, 3)
// We're expecting something like {"nested": {"vector": [0.0, 1.0, 2.0, 3.0, 4.0], "size": 3}}
require.Len(t, res[0].DecodedData, 1)
require.Len(t, res[0].DecodedData["nested"], 2)
require.EqualValues(t, 3, res[0].DecodedData["nested.size"])
require.EqualValues(t, 3, res[0].DecodedData["nested"].(map[string]interface{})["size"])
require.NoError(t, s.Close())
}

func TestSearch_SelectStarNestedFieldSort(t *testing.T) {
// ---------------------------
s := tempShard(t)
points := randPoints(10)
err := s.InsertPoints(points)
require.NoError(t, err)
// ---------------------------
sr := models.SearchRequest{
Query: models.Query{
Property: "nested.vector",
VectorVamana: &models.SearchVectorVamanaOptions{
Vector: getVector(points[3]),
SearchSize: 75,
Limit: 5,
Operator: "near",
},
},
Select: []string{"*"},
Sort: []models.SortOption{
{Property: "nested.size", Descending: true},
},
}
s.InsertPoints(points)
res, err := s.SearchPoints(sr)
require.NoError(t, err)
require.Len(t, res, 5)
require.EqualValues(t, 0, *res[0].Distance)
// Check if the results are sorted in descending order
for i := 0; i < 5; i++ {
for j := i + 1; j < 5; j++ {
iv := res[i].DecodedData["nested"].(map[string]interface{})["size"]
jv := res[j].DecodedData["nested"].(map[string]interface{})["size"]
require.GreaterOrEqual(t, iv, jv)
}
}
require.NoError(t, s.Close())
}

func TestSearch_NestedFieldSort(t *testing.T) {
func TestSearch_SelectNestedFieldSort(t *testing.T) {
// ---------------------------
s := tempShard(t)
points := randPoints(10)
Expand Down Expand Up @@ -116,7 +211,9 @@ func TestSearch_NestedFieldSort(t *testing.T) {
// Check if the results are sorted in descending order
for i := 0; i < 5; i++ {
for j := i + 1; j < 5; j++ {
require.GreaterOrEqual(t, res[i].DecodedData["nested.size"], res[j].DecodedData["nested.size"])
iv := res[i].DecodedData["nested"].(map[string]interface{})["size"]
jv := res[j].DecodedData["nested"].(map[string]interface{})["size"]
require.GreaterOrEqual(t, iv, jv)
}
}
require.NoError(t, s.Close())
Expand Down
3 changes: 2 additions & 1 deletion shard/shard_vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ func searchRequest(p models.Point, limit int) models.SearchRequest {
Operator: "near",
},
},
Limit: limit,
Select: []string{"*"},
Limit: limit,
}
}

Expand Down
25 changes: 22 additions & 3 deletions utils/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"cmp"
"reflect"
"slices"
"strings"

"github.com/semafind/semadb/models"
)
Expand Down Expand Up @@ -33,15 +34,33 @@ func CompareAny(a, b any) int {
return 0
}

// Attemps to sort search results by the given properties.
// Accesses a nested property in a map of the form path "a.b.c".
func AccessNestedProperty(data map[string]any, path string) (any, bool) {
var current any = data
for _, p := range strings.Split(path, ".") {
switch v := current.(type) {
case map[string]any:
var ok bool
current, ok = v[p]
if !ok {
return nil, false
}
default:
return nil, false
}
}
return current, true
}

// Attempts to sort search results by the given properties.
func SortSearchResults(results []models.SearchResult, sortOpts []models.SortOption) {
/* Because we don't know the type of the values, this may be a costly
* operation to undertake. We should monitor how this performs. */
slices.SortFunc(results, func(a, b models.SearchResult) int {
for _, s := range sortOpts {
// E.g. s = "age"
av, aok := a.DecodedData[s.Property]
bv, bok := b.DecodedData[s.Property]
av, aok := AccessNestedProperty(a.DecodedData, s.Property)
bv, bok := AccessNestedProperty(b.DecodedData, s.Property)
/* If the property is missing, we need to decide what to do
* here. We can either put it at the top or bottom. We put it
* at the bottom for now so that missing values are last. */
Expand Down
Loading

0 comments on commit debf542

Please sign in to comment.