From 2cf164fea6f1eb1028e777e86ce4097eab54f912 Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Thu, 10 Oct 2024 14:59:32 +0000 Subject: [PATCH] Add context wrapping for syncer disconnections Any long-lived process (such as mixing) must return if the network backend (e.g. the dcrd RPC connection) is disconnected or the syncer encounters another error. However, it was observed that mixclient calls were continuing to execute despite the syncer erroring, resulting in locked outputs that could not be mixed after the syncer reconnected and restarted. To avoid this, the network backend interface gains additional methods to select on when disconnect occurs, and a new context wrapping function is added, which creates a derived context that is canceled (or "done") after the syncer exits. Wrapped contexts are added in the ticket purchasing and account mixing paths, including those being performed by the autobuyer. --- chain/backend.go | 21 +++++++++++++++ chain/sync.go | 15 +++++++++++ spv/backend.go | 21 +++++++++++++++ spv/sync.go | 17 ++++++++++++- ticketbuyer/tb.go | 2 ++ wallet/mixing.go | 7 +++++ wallet/network.go | 58 ++++++++++++++++++++++++++++++++++++++++++ wallet/network_test.go | 2 ++ wallet/wallet.go | 3 +++ 9 files changed, 145 insertions(+), 1 deletion(-) diff --git a/chain/backend.go b/chain/backend.go index 3a56e1b57..e514850e6 100644 --- a/chain/backend.go +++ b/chain/backend.go @@ -87,3 +87,24 @@ func (s *Syncer) ExistsLiveTickets(ctx context.Context, tickets []*chainhash.Has func (s *Syncer) UsedAddresses(ctx context.Context, addrs []stdaddr.Address) (bitset.Bytes, error) { return s.rpc.UsedAddresses(ctx, addrs) } + +func (s *Syncer) Done() <-chan struct{} { + s.doneMu.Lock() + c := s.done + s.doneMu.Unlock() + return c +} + +func (s *Syncer) Err() error { + s.doneMu.Lock() + c := s.done + err := s.err + s.doneMu.Unlock() + + select { + case <-c: + return err + default: + return nil + } +} diff --git a/chain/sync.go b/chain/sync.go index 82f16fad7..ca2996f51 100644 --- a/chain/sync.go +++ b/chain/sync.go @@ -55,6 +55,10 @@ type Syncer struct { relevantTxs map[chainhash.Hash][]*wire.MsgTx cb *Callbacks + + done chan struct{} + err error + doneMu sync.Mutex } // RPCOptions specifies the network and security settings for establishing a @@ -525,6 +529,17 @@ func (s *Syncer) Run(ctx context.Context) (err error) { } }() + s.doneMu.Lock() + s.done = make(chan struct{}) + s.err = nil + s.doneMu.Unlock() + defer func() { + s.doneMu.Lock() + close(s.done) + s.err = err + s.doneMu.Unlock() + }() + params := s.wallet.ChainParams() s.notifier = ¬ifier{ diff --git a/spv/backend.go b/spv/backend.go index 4a6d90397..ca6ff64c1 100644 --- a/spv/backend.go +++ b/spv/backend.go @@ -619,3 +619,24 @@ func (s *Syncer) Rescan(ctx context.Context, blockHashes []chainhash.Hash, save func (s *Syncer) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) { return 0, errors.E(errors.Invalid, "stake difficulty is not queryable over wire protocol") } + +func (s *Syncer) Done() <-chan struct{} { + s.doneMu.Lock() + c := s.done + s.doneMu.Unlock() + return c +} + +func (s *Syncer) Err() error { + s.doneMu.Lock() + c := s.done + err := s.err + s.doneMu.Unlock() + + select { + case <-c: + return err + default: + return nil + } +} diff --git a/spv/sync.go b/spv/sync.go index c626346c7..d5560a7f4 100644 --- a/spv/sync.go +++ b/spv/sync.go @@ -91,6 +91,10 @@ type Syncer struct { // Mempool for non-wallet-relevant transactions. mempool sync.Map // k=chainhash.Hash v=*wire.MsgTx mempoolAdds chan *chainhash.Hash + + done chan struct{} + err error + doneMu sync.Mutex } // Notifications struct to contain all of the upcoming callbacks that will @@ -318,7 +322,18 @@ func (s *Syncer) setRequiredHeight(tipHeight int32) { // Run synchronizes the wallet, returning when synchronization fails or the // context is cancelled. -func (s *Syncer) Run(ctx context.Context) error { +func (s *Syncer) Run(ctx context.Context) (err error) { + s.doneMu.Lock() + s.done = make(chan struct{}) + s.err = nil + s.doneMu.Unlock() + go func() { + s.doneMu.Lock() + close(s.done) + s.err = err + s.doneMu.Unlock() + }() + tipHash, tipHeight := s.wallet.MainChainTip(ctx) s.setRequiredHeight(tipHeight) rescanPoint, err := s.wallet.RescanPoint(ctx) diff --git a/ticketbuyer/tb.go b/ticketbuyer/tb.go index b74321fe5..a7b9f8a53 100644 --- a/ticketbuyer/tb.go +++ b/ticketbuyer/tb.go @@ -216,6 +216,8 @@ func (tb *TB) buy(ctx context.Context, passphrase []byte, tip *wire.BlockHeader, if err != nil { return err } + ctx, cancel := wallet.WrapNetworkBackendContext(n, ctx) + defer cancel() if len(passphrase) > 0 { // Ensure wallet is unlocked with the current passphrase. If the passphase diff --git a/wallet/mixing.go b/wallet/mixing.go index aa031a838..37c51e3ac 100644 --- a/wallet/mixing.go +++ b/wallet/mixing.go @@ -275,6 +275,13 @@ func (w *Wallet) MixOutput(ctx context.Context, output *wire.OutPoint, changeAcc return errors.E(op, errors.Invalid, s) } + nb, err := w.NetworkBackend() + if err != nil { + return err + } + ctx, cancel := WrapNetworkBackendContext(nb, ctx) + defer cancel() + sdiff, err := w.NextStakeDifficulty(ctx) if err != nil { return errors.E(op, err) diff --git a/wallet/network.go b/wallet/network.go index 511dbb587..fb920a248 100644 --- a/wallet/network.go +++ b/wallet/network.go @@ -6,6 +6,7 @@ package wallet import ( "context" + "sync" "decred.org/dcrwallet/v5/errors" "github.com/decred/dcrd/chaincfg/chainhash" @@ -49,6 +50,12 @@ type NetworkBackend interface { // the wallet to the underlying network, and if not, it returns the // target height that it is attempting to sync to. Synced(ctx context.Context) (bool, int32) + + // Done return a channel that is closed after the syncer disconnects. + // The error (if any) can be returned via Err. + // These semantics match that of context.Context. + Done() <-chan struct{} + Err() error } // NetworkBackend returns the currently associated network backend of the @@ -73,6 +80,47 @@ func (w *Wallet) SetNetworkBackend(n NetworkBackend) { w.networkBackendMu.Unlock() } +type networkContext struct { + context.Context + err error + mu sync.Mutex +} + +func (c *networkContext) Err() error { + c.mu.Lock() + err := c.err + c.mu.Unlock() + + if err != nil { + return err + } + return c.Context.Err() +} + +// WrapNetworkBackendContext returns a derived context that is canceled when +// the NetworkBackend is disconnected. The cancel func must be called +// (e.g. using defer) otherwise a goroutine leak may occur. +func WrapNetworkBackendContext(nb NetworkBackend, ctx context.Context) (context.Context, context.CancelFunc) { + childCtx, cancel := context.WithCancel(ctx) + nbContext := &networkContext{ + Context: childCtx, + } + + go func() { + select { + case <-nb.Done(): + err := nb.Err() + nbContext.mu.Lock() + nbContext.err = err + nbContext.mu.Unlock() + case <-childCtx.Done(): + } + cancel() + }() + + return nbContext, cancel +} + // Caller provides a client interface to perform remote procedure calls. // Serialization and calling conventions are implementation-specific. type Caller interface { @@ -122,6 +170,16 @@ func (o OfflineNetworkBackend) Synced(ctx context.Context) (bool, int32) { return true, 0 } +func (o OfflineNetworkBackend) Done() <-chan struct{} { + c := make(chan struct{}) + close(c) + return c +} + +func (o OfflineNetworkBackend) Err() error { + return errors.E("offline") +} + // Compile time check to ensure OfflineNetworkBackend fulfills the // NetworkBackend interface. var _ NetworkBackend = OfflineNetworkBackend{} diff --git a/wallet/network_test.go b/wallet/network_test.go index 10966f6c5..eb75a7202 100644 --- a/wallet/network_test.go +++ b/wallet/network_test.go @@ -35,3 +35,5 @@ func (mockNetwork) Rescan(ctx context.Context, blocks []chainhash.Hash, save fun } func (mockNetwork) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) { return 0, nil } func (mockNetwork) Synced(ctx context.Context) (bool, int32) { return false, 0 } +func (mockNetwork) Done() <-chan struct{} { return nil } +func (mockNetwork) Err() error { return nil } diff --git a/wallet/wallet.go b/wallet/wallet.go index c0eff9e76..0f9dda294 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -1588,6 +1588,9 @@ func (w *Wallet) PurchaseTickets(ctx context.Context, n NetworkBackend, const op errors.Op = "wallet.PurchaseTickets" + ctx, cancel := WrapNetworkBackendContext(n, ctx) + defer cancel() + resp, err := w.purchaseTickets(ctx, op, n, req) if err == nil || !errors.Is(err, errVSPFeeRequiresUTXOSplit) || req.DontSignTx { return resp, err