diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 59cbc3c9a26d7..2e50721ed2d38 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -55,6 +55,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -83,8 +85,8 @@ type Server struct { // component client etcdCli *clientv3.Client tikvCli *txnkv.Client - rootCoord types.RootCoordClient - dataCoord types.DataCoordClient + rootCoord *syncutil.Future[types.RootCoordClient] + dataCoord *syncutil.Future[types.DataCoordClient] chunkManager storage.ChunkManager componentState *componentutil.ComponentStateService } @@ -95,6 +97,8 @@ func NewServer(ctx context.Context, f dependency.Factory) (*Server, error) { return &Server{ stopOnce: sync.Once{}, factory: f, + dataCoord: syncutil.NewFuture[types.DataCoordClient](), + rootCoord: syncutil.NewFuture[types.RootCoordClient](), grpcServerChan: make(chan struct{}), componentState: componentutil.NewComponentStateService(typeutil.StreamingNodeRole), ctx: ctx1, @@ -166,8 +170,17 @@ func (s *Server) stop() { // Stop rootCoord client. log.Info("streamingnode stop rootCoord client...") - if err := s.rootCoord.Close(); err != nil { - log.Warn("streamingnode stop rootCoord client failed", zap.Error(err)) + if s.rootCoord.Ready() { + if err := s.rootCoord.Get().Close(); err != nil { + log.Warn("streamingnode stop rootCoord client failed", zap.Error(err)) + } + } + + log.Info("streamingnode stop dataCoord client...") + if s.dataCoord.Ready() { + if err := s.dataCoord.Get().Close(); err != nil { + log.Warn("streamingnode stop dataCoord client failed", zap.Error(err)) + } } // Stop tikv @@ -216,12 +229,8 @@ func (s *Server) init() (err error) { if err := s.initSession(); err != nil { return err } - if err := s.initRootCoord(); err != nil { - return err - } - if err := s.initDataCoord(); err != nil { - return err - } + s.initRootCoord() + s.initDataCoord() s.initGRPCServer() // Create StreamingNode service. @@ -300,36 +309,48 @@ func (s *Server) initMeta() error { return nil } -func (s *Server) initRootCoord() (err error) { +func (s *Server) initRootCoord() { log := log.Ctx(s.ctx) - log.Info("StreamingNode connect to rootCoord...") - s.rootCoord, err = rcc.NewClient(s.ctx) - if err != nil { - return errors.Wrap(err, "StreamingNode try to new RootCoord client failed") - } + go func() { + retry.Do(s.ctx, func() error { + log.Info("StreamingNode connect to rootCoord...") + rootCoord, err := rcc.NewClient(s.ctx) + if err != nil { + return errors.Wrap(err, "StreamingNode try to new RootCoord client failed") + } - log.Info("StreamingNode try to wait for RootCoord ready") - err = componentutil.WaitForComponentHealthy(s.ctx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) - if err != nil { - return errors.Wrap(err, "StreamingNode wait for RootCoord ready failed") - } - return nil + log.Info("StreamingNode try to wait for RootCoord ready") + err = componentutil.WaitForComponentHealthy(s.ctx, rootCoord, "RootCoord", 1000000, time.Millisecond*200) + if err != nil { + return errors.Wrap(err, "StreamingNode wait for RootCoord ready failed") + } + log.Info("StreamingNode wait for RootCoord done") + s.rootCoord.Set(rootCoord) + return nil + }, retry.AttemptAlways()) + }() } -func (s *Server) initDataCoord() (err error) { +func (s *Server) initDataCoord() { log := log.Ctx(s.ctx) - log.Info("StreamingNode connect to dataCoord...") - s.dataCoord, err = dcc.NewClient(s.ctx) - if err != nil { - return errors.Wrap(err, "StreamingNode try to new DataCoord client failed") - } + go func() { + retry.Do(s.ctx, func() error { + log.Info("StreamingNode connect to dataCoord...") + dataCoord, err := dcc.NewClient(s.ctx) + if err != nil { + return errors.Wrap(err, "StreamingNode try to new DataCoord client failed") + } - log.Info("StreamingNode try to wait for DataCoord ready") - err = componentutil.WaitForComponentHealthy(s.ctx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) - if err != nil { - return errors.Wrap(err, "StreamingNode wait for DataCoord ready failed") - } - return nil + log.Info("StreamingNode try to wait for DataCoord ready") + err = componentutil.WaitForComponentHealthy(s.ctx, dataCoord, "DataCoord", 1000000, time.Millisecond*200) + if err != nil { + return errors.Wrap(err, "StreamingNode wait for DataCoord ready failed") + } + log.Info("StreamingNode wait for DataCoord ready") + s.dataCoord.Set(dataCoord) + return nil + }, retry.AttemptAlways()) + }() } func (s *Server) initChunkManager() (err error) { diff --git a/internal/streamingnode/server/builder.go b/internal/streamingnode/server/builder.go index cdf725df55d01..f35e76b233375 100644 --- a/internal/streamingnode/server/builder.go +++ b/internal/streamingnode/server/builder.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // ServerBuilder is used to build a server. @@ -18,8 +19,8 @@ import ( type ServerBuilder struct { etcdClient *clientv3.Client grpcServer *grpc.Server - rc types.RootCoordClient - dc types.DataCoordClient + rc *syncutil.Future[types.RootCoordClient] + dc *syncutil.Future[types.DataCoordClient] session *sessionutil.Session kv kv.MetaKv chunkManager storage.ChunkManager @@ -49,13 +50,13 @@ func (b *ServerBuilder) WithGRPCServer(svr *grpc.Server) *ServerBuilder { } // WithRootCoordClient sets root coord client to the server builder. -func (b *ServerBuilder) WithRootCoordClient(rc types.RootCoordClient) *ServerBuilder { +func (b *ServerBuilder) WithRootCoordClient(rc *syncutil.Future[types.RootCoordClient]) *ServerBuilder { b.rc = rc return b } // WithDataCoordClient sets data coord client to the server builder. -func (b *ServerBuilder) WithDataCoordClient(dc types.DataCoordClient) *ServerBuilder { +func (b *ServerBuilder) WithDataCoordClient(dc *syncutil.Future[types.DataCoordClient]) *ServerBuilder { b.dc = dc return b } diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go index 670677006a27a..51965267f56fb 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go @@ -86,8 +86,17 @@ func (c *channelLifetime) Run() error { // Get recovery info from datacoord. ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - resp, err := resource.Resource().DataCoordClient(). - GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: c.vchannel}) + + pipelineParams, err := c.f.getPipelineParams(ctx) + if err != nil { + return err + } + + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return errors.Wrap(err, "At Get DataCoordClient") + } + resp, err := dc.GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: c.vchannel}) if err = merr.CheckRPCCall(resp, err); err != nil { return err } @@ -115,7 +124,7 @@ func (c *channelLifetime) Run() error { } // Build and add pipeline. - ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams, + ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, pipelineParams, // TODO fubang add the db properties &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan(), func(t syncmgr.Task, err error) { if err != nil || t == nil { diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go index f87acd7353c5e..c97c9b491bba4 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go @@ -18,7 +18,6 @@ package flusherimpl import ( "context" - "sync" "time" "github.com/cockroachdb/errors" @@ -34,56 +33,55 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) var _ flusher.Flusher = (*flusherImpl)(nil) type flusherImpl struct { - broker broker.Broker - fgMgr pipeline.FlowgraphManager - syncMgr syncmgr.SyncManager - wbMgr writebuffer.BufferManager - cpUpdater *util.ChannelCheckpointUpdater + fgMgr pipeline.FlowgraphManager + wbMgr writebuffer.BufferManager + syncMgr syncmgr.SyncManager + cpUpdater *syncutil.Future[*util.ChannelCheckpointUpdater] + chunkManager storage.ChunkManager channelLifetimes *typeutil.ConcurrentMap[string, ChannelLifetime] - notifyCh chan struct{} - stopChan lifetime.SafeChan - stopWg sync.WaitGroup - pipelineParams *util.PipelineParams + notifyCh chan struct{} + notifier *syncutil.AsyncTaskNotifier[struct{}] } func NewFlusher(chunkManager storage.ChunkManager) flusher.Flusher { - params := getPipelineParams(chunkManager) - return newFlusherWithParam(params) -} - -func newFlusherWithParam(params *util.PipelineParams) flusher.Flusher { - fgMgr := pipeline.NewFlowgraphManager() + syncMgr := syncmgr.NewSyncManager(chunkManager) + wbMgr := writebuffer.NewManager(syncMgr) return &flusherImpl{ - broker: params.Broker, - fgMgr: fgMgr, - syncMgr: params.SyncMgr, - wbMgr: params.WriteBufferManager, - cpUpdater: params.CheckpointUpdater, + fgMgr: pipeline.NewFlowgraphManager(), + wbMgr: wbMgr, + syncMgr: syncMgr, + cpUpdater: syncutil.NewFuture[*util.ChannelCheckpointUpdater](), + chunkManager: chunkManager, channelLifetimes: typeutil.NewConcurrentMap[string, ChannelLifetime](), notifyCh: make(chan struct{}, 1), - stopChan: lifetime.NewSafeChan(), - pipelineParams: params, + notifier: syncutil.NewAsyncTaskNotifier[struct{}](), } } func (f *flusherImpl) RegisterPChannel(pchannel string, wal wal.WAL) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - resp, err := resource.Resource().RootCoordClient().GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ + rc, err := resource.Resource().RootCoordClient().GetWithContext(ctx) + if err != nil { + return errors.Wrap(err, "At Get RootCoordClient") + } + resp, err := rc.GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ Pchannel: pchannel, }) if err = merr.CheckRPCCall(resp, err); err != nil { @@ -126,11 +124,18 @@ func (f *flusherImpl) notify() { } func (f *flusherImpl) Start() { - f.stopWg.Add(1) f.wbMgr.Start() - go f.cpUpdater.Start() go func() { - defer f.stopWg.Done() + defer f.notifier.Finish(struct{}{}) + dc, err := resource.Resource().DataCoordClient().GetWithContext(f.notifier.Context()) + if err != nil { + return + } + broker := broker.NewCoordBroker(dc, paramtable.GetNodeID()) + cpUpdater := util.NewChannelCheckpointUpdater(broker) + go cpUpdater.Start() + f.cpUpdater.Set(cpUpdater) + backoff := typeutil.NewBackoffTimer(typeutil.BackoffTimerConfig{ Default: 5 * time.Second, Backoff: typeutil.BackoffConfig{ @@ -143,7 +148,7 @@ func (f *flusherImpl) Start() { var nextTimer <-chan time.Time for { select { - case <-f.stopChan.CloseCh(): + case <-f.notifier.Context().Done(): log.Info("flusher exited") return case <-f.notifyCh: @@ -190,13 +195,37 @@ func (f *flusherImpl) handle(backoff *typeutil.BackoffTimer) <-chan time.Time { } func (f *flusherImpl) Stop() { - f.stopChan.Close() - f.stopWg.Wait() + f.notifier.Cancel() + f.notifier.BlockUntilFinish() f.channelLifetimes.Range(func(vchannel string, lifetime ChannelLifetime) bool { lifetime.Cancel() return true }) f.fgMgr.ClearFlowgraphs() f.wbMgr.Stop() - f.cpUpdater.Close() + if f.cpUpdater.Ready() { + f.cpUpdater.Get().Close() + } +} + +func (f *flusherImpl) getPipelineParams(ctx context.Context) (*util.PipelineParams, error) { + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + + cpUpdater, err := f.cpUpdater.GetWithContext(ctx) + if err != nil { + return nil, err + } + return &util.PipelineParams{ + Ctx: context.Background(), + Broker: broker.NewCoordBroker(dc, paramtable.GetNodeID()), + SyncMgr: f.syncMgr, + ChunkManager: f.chunkManager, + WriteBufferManager: f.wbMgr, + CheckpointUpdater: cpUpdater, + Allocator: idalloc.NewMAllocator(resource.Resource().IDAllocator()), + MsgHandler: newMsgHandler(f.wbMgr), + }, nil } diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go index aef723e7a59f6..f4f0116231962 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go @@ -30,8 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -39,9 +37,11 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func init() { @@ -106,22 +106,8 @@ func newMockWAL(t *testing.T, vchannels []string, maybe bool) *mock_wal.MockWAL } func newTestFlusher(t *testing.T, maybe bool) flusher.Flusher { - wbMgr := writebuffer.NewMockBufferManager(t) - register := wbMgr.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - removeChannel := wbMgr.EXPECT().RemoveChannel(mock.Anything).Return() - start := wbMgr.EXPECT().Start().Return() - stop := wbMgr.EXPECT().Stop().Return() - if maybe { - register.Maybe() - removeChannel.Maybe() - start.Maybe() - stop.Maybe() - } m := mocks.NewChunkManager(t) - params := getPipelineParams(m) - params.SyncMgr = syncmgr.NewMockSyncManager(t) - params.WriteBufferManager = wbMgr - return newFlusherWithParam(params) + return NewFlusher(m) } func TestFlusher_RegisterPChannel(t *testing.T) { @@ -146,10 +132,16 @@ func TestFlusher_RegisterPChannel(t *testing.T) { rootcoord.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything). Return(&rootcoordpb.GetPChannelInfoResponse{Collections: collectionsInfo}, nil) datacoord := newMockDatacoord(t, maybe) + + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) + + fRootcoord := syncutil.NewFuture[types.RootCoordClient]() + fRootcoord.Set(rootcoord) resource.InitForTest( t, - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) @@ -182,9 +174,11 @@ func TestFlusher_RegisterVChannel(t *testing.T) { } datacoord := newMockDatacoord(t, maybe) + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) resource.InitForTest( t, - resource.OptDataCoordClient(datacoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) @@ -220,9 +214,11 @@ func TestFlusher_Concurrency(t *testing.T) { } datacoord := newMockDatacoord(t, maybe) + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) resource.InitForTest( t, - resource.OptDataCoordClient(datacoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) diff --git a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go deleted file mode 100644 index 79751dff73444..0000000000000 --- a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flusherimpl - -import ( - "context" - - "github.com/milvus-io/milvus/internal/flushcommon/broker" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/util" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" - "github.com/milvus-io/milvus/pkg/util/paramtable" -) - -// getPipelineParams initializes the pipeline parameters. -func getPipelineParams(chunkManager storage.ChunkManager) *util.PipelineParams { - var ( - rsc = resource.Resource() - syncMgr = syncmgr.NewSyncManager(chunkManager) - wbMgr = writebuffer.NewManager(syncMgr) - coordBroker = broker.NewCoordBroker(rsc.DataCoordClient(), paramtable.GetNodeID()) - cpUpdater = util.NewChannelCheckpointUpdater(coordBroker) - ) - return &util.PipelineParams{ - Ctx: context.Background(), - Broker: coordBroker, - SyncMgr: syncMgr, - ChunkManager: chunkManager, - WriteBufferManager: wbMgr, - CheckpointUpdater: cpUpdater, - Allocator: idalloc.NewMAllocator(rsc.IDAllocator()), - MsgHandler: newMsgHandler(wbMgr), - } -} diff --git a/internal/streamingnode/server/resource/idalloc/allocator.go b/internal/streamingnode/server/resource/idalloc/allocator.go index 3e8b7bdb59d23..f614d6f5ec3d6 100644 --- a/internal/streamingnode/server/resource/idalloc/allocator.go +++ b/internal/streamingnode/server/resource/idalloc/allocator.go @@ -22,6 +22,7 @@ import ( "time" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // batchAllocateSize is the size of batch allocate from remote allocator. @@ -30,7 +31,7 @@ const batchAllocateSize = 1000 var _ Allocator = (*allocatorImpl)(nil) // NewTSOAllocator creates a new allocator. -func NewTSOAllocator(rc types.RootCoordClient) Allocator { +func NewTSOAllocator(rc *syncutil.Future[types.RootCoordClient]) Allocator { return &allocatorImpl{ mu: sync.Mutex{}, remoteAllocator: newTSOAllocator(rc), @@ -39,7 +40,7 @@ func NewTSOAllocator(rc types.RootCoordClient) Allocator { } // NewIDAllocator creates a new allocator. -func NewIDAllocator(rc types.RootCoordClient) Allocator { +func NewIDAllocator(rc *syncutil.Future[types.RootCoordClient]) Allocator { return &allocatorImpl{ mu: sync.Mutex{}, remoteAllocator: newIDAllocator(rc), diff --git a/internal/streamingnode/server/resource/idalloc/allocator_test.go b/internal/streamingnode/server/resource/idalloc/allocator_test.go index c4db2e520a578..26eb9e90c2b1a 100644 --- a/internal/streamingnode/server/resource/idalloc/allocator_test.go +++ b/internal/streamingnode/server/resource/idalloc/allocator_test.go @@ -11,7 +11,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestTimestampAllocator(t *testing.T) { @@ -19,7 +21,10 @@ func TestTimestampAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) - allocator := NewTSOAllocator(client) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator := NewTSOAllocator(f) for i := 0; i < 5000; i++ { ts, err := allocator.Allocate(context.Background()) @@ -46,7 +51,10 @@ func TestTimestampAllocator(t *testing.T) { }, nil }, ) - allocator = NewTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = NewTSOAllocator(f) _, err := allocator.Allocate(context.Background()) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator.go b/internal/streamingnode/server/resource/idalloc/basic_allocator.go index 8e0ad90e63d1c..8b9e220cc410a 100644 --- a/internal/streamingnode/server/resource/idalloc/basic_allocator.go +++ b/internal/streamingnode/server/resource/idalloc/basic_allocator.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) var errExhausted = errors.New("exhausted") @@ -56,12 +57,12 @@ func (a *localAllocator) exhausted() { // tsoAllocator allocate timestamp from remote root coordinator. type tsoAllocator struct { - rc types.RootCoordClient + rc *syncutil.Future[types.RootCoordClient] nodeID int64 } // newTSOAllocator creates a new remote allocator. -func newTSOAllocator(rc types.RootCoordClient) *tsoAllocator { +func newTSOAllocator(rc *syncutil.Future[types.RootCoordClient]) *tsoAllocator { a := &tsoAllocator{ nodeID: paramtable.GetNodeID(), rc: rc, @@ -80,8 +81,12 @@ func (ta *tsoAllocator) batchAllocate(ctx context.Context, count uint32) (uint64 ), Count: count, } + rc, err := ta.rc.GetWithContext(ctx) + if err != nil { + return 0, 0, fmt.Errorf("get root coordinator client timeout: %w", err) + } - resp, err := ta.rc.AllocTimestamp(ctx, req) + resp, err := rc.AllocTimestamp(ctx, req) if err != nil { return 0, 0, fmt.Errorf("syncTimestamp Failed:%w", err) } @@ -96,12 +101,12 @@ func (ta *tsoAllocator) batchAllocate(ctx context.Context, count uint32) (uint64 // idAllocator allocate timestamp from remote root coordinator. type idAllocator struct { - rc types.RootCoordClient + rc *syncutil.Future[types.RootCoordClient] nodeID int64 } // newIDAllocator creates a new remote allocator. -func newIDAllocator(rc types.RootCoordClient) *idAllocator { +func newIDAllocator(rc *syncutil.Future[types.RootCoordClient]) *idAllocator { a := &idAllocator{ nodeID: paramtable.GetNodeID(), rc: rc, @@ -120,8 +125,12 @@ func (ta *idAllocator) batchAllocate(ctx context.Context, count uint32) (uint64, ), Count: count, } + rc, err := ta.rc.GetWithContext(ctx) + if err != nil { + return 0, 0, fmt.Errorf("get root coordinator client timeout: %w", err) + } - resp, err := ta.rc.AllocID(ctx, req) + resp, err := rc.AllocID(ctx, req) if err != nil { return 0, 0, fmt.Errorf("AllocID Failed:%w", err) } diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go b/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go index 081832006f017..549f78cc00d8b 100644 --- a/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go +++ b/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go @@ -13,7 +13,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestLocalAllocator(t *testing.T) { @@ -63,8 +65,10 @@ func TestRemoteTSOAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) - allocator := newTSOAllocator(client) + allocator := newTSOAllocator(f) ts, count, err := allocator.batchAllocate(context.Background(), 100) assert.NoError(t, err) assert.NotZero(t, ts) @@ -77,7 +81,10 @@ func TestRemoteTSOAllocator(t *testing.T) { return nil, errors.New("test") }, ) - allocator = newTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newTSOAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) @@ -91,7 +98,10 @@ func TestRemoteTSOAllocator(t *testing.T) { }, nil }, ) - allocator = newTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newTSOAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) } @@ -101,8 +111,11 @@ func TestRemoteIDAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator := newIDAllocator(f) - allocator := newIDAllocator(client) ts, count, err := allocator.batchAllocate(context.Background(), 100) assert.NoError(t, err) assert.NotZero(t, ts) @@ -115,7 +128,10 @@ func TestRemoteIDAllocator(t *testing.T) { return nil, errors.New("test") }, ) - allocator = newIDAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newIDAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) @@ -129,7 +145,10 @@ func TestRemoteIDAllocator(t *testing.T) { }, nil }, ) - allocator = newIDAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newIDAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 23ff6316052b9..06edb5a5cd32b 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) var r = &resourceImpl{} // singleton resource instance @@ -41,7 +42,7 @@ func OptChunkManager(chunkManager storage.ChunkManager) optResourceInit { } // OptRootCoordClient provides the root coordinator client to the resource. -func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit { +func OptRootCoordClient(rootCoordClient *syncutil.Future[types.RootCoordClient]) optResourceInit { return func(r *resourceImpl) { r.rootCoordClient = rootCoordClient r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) @@ -50,7 +51,7 @@ func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit { } // OptDataCoordClient provides the data coordinator client to the resource. -func OptDataCoordClient(dataCoordClient types.DataCoordClient) optResourceInit { +func OptDataCoordClient(dataCoordClient *syncutil.Future[types.DataCoordClient]) optResourceInit { return func(r *resourceImpl) { r.dataCoordClient = dataCoordClient } @@ -96,8 +97,8 @@ type resourceImpl struct { idAllocator idalloc.Allocator etcdClient *clientv3.Client chunkManager storage.ChunkManager - rootCoordClient types.RootCoordClient - dataCoordClient types.DataCoordClient + rootCoordClient *syncutil.Future[types.RootCoordClient] + dataCoordClient *syncutil.Future[types.DataCoordClient] streamingNodeCatalog metastore.StreamingNodeCataLog segmentAssignStatsManager *stats.StatsManager timeTickInspector tinspector.TimeTickSyncInspector @@ -129,12 +130,12 @@ func (r *resourceImpl) ChunkManager() storage.ChunkManager { } // RootCoordClient returns the root coordinator client. -func (r *resourceImpl) RootCoordClient() types.RootCoordClient { +func (r *resourceImpl) RootCoordClient() *syncutil.Future[types.RootCoordClient] { return r.rootCoordClient } // DataCoordClient returns the data coordinator client. -func (r *resourceImpl) DataCoordClient() types.DataCoordClient { +func (r *resourceImpl) DataCoordClient() *syncutil.Future[types.DataCoordClient] { return r.dataCoordClient } diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index 1d8d4f976f784..8c219d86ff0c8 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -6,9 +6,10 @@ import ( "github.com/stretchr/testify/assert" clientv3 "go.etcd.io/etcd/client/v3" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestApply(t *testing.T) { @@ -16,7 +17,7 @@ func TestApply(t *testing.T) { Apply() Apply(OptETCD(&clientv3.Client{})) - Apply(OptRootCoordClient(mocks.NewMockRootCoordClient(t))) + Apply(OptRootCoordClient(syncutil.NewFuture[types.RootCoordClient]())) assert.Panics(t, func() { Done() @@ -24,8 +25,8 @@ func TestApply(t *testing.T) { Apply( OptETCD(&clientv3.Client{}), - OptRootCoordClient(mocks.NewMockRootCoordClient(t)), - OptDataCoordClient(mocks.NewMockDataCoordClient(t)), + OptRootCoordClient(syncutil.NewFuture[types.RootCoordClient]()), + OptDataCoordClient(syncutil.NewFuture[types.DataCoordClient]()), OptStreamingNodeCatalog(mock_metastore.NewMockStreamingNodeCataLog(t)), ) Done() diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index bad9e0f4bf1de..3fddc19b893f2 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -9,6 +9,8 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // InitForTest initializes the singleton of resources for test. @@ -21,7 +23,9 @@ func InitForTest(t *testing.T, opts ...optResourceInit) { r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } else { - r.rootCoordClient = idalloc.NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(idalloc.NewMockRootCoordClient(t)) + r.rootCoordClient = f r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index f9f1fb80be165..b217af0d521e9 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -24,10 +24,12 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) const testVChannel = "v1" @@ -53,8 +55,15 @@ func initResourceForTest(t *testing.T) { rc := idalloc.NewMockRootCoordClient(t) rc.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{}, nil) + fRootCoordClient := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootCoordClient.Set(rc) + dc := mocks.NewMockDataCoordClient(t) dc.EXPECT().AllocSegment(mock.Anything, mock.Anything).Return(&datapb.AllocSegmentResponse{}, nil) + + fDataCoordClient := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDataCoordClient.Set(dc) + catalog := mock_metastore.NewMockStreamingNodeCataLog(t) catalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return(nil, nil) catalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -67,8 +76,8 @@ func initResourceForTest(t *testing.T) { resource.InitForTest( t, - resource.OptRootCoordClient(rc), - resource.OptDataCoordClient(dc), + resource.OptRootCoordClient(fRootCoordClient), + resource.OptDataCoordClient(fDataCoordClient), resource.OptFlusher(flusher), resource.OptStreamingNodeCatalog(catalog), ) diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go index def30b9575115..bce92f57960d6 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -225,7 +225,11 @@ func (m *partitionSegmentManager) allocNewGrowingSegment(ctx context.Context) (* // Transfer the pending segment into growing state. // Alloc the growing segment at datacoord first. - resp, err := resource.Resource().DataCoordClient().AllocSegment(ctx, &datapb.AllocSegmentRequest{ + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + resp, err := dc.AllocSegment(ctx, &datapb.AllocSegmentRequest{ CollectionId: pendingSegment.GetCollectionID(), PartitionId: pendingSegment.GetPartitionID(), SegmentId: pendingSegment.GetSegmentID(), diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go index fe30a7e2fbde2..e942ffae35c55 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -32,7 +32,11 @@ func RecoverPChannelSegmentAllocManager( return nil, errors.Wrap(err, "failed to list segment assignment from catalog") } // get collection and parition info from rootcoord. - resp, err := resource.Resource().RootCoordClient().GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ + rc, err := resource.Resource().RootCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + resp, err := rc.GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ Pchannel: pchannel.Name, }) if err := merr.CheckRPCCall(resp, err); err != nil { diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index 33597cce87c25..7093f1139e89a 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" @@ -311,6 +312,8 @@ func initializeTestState(t *testing.T) { Status: merr.Success(), }, nil }) + fDataCoordClient := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDataCoordClient.Set(dataCoordClient) rootCoordClient := idalloc.NewMockRootCoordClient(t) rootCoordClient.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{ @@ -325,11 +328,13 @@ func initializeTestState(t *testing.T) { }, }, }, nil) + fRootCoordClient := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootCoordClient.Set(rootCoordClient) resource.InitForTest(t, resource.OptStreamingNodeCatalog(streamingNodeCatalog), - resource.OptDataCoordClient(dataCoordClient), - resource.OptRootCoordClient(rootCoordClient), + resource.OptDataCoordClient(fDataCoordClient), + resource.OptRootCoordClient(fRootCoordClient), ) streamingNodeCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return( []*streamingpb.SegmentAssignmentMeta{ diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go index 0803931c3b909..0ba11fd88b499 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go @@ -17,9 +17,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/metricsutil" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestAck(t *testing.T) { @@ -43,7 +45,9 @@ func TestAck(t *testing.T) { }, nil }, ) - resource.InitForTest(t, resource.OptRootCoordClient(rc)) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(rc) + resource.InitForTest(t, resource.OptRootCoordClient(f)) ackManager := NewAckManager(0, nil, metricsutil.NewTimeTickMetrics("test")) @@ -160,7 +164,9 @@ func TestAckManager(t *testing.T) { }, nil }, ) - resource.InitForTest(t, resource.OptRootCoordClient(rc)) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(rc) + resource.InitForTest(t, resource.OptRootCoordClient(f)) ackManager := NewAckManager(0, walimplstest.NewTestMessageID(0), metricsutil.NewTimeTickMetrics("test")) diff --git a/internal/streamingnode/server/walmanager/manager_impl_test.go b/internal/streamingnode/server/walmanager/manager_impl_test.go index 35b269cc04a85..cdaa931e3c51d 100644 --- a/internal/streamingnode/server/walmanager/manager_impl_test.go +++ b/internal/streamingnode/server/walmanager/manager_impl_test.go @@ -12,10 +12,12 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestMain(m *testing.M) { @@ -25,7 +27,11 @@ func TestMain(m *testing.M) { func TestManager(t *testing.T) { rootcoord := mocks.NewMockRootCoordClient(t) + fRootcoord := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootcoord.Set(rootcoord) datacoord := mocks.NewMockDataCoordClient(t) + fDatacoord := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDatacoord.Set(datacoord) flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) @@ -33,8 +39,8 @@ func TestManager(t *testing.T) { resource.InitForTest( t, resource.OptFlusher(flusher), - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) opener := mock_wal.NewMockOpener(t) diff --git a/internal/streamingnode/server/walmanager/wal_lifetime_test.go b/internal/streamingnode/server/walmanager/wal_lifetime_test.go index d34bfe4f88896..a14464df8b594 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime_test.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime_test.go @@ -12,14 +12,20 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestWALLifetime(t *testing.T) { channel := "test" rootcoord := mocks.NewMockRootCoordClient(t) + fRootcoord := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootcoord.Set(rootcoord) datacoord := mocks.NewMockDataCoordClient(t) + fDatacoord := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDatacoord.Set(datacoord) flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) @@ -28,8 +34,8 @@ func TestWALLifetime(t *testing.T) { resource.InitForTest( t, resource.OptFlusher(flusher), - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) opener := mock_wal.NewMockOpener(t) diff --git a/pkg/util/retry/options.go b/pkg/util/retry/options.go index 80f00a9ffc8f9..852e4ec7d786e 100644 --- a/pkg/util/retry/options.go +++ b/pkg/util/retry/options.go @@ -31,6 +31,12 @@ func newDefaultConfig() *config { // Option is used to config the retry function. type Option func(*config) +func AttemptAlways() Option { + return func(c *config) { + c.attempts = 0 + } +} + // Attempts is used to config the max retry times. func Attempts(attempts uint) Option { return func(c *config) { diff --git a/pkg/util/retry/retry.go b/pkg/util/retry/retry.go index c623bb1dbeb4d..a2a722ec13571 100644 --- a/pkg/util/retry/retry.go +++ b/pkg/util/retry/retry.go @@ -40,7 +40,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { var lastErr error - for i := uint(0); i < c.attempts; i++ { + for i := uint(0); c.attempts == 0 || i < c.attempts; i++ { if err := fn(); err != nil { if i%4 == 0 { log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) diff --git a/pkg/util/retry/retry_test.go b/pkg/util/retry/retry_test.go index d0936a70dba85..e4c86d0b7521d 100644 --- a/pkg/util/retry/retry_test.go +++ b/pkg/util/retry/retry_test.go @@ -50,6 +50,17 @@ func TestAttempts(t *testing.T) { err := Do(ctx, testFn, Attempts(1)) assert.Error(t, err) t.Log(err) + + ctx = context.Background() + testOperation := 0 + testFn = func() error { + testOperation++ + return nil + } + + err = Do(ctx, testFn, AttemptAlways()) + assert.Equal(t, testOperation, 1) + assert.NoError(t, err) } func TestMaxSleepTime(t *testing.T) {