From cd38ea98e03c743444e432e4e3d6d1148afaf99c Mon Sep 17 00:00:00 2001 From: fanquake Date: Thu, 19 Aug 2021 09:18:28 +0800 Subject: [PATCH] Merge bitcoin/bitcoin#19101: refactor: remove ::vpwallets and related global variables 62a09a30772141ef4add2f10d29927211abf57eb refactor: remove ::vpwallets and related global variables (Russell Yanofsky) Pull request description: Get rid of global wallet list variables by moving them to WalletContext struct - [`cs_wallets`](https://github.com/bitcoin/bitcoin/blob/e638acf6970394f8eb1957366ad2d39512f33b31/src/wallet/wallet.cpp#L56) is now [`WalletContext::wallet_mutex`](https://github.com/ryanofsky/bitcoin/blob/4be544c7ec08a81952fd3f4349151cbb8bdb60e8/src/wallet/context.h#L37) - [`vpwallets`](https://github.com/bitcoin/bitcoin/blob/e638acf6970394f8eb1957366ad2d39512f33b31/src/wallet/wallet.cpp#L57) is now [`WalletContext::wallets`](https://github.com/ryanofsky/bitcoin/blob/4be544c7ec08a81952fd3f4349151cbb8bdb60e8/src/wallet/context.h#L38) - [`g_load_wallet_fns`](https://github.com/bitcoin/bitcoin/blob/e638acf6970394f8eb1957366ad2d39512f33b31/src/wallet/wallet.cpp#L58) is now [`WalletContext::wallet_load_fns`](https://github.com/ryanofsky/bitcoin/blob/4be544c7ec08a81952fd3f4349151cbb8bdb60e8/src/wallet/context.h#L39) ACKs for top commit: achow101: ACK 62a09a30772141ef4add2f10d29927211abf57eb meshcollider: re-utACK 62a09a30772141ef4add2f10d29927211abf57eb Tree-SHA512: 74428180d57b4214c3d96963e6ff43e8778f6f23b6880262d1272f2de67d02714fdc3ebb558f62e48655b221a642c36f80ef37c8f89d362e2d66fd93cbf03b8f --- src/interfaces/wallet.h | 5 +- src/qt/test/addressbooktests.cpp | 7 +-- src/qt/test/wallettests.cpp | 2 +- src/wallet/context.h | 15 +++++- src/wallet/interfaces.cpp | 30 ++++++----- src/wallet/load.cpp | 35 +++++++------ src/wallet/load.h | 13 ++--- src/wallet/rpcwallet.cpp | 18 ++++--- src/wallet/test/wallet_tests.cpp | 45 +++++++++------- src/wallet/wallet.cpp | 89 ++++++++++++++++---------------- src/wallet/wallet.h | 24 +++++---- src/wallet/walletdb.cpp | 4 +- src/wallet/walletdb.h | 3 +- 13 files changed, 165 insertions(+), 125 deletions(-) diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h index 3339fcafbc3da2..8569098152011c 100644 --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -360,6 +360,9 @@ class WalletLoader : public ChainClient //! loaded at startup or by RPC. using LoadWalletFn = std::function wallet)>; virtual std::unique_ptr handleLoadWallet(LoadWalletFn fn) = 0; + + //! Return pointer to internal context, useful for testing. + virtual WalletContext* context() { return nullptr; } }; //! Information about one wallet address. @@ -443,7 +446,7 @@ struct WalletTxOut //! Return implementation of Wallet interface. This function is defined in //! dummywallet.cpp and throws if the wallet component is not compiled. -std::unique_ptr MakeWallet(const std::shared_ptr& wallet); +std::unique_ptr MakeWallet(WalletContext& context, const std::shared_ptr& wallet); //! Return implementation of ChainClient interface for a wallet loader. This //! function will be undefined in builds where ENABLE_WALLET is false. diff --git a/src/qt/test/addressbooktests.cpp b/src/qt/test/addressbooktests.cpp index 5eb00a1c32ce8f..6eb794202e3c2c 100644 --- a/src/qt/test/addressbooktests.cpp +++ b/src/qt/test/addressbooktests.cpp @@ -106,9 +106,10 @@ void TestAddAddressesToSendBook(interfaces::Node& node) // Initialize relevant QT models. OptionsModel optionsModel; ClientModel clientModel(node, &optionsModel); - AddWallet(wallet); - WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel); - RemoveWallet(wallet, std::nullopt); + WalletContext& context = *node.walletClient().context(); + AddWallet(context, wallet); + WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); EditAddressDialog editAddressDialog(EditAddressDialog::NewSendingAddress); editAddressDialog.setModel(walletModel.getAddressTableModel()); diff --git a/src/qt/test/wallettests.cpp b/src/qt/test/wallettests.cpp index 15140e9645cad1..f5f123c835d5a2 100644 --- a/src/qt/test/wallettests.cpp +++ b/src/qt/test/wallettests.cpp @@ -133,7 +133,7 @@ void TestGUI(interfaces::Node& node) TransactionView transactionView; OptionsModel optionsModel; ClientModel clientModel(node, &optionsModel); - WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel); + WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel); sendCoinsDialog.setModel(&walletModel); transactionView.setModel(&walletModel); diff --git a/src/wallet/context.h b/src/wallet/context.h index 059b7f10623959..8876f99b760c45 100644 --- a/src/wallet/context.h +++ b/src/wallet/context.h @@ -6,15 +6,25 @@ #define BITCOIN_WALLET_CONTEXT_H #include +#include + +#include +#include +#include +#include class ArgsManager; +class CWallet; namespace interfaces { class Chain; namespace CoinJoin { class Loader; } // namspace CoinJoin +class Wallet; } // namespace interfaces +using LoadWalletFn = std::function wallet)>; + //! WalletContext struct containing references to state shared between CWallet //! instances, like the reference to the chain interface, and the list of opened //! wallets. @@ -27,7 +37,10 @@ class Loader; //! behavior. struct WalletContext { interfaces::Chain* chain{nullptr}; - ArgsManager* args{nullptr}; + ArgsManager* args{nullptr}; // Currently a raw pointer because the memory is not managed by this struct + Mutex wallets_mutex; + std::vector> wallets GUARDED_BY(wallets_mutex); + std::list wallet_load_fns GUARDED_BY(wallets_mutex); // TODO: replace this unique_ptr to a pointer // probably possible to do after bitcoin/bitcoin#22219 const std::unique_ptr& m_coinjoin_loader; diff --git a/src/wallet/interfaces.cpp b/src/wallet/interfaces.cpp index af1c53e0fc0da2..49b2abefa69e8a 100644 --- a/src/wallet/interfaces.cpp +++ b/src/wallet/interfaces.cpp @@ -127,7 +127,7 @@ WalletTxOut MakeWalletTxOut(const CWallet& wallet, class WalletImpl : public Wallet { public: - explicit WalletImpl(const std::shared_ptr& wallet) : m_wallet(wallet) {} + explicit WalletImpl(WalletContext& context, const std::shared_ptr& wallet) : m_context(context), m_wallet(wallet) {} void markDirty() override { @@ -509,7 +509,7 @@ class WalletImpl : public Wallet CAmount getDefaultMaxTxFee() override { return m_wallet->m_default_max_tx_fee; } void remove() override { - RemoveWallet(m_wallet, false /* load_on_start */); + RemoveWallet(m_context, m_wallet, false /* load_on_start */); } bool isLegacy() override { return m_wallet->IsLegacy(); } std::unique_ptr handleUnload(UnloadFn fn) override @@ -555,6 +555,7 @@ class WalletImpl : public Wallet } CWallet* wallet() override { return m_wallet.get(); } + WalletContext& m_context; std::shared_ptr m_wallet; std::unique_ptr m_coinjoin_client; }; @@ -568,7 +569,7 @@ class WalletLoaderImpl : public WalletLoader m_context.chain = &chain; m_context.args = &args; } - ~WalletLoaderImpl() override { UnloadWallets(); } + ~WalletClientImpl() override { UnloadWallets(m_context); } //! ChainClient methods void registerRpcs() override @@ -582,11 +583,11 @@ class WalletLoaderImpl : public WalletLoader m_rpc_handlers.emplace_back(m_context.chain->handleRpc(m_rpc_commands.back())); } } - bool verify() override { return VerifyWallets(*m_context.chain); } - bool load() override { assert(m_context.m_coinjoin_loader); return LoadWallets(*m_context.chain, *m_context.m_coinjoin_loader); } - void start(CScheduler& scheduler) override { return StartWallets(scheduler, *Assert(m_context.args)); } - void flush() override { return FlushWallets(); } - void stop() override { return StopWallets(); } + bool verify() override { return VerifyWallets(m_context); } + bool load() override { assert(m_context.m_coinjoin_loader); return LoadWallets(m_context, *m_context.m_coinjoin_loader); } + void start(CScheduler& scheduler) override { return StartWallets(m_context, scheduler); } + void flush() override { return FlushWallets(m_context); } + void stop() override { return StopWallets(m_context); } void setMockTime(int64_t time) override { return SetMockTime(time); } //! WalletLoader methods @@ -599,7 +600,7 @@ class WalletLoaderImpl : public WalletLoader options.create_flags = wallet_creation_flags; options.create_passphrase = passphrase; assert(m_context.m_coinjoin_loader); - return MakeWallet(CreateWallet(*m_context.chain, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings)); + return MakeWallet(m_context, CreateWallet(*m_context.chain, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings)); } std::unique_ptr loadWallet(const std::string& name, bilingual_str& error, std::vector& warnings) override { @@ -607,7 +608,7 @@ class WalletLoaderImpl : public WalletLoader DatabaseStatus status; options.require_existing = true; assert(m_context.m_coinjoin_loader); - return MakeWallet(LoadWallet(*m_context.chain, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings)); + return MakeWallet(m_context, LoadWallet(m_context, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings)); } std::unique_ptr restoreWallet(const fs::path& backup_file, const std::string& wallet_name, bilingual_str& error, std::vector& warnings) override { @@ -630,15 +631,16 @@ class WalletLoaderImpl : public WalletLoader std::vector> getWallets() override { std::vector> wallets; - for (const auto& wallet : GetWallets()) { - wallets.emplace_back(MakeWallet(wallet)); + for (const auto& wallet : GetWallets(m_context)) { + wallets.emplace_back(MakeWallet(m_context, wallet)); } return wallets; } std::unique_ptr handleLoadWallet(LoadWalletFn fn) override { - return HandleLoadWallet(std::move(fn)); + return HandleLoadWallet(m_context, std::move(fn)); } + WalletContext* context() override { return &m_context; } WalletContext m_context; const std::vector m_wallet_filenames; @@ -649,7 +651,7 @@ class WalletLoaderImpl : public WalletLoader } // namespace wallet namespace interfaces { -std::unique_ptr MakeWallet(const std::shared_ptr& wallet) { return wallet ? std::make_unique(wallet) : nullptr; } +std::unique_ptr MakeWallet(WalletContext& context, const std::shared_ptr& wallet) { return wallet ? std::make_unique(wallet) : nullptr; } std::unique_ptr MakeWalletLoader(Chain& chain, const std::unique_ptr& coinjoin_loader, ArgsManager& args) { return std::make_unique(chain, coinjoin_loader, args); } diff --git a/src/wallet/load.cpp b/src/wallet/load.cpp index f5091b37949b42..5e7110e4972d09 100644 --- a/src/wallet/load.cpp +++ b/src/wallet/load.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -21,8 +22,9 @@ #include -bool VerifyWallets(interfaces::Chain& chain) +bool VerifyWallets(WalletContext& context) { + interfaces::Chain& chain = *context.chain; if (gArgs.IsArgSet("-walletdir")) { fs::path wallet_dir = fs::PathFromString(gArgs.GetArg("-walletdir", "")); std::error_code error; @@ -94,8 +96,9 @@ bool VerifyWallets(interfaces::Chain& chain) return true; } -bool LoadWallets(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader) +bool LoadWallets(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader) { + interfaces::Chain& chain = *context.chain; try { std::set wallet_paths; for (const std::string& name : gArgs.GetArgs("-wallet")) { @@ -113,13 +116,13 @@ bool LoadWallets(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoi continue; } chain.initMessage(_("Loading wallet...").translated); - std::shared_ptr pwallet = database ? CWallet::Create(&chain, &coinjoin_loader, name, std::move(database), options.create_flags, error_string, warnings) : nullptr; + std::shared_ptr pwallet = database ? CWallet::Create(context, &coinjoin_loader, name, std::move(database), options.create_flags, error_string, warnings) : nullptr; if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n"))); if (!pwallet) { chain.initError(error_string); return false; } - AddWallet(pwallet); + AddWallet(context, pwallet); } return true; } catch (const std::runtime_error& e) { @@ -128,22 +131,22 @@ bool LoadWallets(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoi } } -void StartWallets(CScheduler& scheduler, const ArgsManager& args) +void StartWallets(WalletContext& context, CScheduler& scheduler) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->postInitProcess(); } // Schedule periodic wallet flushes and tx rebroadcasts - if (args.GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { - scheduler.scheduleEvery(MaybeCompactWalletDB, std::chrono::milliseconds{500}); + if (context.args->GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { + scheduler.scheduleEvery([&context] { MaybeCompactWalletDB(context); }, std::chrono::milliseconds{500}); } - scheduler.scheduleEvery(MaybeResendWalletTxs, std::chrono::milliseconds{1000}); + scheduler.scheduleEvery([&context] { MaybeResendWalletTxs(context); }, std::chrono::milliseconds{1000}); } -void FlushWallets() +void FlushWallets(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { if (CCoinJoinClientOptions::IsEnabled()) { // Stop CoinJoin, release keys pwallet->coinjoin_loader().FlushWallet(pwallet->GetName()); @@ -152,21 +155,21 @@ void FlushWallets() } } -void StopWallets() +void StopWallets(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->Close(); } } -void UnloadWallets() +void UnloadWallets(WalletContext& context) { - auto wallets = GetWallets(); + auto wallets = GetWallets(context); while (!wallets.empty()) { auto wallet = wallets.back(); wallets.pop_back(); std::vector warnings; - RemoveWallet(wallet, std::nullopt, warnings); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt, warnings); UnloadWallet(std::move(wallet)); } } diff --git a/src/wallet/load.h b/src/wallet/load.h index cb7eb77b1b3c29..a293f86a7e72f3 100644 --- a/src/wallet/load.h +++ b/src/wallet/load.h @@ -12,6 +12,7 @@ class ArgsManager; class CConnman; class CScheduler; +struct WalletContext; namespace interfaces { class Chain; @@ -21,21 +22,21 @@ class Loader; } // namespace interfaces //! Responsible for reading and validating the -wallet arguments and verifying the wallet database. -bool VerifyWallets(interfaces::Chain& chain); +bool VerifyWallets(WalletContext& context); //! Load wallet databases. -bool LoadWallets(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader); +bool LoadWallets(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader); //! Complete startup of wallets. -void StartWallets(CScheduler& scheduler, const ArgsManager& args); +void StartWallets(WalletContext& context, CScheduler& scheduler); //! Flush all wallets in preparation for shutdown. -void FlushWallets(); +void FlushWallets(WalletContext& context); //! Stop all wallets. Wallets will be flushed first. -void StopWallets(); +void StopWallets(WalletContext& context); //! Close all wallets. -void UnloadWallets(); +void UnloadWallets(WalletContext& context); #endif // BITCOIN_WALLET_LOAD_H diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index fdea130b65c7a3..4d127e5f9baba0 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -103,15 +103,19 @@ bool GetWalletNameFromJSONRPCRequest(const JSONRPCRequest& request, std::string& std::shared_ptr GetWalletForJSONRPCRequest(const JSONRPCRequest& request) { CHECK_NONFATAL(request.mode == JSONRPCRequest::EXECUTE); +<<<<<<< HEAD +======= + WalletContext& context = EnsureWalletContext(request.context); +>>>>>>> 638855af63... Merge bitcoin/bitcoin#19101: refactor: remove ::vpwallets and related global variables std::string wallet_name; if (GetWalletNameFromJSONRPCRequest(request, wallet_name)) { - std::shared_ptr pwallet = GetWallet(wallet_name); + std::shared_ptr pwallet = GetWallet(context, wallet_name); if (!pwallet) throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); return pwallet; } - std::vector> wallets = GetWallets(); + std::vector> wallets = GetWallets(context); if (wallets.size() == 1) { return wallets[0]; } @@ -2715,7 +2719,8 @@ static RPCHelpMan listwallets() { UniValue obj(UniValue::VARR); - for (const std::shared_ptr& wallet : GetWallets()) { + WalletContext& context = EnsureWalletContext(request.context); + for (const std::shared_ptr& wallet : GetWallets(context)) { LOCK(wallet->cs_wallet); obj.push_back(wallet->GetName()); } @@ -3039,7 +3044,7 @@ static RPCHelpMan createwallet() options.create_passphrase = passphrase; bilingual_str error; std::optional load_on_start = request.params[6].isNull() ? std::nullopt : std::optional(request.params[6].get_bool()); - std::shared_ptr wallet = CreateWallet(*context.chain, *context.m_coinjoin_loader, request.params[0].get_str(), load_on_start, options, status, error, warnings); + std::shared_ptr wallet = CreateWallet(context, *context.m_coinjoin_loader, request.params[0].get_str(), load_on_start, options, status, error, warnings); if (!wallet) { RPCErrorCode code = status == DatabaseStatus::FAILED_ENCRYPT ? RPC_WALLET_ENCRYPTION_FAILED : RPC_WALLET_ERROR; throw JSONRPCError(code, error.original); @@ -3134,7 +3139,8 @@ static RPCHelpMan unloadwallet() wallet_name = request.params[0].get_str(); } - std::shared_ptr wallet = GetWallet(wallet_name); + WalletContext& context = EnsureWalletContext(request.context); + std::shared_ptr wallet = GetWallet(context, wallet_name); if (!wallet) { throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); } @@ -3144,7 +3150,7 @@ static RPCHelpMan unloadwallet() // is destroyed (see CheckUniqueFileid). std::vector warnings; std::optional load_on_start = request.params[1].isNull() ? std::nullopt : std::optional(request.params[1].get_bool()); - if (!RemoveWallet(wallet, load_on_start, warnings)) { + if (!RemoveWallet(context, wallet, load_on_start, warnings)) { throw JSONRPCError(RPC_MISC_ERROR, "Requested wallet already unloaded"); } diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index ef6218ed6f7127..6d3e665e2c6fcd 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -38,8 +39,6 @@ extern RPCHelpMan getrawchangeaddress(); extern RPCHelpMan getaddressinfo(); extern RPCHelpMan addmultisigaddress(); -extern RecursiveMutex cs_wallets; - // Ensure that fee levels defined in the wallet are at least as high // as the default levels for node policy. static_assert(DEFAULT_TRANSACTION_MINFEE >= DEFAULT_MIN_RELAY_TX_FEE, "wallet minimum fee is smaller than default relay fee"); @@ -51,15 +50,15 @@ namespace { constexpr CAmount fallbackFee = 1000; } // anonymous namespace -static std::shared_ptr TestLoadWallet(interfaces::Chain* chain, interfaces::CoinJoin::Loader* coinjoin_loader) +static std::shared_ptr TestLoadWallet(WalletContext& context, interfaces::CoinJoin::Loader* coinjoin_loader) { DatabaseOptions options; DatabaseStatus status; bilingual_str error; std::vector warnings; auto database = MakeWalletDatabase("", options, status, error); - auto wallet = CWallet::Create(chain, coinjoin_loader, "", std::move(database), options.create_flags, error, warnings); - if (chain) { + auto wallet = CWallet::Create(context, coinjoin_loader, "", std::move(database), options.create_flags, error, warnings); + if (context.chain) { wallet->postInitProcess(); } return wallet; @@ -220,7 +219,8 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) std::shared_ptr wallet = std::make_shared(m_node.chain.get(), m_node.coinjoin_loader.get(), "", CreateDummyWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); WITH_LOCK(wallet->cs_wallet, wallet->SetLastBlockProcessed(newTip->nHeight, newTip->GetBlockHash())); - AddWallet(wallet); + WalletContext context; + AddWallet(context, wallet); UniValue keys; keys.setArray(); UniValue key; @@ -238,6 +238,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) key.pushKV("internal", UniValue(true)); keys.push_back(key); JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(keys); @@ -251,7 +252,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) "downloading and rescanning the relevant blocks (see -reindex and -rescan " "options).\"}},{\"success\":true}]", 0, oldTip->GetBlockTimeMax(), TIMESTAMP_WINDOW)); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); } } @@ -278,6 +279,7 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) // Import key into wallet and call dumpwallet to create backup file. { + WalletContext context; std::shared_ptr wallet = std::make_shared(m_node.chain.get(), m_node.coinjoin_loader.get(), "", CreateDummyWalletDatabase()); { auto spk_man = wallet->GetOrCreateLegacyScriptPubKeyMan(); @@ -285,15 +287,16 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) spk_man->mapKeyMetadata[coinbaseKey.GetPubKey().GetID()].nCreateTime = KEY_TIME; spk_man->AddKeyPubKey(coinbaseKey, coinbaseKey.GetPubKey()); - AddWallet(wallet); + AddWallet(context, wallet); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); } JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(backup_file); ::dumpwallet().HandleRequest(request); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); } // Call importwallet RPC and verify all blocks with timestamps >= BLOCK_TIME @@ -303,13 +306,15 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) LOCK(wallet->cs_wallet); wallet->SetupLegacyScriptPubKeyMan(); + WalletContext context; JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(backup_file); - AddWallet(wallet); + AddWallet(context, wallet); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); ::importwallet().HandleRequest(request); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); BOOST_CHECK_EQUAL(wallet->mapWallet.size(), 3U); BOOST_CHECK_EQUAL(m_coinbase_txns.size(), 103U); @@ -1232,7 +1237,9 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) { gArgs.ForceSetArg("-unsafesqlitesync", "1"); // Create new wallet with known key and unload it. - auto wallet = TestLoadWallet(m_node.chain.get(), m_node.coinjoin_loader.get()); + WalletContext context; + context.chain = m_node.chain.get(); + auto wallet = TestLoadWallet(context, m_node.coinjoin_loader.get()); CKey key; key.MakeNewKey(true); AddKey(*wallet, key); @@ -1272,7 +1279,7 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) // Reload wallet and make sure new transactions are detected despite events // being blocked - wallet = TestLoadWallet(m_node.chain.get(), m_node.coinjoin_loader.get()); + wallet = TestLoadWallet(context, m_node.coinjoin_loader.get()); BOOST_CHECK(rescan_completed); BOOST_CHECK_EQUAL(addtx_count, 2); { @@ -1298,7 +1305,7 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) // deadlock during the sync and simulates a new block notification happening // as soon as possible. addtx_count = 0; - auto handler = HandleLoadWallet([&](std::unique_ptr wallet) { + auto handler = HandleLoadWallet(context, [&](std::unique_ptr wallet) { BOOST_CHECK(rescan_completed); m_coinbase_txns.push_back(CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); block_tx = TestSimpleSpend(*m_coinbase_txns[2], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); @@ -1307,7 +1314,7 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) BOOST_CHECK(m_node.chain->broadcastTransaction(MakeTransactionRef(mempool_tx), DEFAULT_TRANSACTION_MAXFEE, false, error)); SyncWithValidationInterfaceQueue(); }); - wallet = TestLoadWallet(m_node.chain.get(), m_node.coinjoin_loader.get()); + wallet = TestLoadWallet(context, m_node.coinjoin_loader.get()); BOOST_CHECK_EQUAL(addtx_count, 4); { LOCK(wallet->cs_wallet); @@ -1388,8 +1395,8 @@ BOOST_FIXTURE_TEST_CASE(wallet_descriptor_test, BasicTestingSetup) BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup) { - // TODO: FIX FIX FIX - coinjoin_loader is null heere! - auto wallet = TestLoadWallet(nullptr, nullptr); + WalletContext context; + auto wallet = TestLoadWallet(context, nullptr); BOOST_CHECK(wallet); UnloadWallet(std::move(wallet)); } @@ -1398,7 +1405,9 @@ BOOST_FIXTURE_TEST_CASE(ZapSelectTx, TestChain100Setup) { gArgs.ForceSetArg("-unsafesqlitesync", "1"); auto chain = interfaces::MakeChain(m_node); - auto wallet = TestLoadWallet(m_node.chain.get(), m_node.coinjoin_loader.get()); + WalletContext context; + context.chain = m_node.chain.get(); + auto wallet = TestLoadWallet(context, m_node.coinjoin_loader.get()); CKey key; key.MakeNewKey(true); AddKey(*wallet, key); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 14ba64998d6ac6..3ad1fb8ec2209d 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -38,6 +38,7 @@ #endif #include #include +#include #include #include @@ -62,9 +63,6 @@ const std::map WALLET_FLAG_CAVEATS{ static constexpr size_t OUTPUT_GROUP_MAX_ENTRIES{100}; -RecursiveMutex cs_wallets; -static std::vector> vpwallets GUARDED_BY(cs_wallets); -static std::list g_load_wallet_fns GUARDED_BY(cs_wallets); bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name) { @@ -112,14 +110,14 @@ static void RefreshMempoolStatus(CWalletTx& tx, interfaces::Chain& chain) tx.fInMempool = chain.isInMempool(tx.GetHash()); } -bool AddWallet(const std::shared_ptr& wallet) +bool AddWallet(WalletContext& context, const std::shared_ptr& wallet) { { - LOCK(cs_wallets); + LOCK(context.wallets_mutex); assert(wallet); - std::vector>::const_iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); - if (i != vpwallets.end()) return false; - vpwallets.push_back(wallet); + std::vector>::const_iterator i = std::find(context.wallets..begin(), context.wallets.end(), wallet); + if (i != context.wallets..end()) return false; + context.wallets.push_back(wallet); } wallet->ConnectScriptPubKeyManNotifiers(); wallet->AutoLockMasternodeCollaterals(); @@ -128,7 +126,7 @@ bool AddWallet(const std::shared_ptr& wallet) return true; } -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings) +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings) { assert(wallet); @@ -138,10 +136,10 @@ bool RemoveWallet(const std::shared_ptr& wallet, std::optional lo // Unregister with the validation interface which also drops shared pointers. wallet->m_chain_notifications_handler.reset(); { - LOCK(cs_wallets); - std::vector>::iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); - if (i == vpwallets.end()) return false; - vpwallets.erase(i); + LOCK(context.wallets_mutex); + std::vector>::iterator i = std::find(context.wallets.begin(), context.wallets.end(), wallet); + if (i == end()) return false; + context.wallets.erase(i); } wallet->coinjoin_loader().RemoveWallet(name); @@ -152,32 +150,32 @@ bool RemoveWallet(const std::shared_ptr& wallet, std::optional lo return true; } -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start) +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start) { std::vector warnings; - return RemoveWallet(wallet, load_on_start, warnings); + return RemoveWallet(context, wallet, load_on_start, warnings); } -std::vector> GetWallets() +std::vector> GetWallets(WalletContext& context) { - LOCK(cs_wallets); - return vpwallets; + LOCK(context.wallets_mutex); + return context.wallets; } -std::shared_ptr GetWallet(const std::string& name) +std::shared_ptr GetWallet(WalletContext& context, const std::string& name) { - LOCK(cs_wallets); - for (const std::shared_ptr& wallet : vpwallets) { + LOCK(context.wallets_mutex); + for (const std::shared_ptr& wallet : context.wallets) { if (wallet->GetName() == name) return wallet; } return nullptr; } -std::unique_ptr HandleLoadWallet(LoadWalletFn load_wallet) +std::unique_ptr HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet) { - LOCK(cs_wallets); - auto it = g_load_wallet_fns.emplace(g_load_wallet_fns.end(), std::move(load_wallet)); - return interfaces::MakeHandler([it] { LOCK(cs_wallets); g_load_wallet_fns.erase(it); }); + LOCK(context.wallets_mutex); + auto it = context.wallet_load_fns.emplace(context.wallet_load_fns.end(), std::move(load_wallet)); + return interfaces::MakeHandler([&context, it] { LOCK(context.wallets_mutex); context.wallet_load_fns.erase(it); }); } static Mutex g_loading_wallet_mutex; @@ -229,7 +227,7 @@ void UnloadWallet(std::shared_ptr&& wallet) } namespace { -std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr LoadWalletInternal(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { try { std::unique_ptr database = MakeWalletDatabase(name, options, status, error); @@ -238,18 +236,18 @@ std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, interfaces return nullptr; } - chain.initMessage(_("Loading wallet...").translated); - std::shared_ptr wallet = CWallet::Create(&chain, &coinjoin_loader, name, std::move(database), options.create_flags, error, warnings); + context.chain.initMessage(_("Loading wallet...").translated); + std::shared_ptr wallet = CWallet::Create(context, &coinjoin_loader, name, std::move(database), options.create_flags, error, warnings); if (!wallet) { error = Untranslated("Wallet loading failed.") + Untranslated(" ") + error; status = DatabaseStatus::FAILED_LOAD; return nullptr; } - AddWallet(wallet); + AddWallet(context, wallet); wallet->postInitProcess(); // Write the wallet setting - UpdateWalletSetting(chain, name, load_on_start, warnings); + UpdateWalletSetting(*context.chain, name, load_on_start, warnings); return wallet; } catch (const std::runtime_error& e) { @@ -260,7 +258,7 @@ std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, interfaces } } // namespace -std::shared_ptr LoadWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr LoadWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(name)); if (!result.second) { @@ -268,12 +266,12 @@ std::shared_ptr LoadWallet(interfaces::Chain& chain, interfaces::CoinJo status = DatabaseStatus::FAILED_LOAD; return nullptr; } - auto wallet = LoadWalletInternal(chain, coinjoin_loader, name, load_on_start, options, status, error, warnings); + auto wallet = LoadWalletInternal(context, coinjoin_loader, name, load_on_start, options, status, error, warnings); WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); return wallet; } -std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr CreateWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { uint64_t wallet_creation_flags = options.create_flags; const SecureString& passphrase = options.create_passphrase; @@ -304,8 +302,8 @@ std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::Coin } // Make the wallet - chain.initMessage(_("Loading wallet...").translated); - std::shared_ptr wallet = CWallet::Create(&chain, &coinjoin_loader, name, std::move(database), wallet_creation_flags, error, warnings); + context.chain->initMessage(_("Loading wallet...").translated); + std::shared_ptr wallet = CWallet::Create(context, &coinjoin_loader, name, std::move(database), wallet_creation_flags, error, warnings); if (!wallet) { error = Untranslated("Wallet creation failed.") + Untranslated(" ") + error; status = DatabaseStatus::FAILED_CREATE; @@ -357,17 +355,17 @@ std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::Coin wallet->Lock(); } } - AddWallet(wallet); + AddWallet(context, wallet); wallet->postInitProcess(); // Write the wallet settings - UpdateWalletSetting(chain, name, load_on_start, warnings); + UpdateWalletSetting(*context.chain, name, load_on_start, warnings); status = DatabaseStatus::SUCCESS; return wallet; } -std::shared_ptr RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr RestoreWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { DatabaseOptions options; options.require_existing = true; @@ -389,7 +387,7 @@ std::shared_ptr RestoreWallet(interfaces::Chain& chain, interfaces::Coi auto wallet_file = wallet_path / "wallet.dat"; fs::copy_file(backup_file, wallet_file, fs::copy_options::none); - auto wallet = LoadWallet(chain, coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings); + auto wallet = LoadWallet(context, coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings); if (!wallet) { fs::remove(wallet_file); @@ -2497,9 +2495,9 @@ void CWallet::ResendWalletTransactions() /** @} */ // end of mapWallet -void MaybeResendWalletTxs() +void MaybeResendWalletTxs(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->ResendWalletTransactions(); } } @@ -4694,8 +4692,9 @@ std::unique_ptr MakeWalletDatabase(const std::string& name, cons return MakeDatabase(wallet_path, options, status, error_string); } -std::shared_ptr CWallet::Create(interfaces::Chain* chain, interfaces::CoinJoin::Loader* coinjoin_loader, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings) +std::shared_ptr CWallet::Create(WalletContext& context, interfaces::CoinJoin::Loader* coinjoin_loader, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings) { + interfaces::Chain* chain = context.chain; const std::string& walletFile = database->Filename(); int64_t nStart = GetTimeMillis(); @@ -4958,9 +4957,9 @@ std::shared_ptr CWallet::Create(interfaces::Chain* chain, interfaces::C } { - LOCK(cs_wallets); - for (auto& load_wallet : g_load_wallet_fns) { - load_wallet(interfaces::MakeWallet(walletInstance)); + LOCK(context.wallets_mutex); + for (auto& load_wallet : context.wallet_load_fns) { + load_wallet(interfaces::MakeWallet(context, walletInstance)); } } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 194d734061595d..2529234e77f138 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -45,6 +45,8 @@ #include +struct WalletContext; + struct bilingual_str; using LoadWalletFn = std::function wallet)>; @@ -56,15 +58,15 @@ using LoadWalletFn = std::function wall // by the shared pointer deleter. void UnloadWallet(std::shared_ptr&& wallet); -bool AddWallet(const std::shared_ptr& wallet); -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings); -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start); -std::vector> GetWallets(); -std::shared_ptr GetWallet(const std::string& name); -std::shared_ptr LoadWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::shared_ptr CreateWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::shared_ptr RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::unique_ptr HandleLoadWallet(LoadWalletFn load_wallet); +bool AddWallet(WalletContext& context, const std::shared_ptr& wallet); +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings); +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start); +std::vector> GetWallets(WalletContext& context); +std::shared_ptr GetWallet(WalletContext& context, const std::string& name); +std::shared_ptr LoadWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::shared_ptr CreateWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::shared_ptr RestoreWallet(WalletContext& context, interfaces::CoinJoin::Loader& coinjoin_loader, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::unique_ptr HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet); std::unique_ptr MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error); //! -paytxfee default @@ -1346,7 +1348,7 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati bool ResendTransaction(const uint256& hashTx); /* Initializes the wallet, returns a new CWallet instance or a null pointer in case of an error */ - static std::shared_ptr Create(interfaces::Chain* chain, interfaces::CoinJoin::Loader* coinjoin_loader, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings); + static std::shared_ptr Create(WalletContext& context, interfaces::CoinJoin::Loader* coinjoin_loader, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings); /** * Wallet post-init setup @@ -1515,7 +1517,7 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati * Called periodically by the schedule thread. Prompts individual wallets to resend * their transactions. Actual rebroadcast schedule is managed by the wallets themselves. */ -void MaybeResendWalletTxs(); +void MaybeResendWalletTxs(WalletContext& context); /** RAII object to check and reserve a wallet rescan */ class WalletRescanReserver diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 9042bd25bcd68d..ca0790c7d6e828 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -991,14 +991,14 @@ DBErrors WalletBatch::ZapSelectTx(std::vector& vTxHashIn, std::vector fOneThread(false); if (fOneThread.exchange(true)) { return; } - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { WalletDatabase& dbh = pwallet->GetDatabase(); unsigned int nUpdateCounter = dbh.nUpdateCounter; diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index 7c4427123dc6d8..92def04e2f18c3 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -33,6 +33,7 @@ static const bool DEFAULT_FLUSHWALLET = true; struct CBlockLocator; class CHDChain; class CHDPubKey; +struct WalletContext; class CKeyPool; class CMasterKey; class CScript; @@ -251,7 +252,7 @@ class WalletBatch }; //! Compacts BDB state so that wallet.dat is self-contained (if there are changes) -void MaybeCompactWalletDB(); +void MaybeCompactWalletDB(WalletContext& context); //! Callback for filtering key types to deserialize in ReadKeyValue using KeyFilterFn = std::function;