Skip to content

Commit

Permalink
syncer: wait for goroutines to exit before returning from Close()
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger committed Aug 2, 2024
1 parent 6804566 commit 71608c6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 49 deletions.
93 changes: 46 additions & 47 deletions syncer/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ type Syncer struct {
config config
log *zap.Logger // redundant, but convenient

shutdownCtx context.Context
shutdownCtxCancel context.CancelFunc
wg sync.WaitGroup

mu sync.Mutex
peers map[string]*Peer
Expand Down Expand Up @@ -257,6 +256,9 @@ func (s *Syncer) ban(p *Peer, err error) error {
}

func (s *Syncer) runPeer(p *Peer) error {
s.wg.Add(1)
defer s.wg.Done()

if err := s.pm.AddPeer(p.t.Addr); err != nil {
return fmt.Errorf("failed to add peer: %w", err)
}
Expand All @@ -276,8 +278,7 @@ func (s *Syncer) runPeer(p *Peer) error {
}()

inflight := make(chan struct{}, s.config.MaxInflightRPCs)
var wg sync.WaitGroup
defer wg.Wait()

for {
if p.Err() != nil {
return fmt.Errorf("peer error: %w", p.Err())
Expand All @@ -288,9 +289,9 @@ func (s *Syncer) runPeer(p *Peer) error {
return fmt.Errorf("failed to accept rpc: %w", err)
}
inflight <- struct{}{}
wg.Add(1)
s.wg.Add(1)
go func() {
defer wg.Done()
defer s.wg.Done()
defer stream.Close()
// NOTE: we do not set any deadlines on the stream. If a peer is
// slow, fine; we don't need to worry about resource exhaustion
Expand Down Expand Up @@ -358,7 +359,7 @@ func (s *Syncer) relayV2TransactionSet(index types.ChainIndex, txns []types.V2Tr
}
}

func (s *Syncer) allowConnect(peer string, inbound bool) error {
func (s *Syncer) allowConnect(ctx context.Context, peer string, inbound bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.l == nil {
Expand All @@ -368,7 +369,7 @@ func (s *Syncer) allowConnect(peer string, inbound bool) error {
var addrs []net.IPAddr
if peerHost, _, err := net.SplitHostPort(peer); err != nil {
return fmt.Errorf("failed to split peer host and port: %w", err)
} else if addrs, err = (&net.Resolver{}).LookupIPAddr(s.shutdownCtx, peerHost); err != nil {
} else if addrs, err = (&net.Resolver{}).LookupIPAddr(ctx, peerHost); err != nil {
return fmt.Errorf("failed to resolve peer address: %w", err)
} else if len(addrs) == 0 {
return fmt.Errorf("peer didn't resolve to any addresses")
Expand Down Expand Up @@ -408,19 +409,23 @@ func (s *Syncer) alreadyConnected(id gateway.UniqueID) bool {
return false
}

func (s *Syncer) acceptLoop() error {
var wg sync.WaitGroup
defer wg.Wait()
func (s *Syncer) acceptLoop(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}

conn, err := s.l.Accept()
if err != nil {
return err
}
wg.Add(1)
s.wg.Add(1)
go func() {
defer wg.Done()
defer s.wg.Done()
defer conn.Close()
if err := s.allowConnect(conn.RemoteAddr().String(), true); err != nil {
if err := s.allowConnect(ctx, conn.RemoteAddr().String(), true); err != nil {
s.log.Debug("rejected inbound connection", zap.Stringer("remoteAddress", conn.RemoteAddr()), zap.Error(err))
} else if t, err := gateway.Accept(conn, s.header); err != nil {
s.log.Debug("failed to accept inbound connection", zap.Stringer("remoteAddress", conn.RemoteAddr()), zap.Error(err))
Expand All @@ -437,7 +442,7 @@ func (s *Syncer) acceptLoop() error {
}
}

func (s *Syncer) peerLoop() error {
func (s *Syncer) peerLoop(ctx context.Context) error {
log := s.log.Named("peerLoop")
numOutbound := func() (n int) {
s.mu.Lock()
Expand Down Expand Up @@ -501,11 +506,17 @@ func (s *Syncer) peerLoop() error {
select {
case <-ticker.C:
return true
case <-s.shutdownCtx.Done():
case <-ctx.Done():
return false
}
}
for fst := true; fst || sleep(); fst = false {
select {
case <-ctx.Done():
return nil // avoid spamming "failed to connect" after context is cancelled
default:
}

if numOutbound() >= s.config.MaxOutboundPeers {
continue
}
Expand All @@ -519,7 +530,7 @@ func (s *Syncer) peerLoop() error {
if numOutbound() >= s.config.MaxOutboundPeers {
break
}
ctx, cancel := context.WithTimeout(s.shutdownCtx, s.config.ConnectTimeout)
ctx, cancel := context.WithTimeout(ctx, s.config.ConnectTimeout)
if _, err := s.Connect(ctx, p); err != nil {
log.Debug("connected to peer", zap.String("peer", p))
} else {
Expand All @@ -532,7 +543,7 @@ func (s *Syncer) peerLoop() error {
return nil
}

func (s *Syncer) syncLoop() error {
func (s *Syncer) syncLoop(ctx context.Context) error {
peersForSync := func() (peers []*Peer) {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -557,7 +568,7 @@ func (s *Syncer) syncLoop() error {
select {
case <-ticker.C:
return true
case <-s.shutdownCtx.Done():
case <-ctx.Done():
return false
}
}
Expand Down Expand Up @@ -629,15 +640,14 @@ func (s *Syncer) syncLoop() error {
// connections, and syncing the blockchain from active peers. It blocks until an
// error occurs, upon which all connections are closed and goroutines are
// terminated.
func (s *Syncer) Run() error {
func (s *Syncer) Run(ctx context.Context) error {
errChan := make(chan error)
go func() { errChan <- s.acceptLoop() }()
go func() { errChan <- s.peerLoop() }()
go func() { errChan <- s.syncLoop() }()
go func() { errChan <- s.acceptLoop(ctx) }()
go func() { errChan <- s.peerLoop(ctx) }()
go func() { errChan <- s.syncLoop(ctx) }()
err := <-errChan

// when one goroutine exits, shutdown and wait for the others
s.shutdownCtxCancel()
s.l.Close()
s.mu.Lock()
for _, p := range s.peers {
Expand Down Expand Up @@ -665,29 +675,21 @@ func (s *Syncer) Run() error {

// Close closes the Syncer's net.Listener.
func (s *Syncer) Close() error {
return s.l.Close()
err := s.l.Close()
s.wg.Wait()
return err
}

// Connect forms an outbound connection to a peer.
func (s *Syncer) Connect(ctx context.Context, addr string) (*Peer, error) {
if err := s.allowConnect(addr, false); err != nil {
if err := s.allowConnect(ctx, addr, false); err != nil {
return nil, err
}

// ensure we cancel out immediately if the syncer is stopped
ctx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-ctx.Done():
case <-s.shutdownCtx.Done():
cancel()
}
}()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
cancel()
conn.SetDeadline(time.Now().Add(s.config.ConnectTimeout))
defer conn.SetDeadline(time.Time{})
t, err := gateway.Dial(conn, s.header)
Expand Down Expand Up @@ -776,17 +778,14 @@ func New(l net.Listener, cm ChainManager, pm PeerStore, header gateway.Header, o
for _, opt := range opts {
opt(&config)
}
ctx, cancel := context.WithCancel(context.Background())
return &Syncer{
l: l,
cm: cm,
pm: pm,
header: header,
config: config,
log: config.Logger,
shutdownCtx: ctx,
shutdownCtxCancel: cancel,
peers: make(map[string]*Peer),
strikes: make(map[string]int),
l: l,
cm: cm,
pm: pm,
header: header,
config: config,
log: config.Logger,
peers: make(map[string]*Peer),
strikes: make(map[string]int),
}
}
4 changes: 2 additions & 2 deletions syncer/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ func TestSyncer(t *testing.T) {
NetAddress: l1.Addr().String(),
}, syncer.WithLogger(log.Named("syncer1")))
defer s1.Close()
go s1.Run()
go s1.Run(context.Background())

s2 := syncer.New(l2, cm2, testutil.NewMemPeerStore(), gateway.Header{
GenesisID: genesis.ID(),
UniqueID: gateway.GenerateUniqueID(),
NetAddress: l2.Addr().String(),
}, syncer.WithLogger(log.Named("syncer2")), syncer.WithSyncInterval(10*time.Millisecond))
defer s2.Close()
go s2.Run()
go s2.Run(context.Background())

// mine a few blocks on cm1
testutil.MineBlocks(t, cm1, types.VoidAddress, 10)
Expand Down

0 comments on commit 71608c6

Please sign in to comment.