diff --git a/syncer/syncer.go b/syncer/syncer.go index e60c01b..297ca06 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -206,8 +206,7 @@ type Syncer struct { config config log *zap.Logger // redundant, but convenient - shutdownCtx context.Context - shutdownCtxCancel context.CancelFunc + wg sync.WaitGroup mu sync.Mutex peers map[string]*Peer @@ -257,6 +256,9 @@ func (s *Syncer) ban(p *Peer, err error) error { } func (s *Syncer) runPeer(p *Peer) error { + s.wg.Add(1) + defer s.wg.Done() + if err := s.pm.AddPeer(p.t.Addr); err != nil { return fmt.Errorf("failed to add peer: %w", err) } @@ -276,8 +278,7 @@ func (s *Syncer) runPeer(p *Peer) error { }() inflight := make(chan struct{}, s.config.MaxInflightRPCs) - var wg sync.WaitGroup - defer wg.Wait() + for { if p.Err() != nil { return fmt.Errorf("peer error: %w", p.Err()) @@ -288,9 +289,9 @@ func (s *Syncer) runPeer(p *Peer) error { return fmt.Errorf("failed to accept rpc: %w", err) } inflight <- struct{}{} - wg.Add(1) + s.wg.Add(1) go func() { - defer wg.Done() + defer s.wg.Done() defer stream.Close() // NOTE: we do not set any deadlines on the stream. If a peer is // slow, fine; we don't need to worry about resource exhaustion @@ -358,7 +359,7 @@ func (s *Syncer) relayV2TransactionSet(index types.ChainIndex, txns []types.V2Tr } } -func (s *Syncer) allowConnect(peer string, inbound bool) error { +func (s *Syncer) allowConnect(ctx context.Context, peer string, inbound bool) error { s.mu.Lock() defer s.mu.Unlock() if s.l == nil { @@ -368,7 +369,7 @@ func (s *Syncer) allowConnect(peer string, inbound bool) error { var addrs []net.IPAddr if peerHost, _, err := net.SplitHostPort(peer); err != nil { return fmt.Errorf("failed to split peer host and port: %w", err) - } else if addrs, err = (&net.Resolver{}).LookupIPAddr(s.shutdownCtx, peerHost); err != nil { + } else if addrs, err = (&net.Resolver{}).LookupIPAddr(ctx, peerHost); err != nil { return fmt.Errorf("failed to resolve peer address: %w", err) } else if len(addrs) == 0 { return fmt.Errorf("peer didn't resolve to any addresses") @@ -408,19 +409,23 @@ func (s *Syncer) alreadyConnected(id gateway.UniqueID) bool { return false } -func (s *Syncer) acceptLoop() error { - var wg sync.WaitGroup - defer wg.Wait() +func (s *Syncer) acceptLoop(ctx context.Context) error { for { + select { + case <-ctx.Done(): + return nil + default: + } + conn, err := s.l.Accept() if err != nil { return err } - wg.Add(1) + s.wg.Add(1) go func() { - defer wg.Done() + defer s.wg.Done() defer conn.Close() - if err := s.allowConnect(conn.RemoteAddr().String(), true); err != nil { + if err := s.allowConnect(ctx, conn.RemoteAddr().String(), true); err != nil { s.log.Debug("rejected inbound connection", zap.Stringer("remoteAddress", conn.RemoteAddr()), zap.Error(err)) } else if t, err := gateway.Accept(conn, s.header); err != nil { s.log.Debug("failed to accept inbound connection", zap.Stringer("remoteAddress", conn.RemoteAddr()), zap.Error(err)) @@ -437,7 +442,7 @@ func (s *Syncer) acceptLoop() error { } } -func (s *Syncer) peerLoop() error { +func (s *Syncer) peerLoop(ctx context.Context) error { log := s.log.Named("peerLoop") numOutbound := func() (n int) { s.mu.Lock() @@ -501,11 +506,17 @@ func (s *Syncer) peerLoop() error { select { case <-ticker.C: return true - case <-s.shutdownCtx.Done(): + case <-ctx.Done(): return false } } for fst := true; fst || sleep(); fst = false { + select { + case <-ctx.Done(): + return nil // avoid spamming "failed to connect" after context is cancelled + default: + } + if numOutbound() >= s.config.MaxOutboundPeers { continue } @@ -519,7 +530,7 @@ func (s *Syncer) peerLoop() error { if numOutbound() >= s.config.MaxOutboundPeers { break } - ctx, cancel := context.WithTimeout(s.shutdownCtx, s.config.ConnectTimeout) + ctx, cancel := context.WithTimeout(ctx, s.config.ConnectTimeout) if _, err := s.Connect(ctx, p); err != nil { log.Debug("connected to peer", zap.String("peer", p)) } else { @@ -532,7 +543,7 @@ func (s *Syncer) peerLoop() error { return nil } -func (s *Syncer) syncLoop() error { +func (s *Syncer) syncLoop(ctx context.Context) error { peersForSync := func() (peers []*Peer) { s.mu.Lock() defer s.mu.Unlock() @@ -557,7 +568,7 @@ func (s *Syncer) syncLoop() error { select { case <-ticker.C: return true - case <-s.shutdownCtx.Done(): + case <-ctx.Done(): return false } } @@ -629,15 +640,14 @@ func (s *Syncer) syncLoop() error { // connections, and syncing the blockchain from active peers. It blocks until an // error occurs, upon which all connections are closed and goroutines are // terminated. -func (s *Syncer) Run() error { +func (s *Syncer) Run(ctx context.Context) error { errChan := make(chan error) - go func() { errChan <- s.acceptLoop() }() - go func() { errChan <- s.peerLoop() }() - go func() { errChan <- s.syncLoop() }() + go func() { errChan <- s.acceptLoop(ctx) }() + go func() { errChan <- s.peerLoop(ctx) }() + go func() { errChan <- s.syncLoop(ctx) }() err := <-errChan // when one goroutine exits, shutdown and wait for the others - s.shutdownCtxCancel() s.l.Close() s.mu.Lock() for _, p := range s.peers { @@ -665,29 +675,21 @@ func (s *Syncer) Run() error { // Close closes the Syncer's net.Listener. func (s *Syncer) Close() error { - return s.l.Close() + err := s.l.Close() + s.wg.Wait() + return err } // Connect forms an outbound connection to a peer. func (s *Syncer) Connect(ctx context.Context, addr string) (*Peer, error) { - if err := s.allowConnect(addr, false); err != nil { + if err := s.allowConnect(ctx, addr, false); err != nil { return nil, err } - // ensure we cancel out immediately if the syncer is stopped - ctx, cancel := context.WithCancel(ctx) - go func() { - select { - case <-ctx.Done(): - case <-s.shutdownCtx.Done(): - cancel() - } - }() conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) if err != nil { return nil, err } - cancel() conn.SetDeadline(time.Now().Add(s.config.ConnectTimeout)) defer conn.SetDeadline(time.Time{}) t, err := gateway.Dial(conn, s.header) @@ -776,17 +778,14 @@ func New(l net.Listener, cm ChainManager, pm PeerStore, header gateway.Header, o for _, opt := range opts { opt(&config) } - ctx, cancel := context.WithCancel(context.Background()) return &Syncer{ - l: l, - cm: cm, - pm: pm, - header: header, - config: config, - log: config.Logger, - shutdownCtx: ctx, - shutdownCtxCancel: cancel, - peers: make(map[string]*Peer), - strikes: make(map[string]int), + l: l, + cm: cm, + pm: pm, + header: header, + config: config, + log: config.Logger, + peers: make(map[string]*Peer), + strikes: make(map[string]int), } } diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index 614eae3..89d8502 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -48,7 +48,7 @@ func TestSyncer(t *testing.T) { NetAddress: l1.Addr().String(), }, syncer.WithLogger(log.Named("syncer1"))) defer s1.Close() - go s1.Run() + go s1.Run(context.Background()) s2 := syncer.New(l2, cm2, testutil.NewMemPeerStore(), gateway.Header{ GenesisID: genesis.ID(), @@ -56,7 +56,7 @@ func TestSyncer(t *testing.T) { NetAddress: l2.Addr().String(), }, syncer.WithLogger(log.Named("syncer2")), syncer.WithSyncInterval(10*time.Millisecond)) defer s2.Close() - go s2.Run() + go s2.Run(context.Background()) // mine a few blocks on cm1 testutil.MineBlocks(t, cm1, types.VoidAddress, 10)