diff --git a/internal/clientio/resp.go b/internal/clientio/resp.go index fdfc1dccbd..1394594500 100644 --- a/internal/clientio/resp.go +++ b/internal/clientio/resp.go @@ -236,6 +236,16 @@ func Encode(value interface{}, isSimple bool) []byte { buf.Write(encodeString(b)) // Encode each string and write to the buffer. } return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + case [][]interface{}: + var b []byte + buf := bytes.NewBuffer(b) + + buf.WriteString(fmt.Sprintf("*%d\r\n", len(v))) + + for _, list := range v { + buf.Write(Encode(list, false)) + } + return buf.Bytes() // Handle slices of custom objects (Obj). case []*object.Obj: @@ -255,6 +265,15 @@ func Encode(value interface{}, isSimple bool) []byte { } return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + // Handle slices of int64. + case []float64: + var b []byte + buf := bytes.NewBuffer(b) // Create a buffer for accumulating encoded values. + for _, b := range value.([]float64) { + buf.Write(Encode(b, false)) // Encode each int64 and write to the buffer. + } + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + // Handle slices of int64. case []int64: var b []byte diff --git a/internal/eval/sortedset/sorted_set.go b/internal/eval/sortedset/sorted_set.go index 9afcdbfe3b..9eeb17d02b 100644 --- a/internal/eval/sortedset/sorted_set.go +++ b/internal/eval/sortedset/sorted_set.go @@ -253,14 +253,11 @@ func (ss *Set) CountInRange(minVal, maxVal float64) int { } // GetScoreRange returns a slice of members with scores between min and max, inclusive. -// It returns the members in ascending order if reverse is false, and descending order if reverse is true. // If withScores is true, the members will be returned with their scores. -func (ss *Set) GetScoreRange( - minScore, maxScore float64, - withScores bool, - reverse bool, -) []string { - var result []string +func (ss *Set) GetMemberScoresInRange(minScore, maxScore float64, count, max int) ([]string, []float64) { + var members []string + var scores []float64 + iterFunc := func(item btree.Item) bool { ssi := item.(*Item) if ssi.Score < minScore { @@ -269,17 +266,17 @@ func (ss *Set) GetScoreRange( if ssi.Score > maxScore { return false } - result = append(result, ssi.Member) - if withScores { - scoreStr := strconv.FormatFloat(ssi.Score, 'g', -1, 64) - result = append(result, scoreStr) + members = append(members, ssi.Member) + scores = append(scores, ssi.Score) + count++ + + if max > 0 && count == max { + return false } + return true } - if reverse { - ss.tree.Descend(iterFunc) - } else { - ss.tree.Ascend(iterFunc) - } - return result -} \ No newline at end of file + + ss.tree.Ascend(iterFunc) + return members, scores +} diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index e375b69ad0..a866a83a97 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -6945,6 +6945,90 @@ func evalCommandDocs(args []string) *EvalResponse { return makeEvalResult(result) } +type geoRadiusOpts struct { + WithCoord bool + WithDist bool + WithHash bool + Count int // 0 means no count specified + CountAny bool // true if ANY was specified with COUNT + IsSorted bool // By default return items are not sorted + Ascending bool // If IsSorted is true, return items nearest to farthest relative to the center (ascending) or farthest to nearest relative to the center (descending) + Store string + StoreDist string +} + +func parseGeoRadiusOpts(args []string) (*geoRadiusOpts, error) { + opts := &geoRadiusOpts{ + Ascending: true, // Default to ascending order if sorted + } + + for i := 0; i < len(args); i++ { + param := strings.ToUpper(args[i]) + + switch param { + case "WITHDIST": + opts.WithDist = true + case "WITHCOORD": + opts.WithCoord = true + case "WITHHASH": + opts.WithHash = true + case "COUNT": + + // TODO validate this logic + + if i+1 >= len(args) { + return nil, fmt.Errorf("ERR syntax error") + } + + count, err := strconv.Atoi(args[i+1]) + if err != nil { + return nil, fmt.Errorf("ERR value is not an integer or out of range") + } + if count <= 0 { + return nil, fmt.Errorf("ERR COUNT must be > 0") + } + opts.Count = count + i++ + + // Check for ANY option after COUNT + if i+1 < len(args) && strings.ToUpper(args[i+1]) == "ANY" { + opts.CountAny = true + i++ + } + case "ASC": + opts.IsSorted = true + opts.Ascending = true + + case "DESC": + opts.IsSorted = true + opts.Ascending = false + + case "STORE": + if i+1 >= len(args) { + return nil, fmt.Errorf("STORE option requires a key name") + } + opts.Store = args[i+1] + i++ + + case "STOREDIST": + if i+1 >= len(args) { + return nil, fmt.Errorf("STOREDIST option requires a key name") + } + opts.StoreDist = args[i+1] + i++ + + default: + return nil, fmt.Errorf("unknown parameter: %s", args[i]) + } + } + + if opts.Store != "" && opts.StoreDist != "" { + return nil, fmt.Errorf("STORE and STOREDIST are mutually exclusive") + } + + return opts, nil +} + func evalGEORADIUSBYMEMBER(args []string, store *dstore.Store) *EvalResponse { if len(args) < 4 { return &EvalResponse{ @@ -6961,31 +7045,33 @@ func evalGEORADIUSBYMEMBER(args []string, store *dstore.Store) *EvalResponse { distVal, parseErr := strconv.ParseFloat(dist, 64) if parseErr != nil { return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrInvalidFloat, + Error: diceerrors.ErrInvalidFloat, } } - // TODO parse options - // parseGeoRadiusOptions(args[4:]) + opts, parseErr := parseGeoRadiusOpts(args[4:]) + if parseErr != nil { + return &EvalResponse{ + Result: nil, + Error: parseErr, + } + } obj := store.Get(key) if obj == nil { return &EvalResponse{ Result: clientio.NIL, - Error: nil, } } ss, err := sortedset.FromObject(obj) if err != nil { return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongTypeOperation, + Error: diceerrors.ErrWrongTypeOperation, } } - memberHash, ok := ss.Get(member) + centerHash, ok := ss.Get(member) if !ok { return &EvalResponse{ Result: nil, @@ -7001,14 +7087,20 @@ func evalGEORADIUSBYMEMBER(args []string, store *dstore.Store) *EvalResponse { } } - area, steps := geo.Area(memberHash, radius) + area, steps := geo.Area(centerHash, radius) /* When a huge Radius (in the 5000 km range or more) is used, * adjacent neighbors can be the same, leading to duplicated * elements. Skip every range which is the same as the one * processed previously. */ - var members []string + var hashes []float64 + + anyMax, count := 0, 0 + if opts.CountAny { + anyMax = opts.Count + } + var lastProcessed uint64 for _, hash := range area { if hash == 0 { @@ -7019,18 +7111,117 @@ func evalGEORADIUSBYMEMBER(args []string, store *dstore.Store) *EvalResponse { continue } - // TODO handle COUNT arg to limit number of returned members - hashMin, hashMax := geo.HashMinMax(hash, steps) - rangeMembers := ss.GetScoreRange(float64(hashMin), float64(hashMax), false, false) - for _, member := range rangeMembers { - members = append(members, fmt.Sprintf("%q", member)) + rangeMembers, rangeHashes := ss.GetMemberScoresInRange(float64(hashMin), float64(hashMax), count, anyMax) + members = append(members, rangeMembers...) + hashes = append(hashes, rangeHashes...) + } + + dists := make([]float64, 0, len(members)) + coords := make([][]float64, 0, len(members)) + + centerLat, centerLon := geo.DecodeHash(centerHash) + + if opts.IsSorted || opts.WithDist || opts.WithCoord { + for i := range hashes { + msLat, msLon := geo.DecodeHash(hashes[i]) + + if opts.WithDist || opts.IsSorted { + dist := geo.GetDistance(centerLon, centerLat, msLon, msLat) + dists = append(dists, dist) + } + + if opts.WithCoord { + coords = append(coords, []float64{msLat, msLon}) + } + } + } + + // Sorting is done by distance. Since our output can be dynamic and we can avoid allocating memory + // for each optional output property (hash, dist, coord), we follow an indirect sort approach: + // 1. Save the member inidices. + // 2. Sort the indices based on the distances in ascending or descending order. + // 3. Build the response based on the requested options. + indices := make([]int, len(members)) + for i := range indices { + indices[i] = i + } + + if opts.IsSorted { + if opts.Ascending { + sort.Slice(indices, func(i, j int) bool { + return dists[indices[i]] < dists[indices[j]] + }) + } else { + sort.Slice(indices, func(i, j int) bool { + return dists[indices[i]] > dists[indices[j]] + }) } } - // TODO handle options + optCount := 0 + if opts.WithDist { + optCount++ + } + + if opts.WithHash { + optCount++ + } + + if opts.WithCoord { + optCount++ + } + + max := opts.Count + if max > len(members) { + max = len(members) + } + + if optCount == 0 { + response := make([]string, len(members)) + for i := range members { + response[i] = members[indices[i]] + } + + if max > 0 { + response = response[:max] + } + + return &EvalResponse{ + Result: clientio.Encode(response, false), + } + } + + response := make([][]interface{}, len(members)) + for i := range members { + item := make([]any, optCount+1) + item[0] = members[i] + + itemIdx := 1 + + if opts.WithDist { + item[itemIdx] = dists[i] + itemIdx++ + } + + if opts.WithHash { + item[itemIdx] = hashes[i] + itemIdx++ + } + + if opts.WithCoord { + item[itemIdx] = coords[i] + itemIdx++ + } + + response[indices[i]] = item + } + + if max > 0 { + response = response[:max] + } return &EvalResponse{ - Result: clientio.Encode(members, false), + Result: clientio.Encode(response, false), } } diff --git a/internal/server/cmd_meta.go b/internal/server/cmd_meta.go index 8212d04952..83e47bb050 100644 --- a/internal/server/cmd_meta.go +++ b/internal/server/cmd_meta.go @@ -458,6 +458,10 @@ var ( Cmd: "GEODIST", CmdType: SingleShard, } + geoRadiusByMemberCmdMeta = CmdsMeta{ + Cmd: "GEORADIUSBYMEMBER", + CmdType: SingleShard, + } clientCmdMeta = CmdsMeta{ Cmd: "CLIENT", CmdType: SingleShard, @@ -506,7 +510,6 @@ var ( Cmd: "COMMAND|GETKEYSANDFLAGS", CmdType: SingleShard, } - // Metadata for multishard commands would go here. // These commands require both breakup and gather logic. @@ -637,6 +640,7 @@ func init() { WorkerCmdsMeta["RESTORE"] = restoreCmdMeta WorkerCmdsMeta["GEOADD"] = geoaddCmdMeta WorkerCmdsMeta["GEODIST"] = geodistCmdMeta + WorkerCmdsMeta["GEORADIUSBYMEMBER"] = geoRadiusByMemberCmdMeta WorkerCmdsMeta["CLIENT"] = clientCmdMeta WorkerCmdsMeta["LATENCY"] = latencyCmdMeta WorkerCmdsMeta["FLUSHDB"] = flushDBCmdMeta @@ -649,4 +653,5 @@ func init() { WorkerCmdsMeta["COMMAND|DOCS"] = CmdCommandDocs WorkerCmdsMeta["COMMAND|GETKEYS"] = CmdCommandGetKeys WorkerCmdsMeta["COMMAND|GETKEYSANDFLAGS"] = CmdCommandGetKeysFlags + // Additional commands (multishard, custom) can be added here as needed. }