Skip to content

Commit

Permalink
fix: panic caused by type assert LocalSegment on Segment (#29018)
Browse files Browse the repository at this point in the history
related #29017

---------

Signed-off-by: yah01 <[email protected]>
  • Loading branch information
yah01 authored Dec 7, 2023
1 parent 43abe9c commit c4dda3c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 63 deletions.
3 changes: 1 addition & 2 deletions internal/querynodev2/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ func (node *QueryNode) loadDeltaLogs(ctx context.Context, req *querypb.LoadSegme
continue
}

local := segment.(*segments.LocalSegment)
err := node.loader.LoadDeltaLogs(ctx, local, info.GetDeltalogs())
err := node.loader.LoadDeltaLogs(ctx, segment, info.GetDeltalogs())
if err != nil {
if finalErr == nil {
finalErr = err
Expand Down
65 changes: 6 additions & 59 deletions internal/querynodev2/local_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,12 @@ package querynodev2

import (
"context"
"fmt"

"github.com/samber/lo"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)

var _ cluster.Worker = &LocalWorker{}
Expand All @@ -45,65 +39,18 @@ func NewLocalWorker(node *QueryNode) *LocalWorker {
}

func (w *LocalWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("segmentIDs", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 {
return info.GetSegmentID()
})),
zap.String("loadScope", req.GetLoadScope().String()),
)
w.node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(),
w.node.composeIndexMeta(req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta())
defer w.node.manager.Collection.Unref(req.GetCollectionID(), 1)
log.Info("start to load segments...")
loaded, err := w.node.loader.Load(ctx,
req.GetCollectionID(),
segments.SegmentTypeSealed,
req.GetVersion(),
req.GetInfos()...,
)
if err != nil {
return err
}

w.node.manager.Collection.Ref(req.GetCollectionID(), uint32(len(loaded)))

log.Info("load segments done...",
zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() })))
return err
status, err := w.node.LoadSegments(ctx, req)
return merr.CheckRPCCall(status, err)
}

func (w *LocalWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.String("scope", req.GetScope().String()),
)
log.Info("start to release segments")
sealedCount := 0
for _, id := range req.GetSegmentIDs() {
_, count := w.node.manager.Segment.Remove(id, req.GetScope())
sealedCount += count
}
w.node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount))

return nil
status, err := w.node.ReleaseSegments(ctx, req)
return merr.CheckRPCCall(status, err)
}

func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionId()),
zap.Int64("segmentID", req.GetSegmentId()),
)
log.Debug("start to process segment delete")
status, err := w.node.Delete(ctx, req)
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf(status.GetReason())
}
return nil
return merr.CheckRPCCall(status, err)
}

func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
Expand Down
7 changes: 7 additions & 0 deletions internal/querynodev2/local_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
Expand Down Expand Up @@ -112,6 +113,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
},
CollectionID: suite.collectionID,
Infos: lo.Map(suite.segmentIDs, func(segID int64, _ int) *querypb.SegmentLoadInfo {
return &querypb.SegmentLoadInfo{
Expand All @@ -129,6 +133,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() {

func (suite *LocalWorkerTestSuite) TestReleaseSegment() {
req := &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
},
CollectionID: suite.collectionID,
SegmentIDs: suite.segmentIDs,
}
Expand Down
3 changes: 1 addition & 2 deletions internal/querynodev2/segments/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
wg.Add(1)
go func(segment Segment, i int) {
defer wg.Done()
seg := segment.(*LocalSegment)
tr := timerecord.NewTimeRecorder("retrieveOnSegmentsWithStream")
result, err := seg.Retrieve(ctx, plan)
result, err := segment.Retrieve(ctx, plan)
if err != nil {
errs[i] = err
return
Expand Down

0 comments on commit c4dda3c

Please sign in to comment.