From 6d35c438764d22b6d16e30bd697c7de3c718b2e6 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 27 Jun 2024 17:57:51 +0200 Subject: [PATCH] Implement on-the-fly funding Implement the on-the-fly funding protocol: when a payment cannot be relayed because of a liquidity issue, we notify the `Peer` actor that we'd like to trigger on-the-fly funding if available. If available, we we send a funding proposal to our peer and keep track of its status. Once a matching funding transaction is signed, we persist this funding attempt and wait for the additional liquidity to be available (once the channel is ready or the splice locked). We will then frequently try to relay the payment to get paid our liquidity fees. If the payment keeps getting rejected, or we cannot connect to our peer, we abandon the payment when it reaches its CLTV expiry, which ensures that the upstream channels are not at risk. When using on-the-fly funding, we use a single channel with our peer. If they try to open another channel while one is available, we reject their request and expect a splice instead. --- eclair-core/src/main/resources/reference.conf | 8 + .../scala/fr/acinq/eclair/NodeParams.scala | 8 +- .../fr/acinq/eclair/channel/ChannelData.scala | 6 +- .../acinq/eclair/channel/ChannelEvents.scala | 3 + .../fr/acinq/eclair/channel/fsm/Channel.scala | 18 +- .../channel/fsm/ChannelOpenDualFunded.scala | 12 +- .../channel/fsm/CommonFundingHandlers.scala | 1 + .../channel/fund/InteractiveTxBuilder.scala | 16 +- .../scala/fr/acinq/eclair/db/Databases.scala | 5 + .../fr/acinq/eclair/db/DbEventHandler.scala | 3 + .../fr/acinq/eclair/db/DualDatabases.scala | 42 +- .../acinq/eclair/db/OnTheFlyFundingDb.scala | 47 + .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 47 +- .../eclair/db/pg/PgOnTheFlyFundingDb.scala | 147 +++ .../eclair/db/sqlite/SqliteAuditDb.scala | 48 +- .../db/sqlite/SqliteOnTheFlyFundingDb.scala | 129 +++ .../fr/acinq/eclair/io/MessageRelay.scala | 2 +- .../scala/fr/acinq/eclair/io/Monitoring.scala | 16 + .../eclair/io/OpenChannelInterceptor.scala | 54 +- .../main/scala/fr/acinq/eclair/io/Peer.scala | 343 ++++++- .../fr/acinq/eclair/io/Switchboard.scala | 13 +- .../fr/acinq/eclair/payment/Monitoring.scala | 2 + .../acinq/eclair/payment/PaymentEvents.scala | 8 + .../eclair/payment/relay/ChannelRelay.scala | 65 +- .../eclair/payment/relay/NodeRelay.scala | 88 +- .../payment/relay/OnTheFlyFunding.scala | 318 +++++++ .../relay/PostRestartHtlcCleaner.scala | 21 +- .../acinq/eclair/payment/relay/Relayer.scala | 13 +- .../scala/fr/acinq/eclair/TestConstants.scala | 55 +- .../scala/fr/acinq/eclair/TestDatabases.scala | 1 + .../channel/InteractiveTxBuilderSpec.scala | 7 +- .../b/WaitForDualFundingSignedStateSpec.scala | 49 +- .../c/WaitForDualFundingReadyStateSpec.scala | 20 +- .../states/e/NormalSplicesStateSpec.scala | 84 +- .../channel/states/e/NormalStateSpec.scala | 2 +- .../channel/states/e/OfflineStateSpec.scala | 31 + .../eclair/db/OnTheFlyFundingDbSpec.scala | 138 +++ .../fr/acinq/eclair/io/MessageRelaySpec.scala | 4 +- .../io/OpenChannelInterceptorSpec.scala | 79 +- .../fr/acinq/eclair/io/SwitchboardSpec.scala | 23 +- .../eclair/payment/PaymentPacketSpec.scala | 12 +- .../payment/PostRestartHtlcCleanerSpec.scala | 95 +- .../payment/relay/ChannelRelayerSpec.scala | 88 +- .../payment/relay/NodeRelayerSpec.scala | 167 +++- .../payment/relay/OnTheFlyFundingSpec.scala | 876 ++++++++++++++++++ .../eclair/payment/relay/RelayerSpec.scala | 32 +- .../channel/version1/ChannelCodecs1Spec.scala | 2 +- 47 files changed, 3069 insertions(+), 179 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/db/OnTheFlyFundingDb.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgOnTheFlyFundingDb.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteOnTheFlyFundingDb.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/db/OnTheFlyFundingDbSpec.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/OnTheFlyFundingSpec.scala diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index b1005133c4..fbdc8ee089 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -339,6 +339,14 @@ eclair { ] } + // On-the-fly funding leverages liquidity ads to fund channels with wallet peers based on their payment patterns. + on-the-fly-funding { + wake-up-timeout = 30 seconds + // If our peer doesn't respond to our funding proposal, we must fail the corresponding upstream HTLCs. + // Since MPP may be used, we should use a timeout greater than the MPP timeout. + proposal-timeout = 90 seconds + } + peer-connection { auth-timeout = 15 seconds // will disconnect if connection authentication doesn't happen within that timeframe init-timeout = 15 seconds // will disconnect if initialization doesn't happen within that timeframe diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 41f7e87a0d..deb31dee7b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -30,6 +30,7 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy} import fr.acinq.eclair.io.PeerConnection import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig +import fr.acinq.eclair.payment.relay.OnTheFlyFunding import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Announcements.AddressException import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios} @@ -90,7 +91,7 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, purgeInvoicesInterval: Option[FiniteDuration], revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, willFundRates_opt: Option[LiquidityAds.WillFundRates], - wakeUpTimeout: FiniteDuration) { + onTheFlyFundingConfig: OnTheFlyFunding.Config) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -663,7 +664,10 @@ object NodeParams extends Logging { interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) ), willFundRates_opt = willFundRates_opt, - wakeUpTimeout = 30 seconds, + onTheFlyFundingConfig = OnTheFlyFunding.Config( + wakeUpTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.wake-up-timeout").getSeconds, TimeUnit.SECONDS), + proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS), + ), ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala index 2f5da6f47c..4865374874 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala @@ -146,7 +146,7 @@ object Upstream { val expiryIn: CltvExpiry = add.cltvExpiry } /** Our node is forwarding a payment based on a set of HTLCs from potentially multiple upstream channels. */ - case class Trampoline(received: Seq[Channel]) extends Hot { + case class Trampoline(received: List[Channel]) extends Hot { override val amountIn: MilliSatoshi = received.map(_.add.amountMsat).sum // We must use the lowest expiry of the incoming HTLC set. val expiryIn: CltvExpiry = received.map(_.add.cltvExpiry).min @@ -165,6 +165,10 @@ object Upstream { /** Our node is forwarding a single incoming HTLC. */ case class Channel(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi) extends Cold + object Channel { + def apply(add: UpdateAddHtlc): Channel = Channel(add.channelId, add.id, add.amountMsat) + } + /** Our node is forwarding a payment based on a set of HTLCs from potentially multiple upstream channels. */ case class Trampoline(originHtlcs: List[Channel]) extends Cold { override val amountIn: MilliSatoshi = originHtlcs.map(_.amountIn).sum } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala index ff36bd5fb1..b068099a79 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala @@ -54,6 +54,9 @@ case class ChannelAborted(channel: ActorRef, remoteNodeId: PublicKey, channelId: /** This event will be sent once a channel has been successfully opened and is ready to process payments. */ case class ChannelOpened(channel: ActorRef, remoteNodeId: PublicKey, channelId: ByteVector32) extends ChannelEvent +/** This event is sent once channel_ready or splice_locked have been exchanged. */ +case class ChannelReadyForPayments(channel: ActorRef, remoteNodeId: PublicKey, channelId: ByteVector32, fundingTxIndex: Long) extends ChannelEvent + case class LocalChannelUpdate(channel: ActorRef, channelId: ByteVector32, shortIds: ShortIds, remoteNodeId: PublicKey, channelAnnouncement_opt: Option[ChannelAnnouncement], channelUpdate: ChannelUpdate, commitments: Commitments) extends ChannelEvent { /** * We always include the local alias because we must always be able to route based on it. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index a104bdbc8b..a067b04df4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -44,6 +44,7 @@ import fr.acinq.eclair.crypto.keymanager.ChannelKeyManager import fr.acinq.eclair.db.DbEventHandler.ChannelEvent.EventType import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.io.Peer +import fr.acinq.eclair.io.Peer.LiquidityPurchaseSigned import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.{Bolt11Invoice, PaymentSettlingOnChain} import fr.acinq.eclair.router.Announcements @@ -1095,10 +1096,13 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with log.info("ignoring outgoing interactive-tx message {} from previous session", msg.getClass.getSimpleName) stay() } - case InteractiveTxBuilder.Succeeded(signingSession, commitSig) => + case InteractiveTxBuilder.Succeeded(signingSession, commitSig, liquidityPurchase_opt) => log.info(s"splice tx created with fundingTxIndex=${signingSession.fundingTxIndex} fundingTxId=${signingSession.fundingTx.txId}") cmd_opt.foreach(cmd => cmd.replyTo ! RES_SPLICE(fundingTxIndex = signingSession.fundingTxIndex, signingSession.fundingTx.txId, signingSession.fundingParams.fundingAmount, signingSession.localCommit.fold(_.spec, _.spec).toLocal)) remoteCommitSig_opt.foreach(self ! _) + liquidityPurchase_opt.collect { + case purchase if !signingSession.fundingParams.isInitiator => peer ! LiquidityPurchaseSigned(d.channelId, signingSession.fundingTx.txId, signingSession.fundingTxIndex, d.commitments.params.remoteParams.htlcMinimum, purchase) + } val d1 = d.copy(spliceStatus = SpliceStatus.SpliceWaitingForSigs(signingSession)) stay() using d1 storing() sending commitSig case f: InteractiveTxBuilder.Failed => @@ -2139,6 +2143,12 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with } } + // We tell the peer that the channel is ready to process payments that may be queued. + if (!shutdownInProgress) { + val fundingTxIndex = commitments1.active.map(_.fundingTxIndex).min + peer ! ChannelReadyForPayments(self, remoteNodeId, d.channelId, fundingTxIndex) + } + goto(NORMAL) using d.copy(commitments = commitments1, spliceStatus = spliceStatus1) sending sendQueue } @@ -2710,6 +2720,12 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with if (oldCommitments.availableBalanceForSend != newCommitments.availableBalanceForSend || oldCommitments.availableBalanceForReceive != newCommitments.availableBalanceForReceive) { context.system.eventStream.publish(AvailableBalanceChanged(self, newCommitments.channelId, shortIds, newCommitments)) } + if (oldCommitments.active.size != newCommitments.active.size) { + // Some commitments have been deactivated, which means our available balance changed, which may allow forwarding + // payments that couldn't be forwarded before. + val fundingTxIndex = newCommitments.active.map(_.fundingTxIndex).min + peer ! ChannelReadyForPayments(self, remoteNodeId, newCommitments.channelId, fundingTxIndex) + } } private def maybeUpdateMaxHtlcAmount(currentMaxHtlcAmount: MilliSatoshi, newCommitments: Commitments): Unit = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenDualFunded.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenDualFunded.scala index 77250a01db..3e2821b60e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenDualFunded.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenDualFunded.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.fund.InteractiveTxBuilder.{FullySignedSharedTrans import fr.acinq.eclair.channel.fund.{InteractiveTxBuilder, InteractiveTxSigningSession} import fr.acinq.eclair.channel.publish.TxPublisher.SetChannelId import fr.acinq.eclair.crypto.ShaChain -import fr.acinq.eclair.io.Peer.OpenChannelResponse +import fr.acinq.eclair.io.Peer.{LiquidityPurchaseSigned, OpenChannelResponse} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{MilliSatoshiLong, RealShortChannelId, ToMilliSatoshiConversion, UInt64, randomBytes32} @@ -339,9 +339,12 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers { case Event(msg: InteractiveTxBuilder.Response, d: DATA_WAIT_FOR_DUAL_FUNDING_CREATED) => msg match { case InteractiveTxBuilder.SendMessage(_, msg) => stay() sending msg - case InteractiveTxBuilder.Succeeded(status, commitSig) => + case InteractiveTxBuilder.Succeeded(status, commitSig, liquidityPurchase_opt) => d.deferred.foreach(self ! _) d.replyTo_opt.foreach(_ ! OpenChannelResponse.Created(d.channelId, status.fundingTx.txId, status.fundingTx.tx.localFees.truncateToSatoshi)) + liquidityPurchase_opt.collect { + case purchase if !status.fundingParams.isInitiator => peer ! LiquidityPurchaseSigned(d.channelId, status.fundingTx.txId, status.fundingTxIndex, d.channelParams.remoteParams.htlcMinimum, purchase) + } val d1 = DATA_WAIT_FOR_DUAL_FUNDING_SIGNED(d.channelParams, d.secondRemotePerCommitmentPoint, d.localPushAmount, d.remotePushAmount, status, None) goto(WAIT_FOR_DUAL_FUNDING_SIGNED) using d1 storing() sending commitSig case f: InteractiveTxBuilder.Failed => @@ -687,9 +690,12 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers { case RbfStatus.RbfInProgress(cmd_opt, _, remoteCommitSig_opt) => msg match { case InteractiveTxBuilder.SendMessage(_, msg) => stay() sending msg - case InteractiveTxBuilder.Succeeded(signingSession, commitSig) => + case InteractiveTxBuilder.Succeeded(signingSession, commitSig, liquidityPurchase_opt) => cmd_opt.foreach(cmd => cmd.replyTo ! RES_BUMP_FUNDING_FEE(rbfIndex = d.previousFundingTxs.length, signingSession.fundingTx.txId, signingSession.fundingTx.tx.localFees.truncateToSatoshi)) remoteCommitSig_opt.foreach(self ! _) + liquidityPurchase_opt.collect { + case purchase if !signingSession.fundingParams.isInitiator => peer ! LiquidityPurchaseSigned(d.channelId, signingSession.fundingTx.txId, signingSession.fundingTxIndex, d.commitments.params.remoteParams.htlcMinimum, purchase) + } val d1 = d.copy(rbfStatus = RbfStatus.RbfWaitingForSigs(signingSession)) stay() using d1 storing() sending commitSig case f: InteractiveTxBuilder.Failed => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/CommonFundingHandlers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/CommonFundingHandlers.scala index e827205890..da66aa0253 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/CommonFundingHandlers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/CommonFundingHandlers.scala @@ -135,6 +135,7 @@ trait CommonFundingHandlers extends CommonHandlers { // used to get the final shortChannelId, used in announcements (if minDepth >= ANNOUNCEMENTS_MINCONF this event will fire instantly) blockchain ! WatchFundingDeeplyBuried(self, commitments.latest.fundingTxId, ANNOUNCEMENTS_MINCONF) val commitments1 = commitments.modify(_.remoteNextCommitInfo).setTo(Right(channelReady.nextPerCommitmentPoint)) + peer ! ChannelReadyForPayments(self, remoteNodeId, commitments.channelId, fundingTxIndex = 0) DATA_NORMAL(commitments1, shortIds1, None, initialChannelUpdate, None, None, None, SpliceStatus.NoSplice) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fund/InteractiveTxBuilder.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fund/InteractiveTxBuilder.scala index 0b9cfd32ca..b85ca1b15d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fund/InteractiveTxBuilder.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fund/InteractiveTxBuilder.scala @@ -90,7 +90,7 @@ object InteractiveTxBuilder { sealed trait Response case class SendMessage(sessionId: ByteVector32, msg: LightningMessage) extends Response - case class Succeeded(signingSession: InteractiveTxSigningSession.WaitingForSigs, commitSig: CommitSig) extends Response + case class Succeeded(signingSession: InteractiveTxSigningSession.WaitingForSigs, commitSig: CommitSig, liquidityPurchase_opt: Option[LiquidityAds.Purchase]) extends Response sealed trait Failed extends Response { def cause: ChannelException } case class LocalFailure(cause: ChannelException) extends Failed case class RemoteFailure(cause: ChannelException) extends Failed @@ -370,12 +370,24 @@ object InteractiveTxBuilder { // Note that pending HTLCs are ignored: splices only affect the main outputs. val nextLocalBalance = purpose.previousLocalBalance + fundingParams.localContribution - localPushAmount + remotePushAmount - liquidityFee val nextRemoteBalance = purpose.previousRemoteBalance + fundingParams.remoteContribution - remotePushAmount + localPushAmount + liquidityFee + val liquidityPaymentTypeOk = liquidityPurchase_opt match { + case Some(l) if !fundingParams.isInitiator => l.paymentDetails match { + case LiquidityAds.PaymentDetails.FromChannelBalance | _: LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc => true + // If our peer has enough balance to pay the liquidity fees, they shouldn't use future HTLCs which + // involves trust: they should directly pay from their channel balance. + case _: LiquidityAds.PaymentDetails.FromFutureHtlc | _: LiquidityAds.PaymentDetails.FromFutureHtlcWithPreimage => nextRemoteBalance < l.fees.total + } + case _ => true + } if (fundingParams.fundingAmount < fundingParams.dustLimit) { replyTo ! LocalFailure(FundingAmountTooLow(channelParams.channelId, fundingParams.fundingAmount, fundingParams.dustLimit)) Behaviors.stopped } else if (nextLocalBalance < 0.msat || nextRemoteBalance < 0.msat) { replyTo ! LocalFailure(InvalidFundingBalances(channelParams.channelId, fundingParams.fundingAmount, nextLocalBalance, nextRemoteBalance)) Behaviors.stopped + } else if (!liquidityPaymentTypeOk) { + replyTo ! LocalFailure(InvalidLiquidityAdsPaymentType(channelParams.channelId, liquidityPurchase_opt.get.paymentDetails.paymentType, Set(LiquidityAds.PaymentType.FromChannelBalance, LiquidityAds.PaymentType.FromChannelBalanceForFutureHtlc))) + Behaviors.stopped } else { val actor = new InteractiveTxBuilder(replyTo, sessionId, nodeParams, channelParams, fundingParams, purpose, localPushAmount, remotePushAmount, liquidityPurchase_opt, wallet, stash, context) actor.start() @@ -805,7 +817,7 @@ private class InteractiveTxBuilder(replyTo: ActorRef[InteractiveTxBuilder.Respon Behaviors.receiveMessagePartial { case SignTransactionResult(signedTx) => log.info(s"interactive-tx txid=${signedTx.txId} partially signed with {} local inputs, {} remote inputs, {} local outputs and {} remote outputs", signedTx.tx.localInputs.length, signedTx.tx.remoteInputs.length, signedTx.tx.localOutputs.length, signedTx.tx.remoteOutputs.length) - replyTo ! Succeeded(InteractiveTxSigningSession.WaitingForSigs(fundingParams, purpose.fundingTxIndex, signedTx, Left(localCommit), remoteCommit), commitSig) + replyTo ! Succeeded(InteractiveTxSigningSession.WaitingForSigs(fundingParams, purpose.fundingTxIndex, signedTx, Left(localCommit), remoteCommit), commitSig, liquidityPurchase_opt) Behaviors.stopped case WalletFailure(t) => log.error("could not sign funding transaction: ", t) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala index 9713cfbf1b..f145785ce9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala @@ -44,6 +44,7 @@ trait Databases { def peers: PeersDb def payments: PaymentsDb def pendingCommands: PendingCommandsDb + def onTheFlyFunding: OnTheFlyFundingDb //@formatter:on } @@ -65,6 +66,7 @@ object Databases extends Logging { peers: SqlitePeersDb, payments: SqlitePaymentsDb, pendingCommands: SqlitePendingCommandsDb, + onTheFlyFunding: SqliteOnTheFlyFundingDb, private val backupConnection: Connection) extends Databases with FileBackup { override def backup(backupFile: File): Unit = SqliteUtils.using(backupConnection.createStatement()) { statement => { @@ -83,6 +85,7 @@ object Databases extends Logging { peers = new SqlitePeersDb(eclairJdbc), payments = new SqlitePaymentsDb(eclairJdbc), pendingCommands = new SqlitePendingCommandsDb(eclairJdbc), + onTheFlyFunding = new SqliteOnTheFlyFundingDb(eclairJdbc), backupConnection = eclairJdbc ) } @@ -94,6 +97,7 @@ object Databases extends Logging { peers: PgPeersDb, payments: PgPaymentsDb, pendingCommands: PgPendingCommandsDb, + onTheFlyFunding: PgOnTheFlyFundingDb, dataSource: HikariDataSource, lock: PgLock) extends Databases with ExclusiveLock { override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock(dataSource) @@ -154,6 +158,7 @@ object Databases extends Logging { peers = new PgPeersDb, payments = new PgPaymentsDb, pendingCommands = new PgPendingCommandsDb, + onTheFlyFunding = new PgOnTheFlyFundingDb, dataSource = ds, lock = lock) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala index 0a66c56b43..87234cddb9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala @@ -89,6 +89,9 @@ class DbEventHandler(nodeParams: NodeParams) extends Actor with DiagnosticActorL case ChannelPaymentRelayed(_, _, _, fromChannelId, toChannelId, _, _) => channelsDb.updateChannelMeta(fromChannelId, ChannelEvent.EventType.PaymentReceived) channelsDb.updateChannelMeta(toChannelId, ChannelEvent.EventType.PaymentSent) + case OnTheFlyFundingPaymentRelayed(_, incoming, outgoing) => + incoming.foreach(p => channelsDb.updateChannelMeta(p.channelId, ChannelEvent.EventType.PaymentReceived)) + outgoing.foreach(p => channelsDb.updateChannelMeta(p.channelId, ChannelEvent.EventType.PaymentSent)) } auditDb.add(e) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index bf6eb6e669..7b6a906768 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -9,6 +9,7 @@ import fr.acinq.eclair.db.Databases.{FileBackup, PostgresDatabases, SqliteDataba import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.DualDatabases.runAsync import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.relay.OnTheFlyFunding import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} @@ -30,16 +31,12 @@ import scala.util.{Failure, Success, Try} case class DualDatabases(primary: Databases, secondary: Databases) extends Databases with FileBackup { override val network: NetworkDb = DualNetworkDb(primary.network, secondary.network) - override val audit: AuditDb = DualAuditDb(primary.audit, secondary.audit) - override val channels: ChannelsDb = DualChannelsDb(primary.channels, secondary.channels) - override val peers: PeersDb = DualPeersDb(primary.peers, secondary.peers) - override val payments: PaymentsDb = DualPaymentsDb(primary.payments, secondary.payments) - override val pendingCommands: PendingCommandsDb = DualPendingCommandsDb(primary.pendingCommands, secondary.pendingCommands) + override val onTheFlyFunding: OnTheFlyFundingDb = DualOnTheFlyFundingDb(primary.onTheFlyFunding, secondary.onTheFlyFunding) /** if one of the database supports file backup, we use it */ override def backup(backupFile: File): Unit = (primary, secondary) match { @@ -411,3 +408,38 @@ case class DualPendingCommandsDb(primary: PendingCommandsDb, secondary: PendingC primary.listSettlementCommands() } } + +case class DualOnTheFlyFundingDb(primary: OnTheFlyFundingDb, secondary: OnTheFlyFundingDb) extends OnTheFlyFundingDb { + + private implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("db-on-the-fly-funding").build())) + + override def addPreimage(preimage: ByteVector32): Unit = { + runAsync(secondary.addPreimage(preimage)) + primary.addPreimage(preimage) + } + + override def getPreimage(paymentHash: ByteVector32): Option[ByteVector32] = { + runAsync(secondary.getPreimage(paymentHash)) + primary.getPreimage(paymentHash) + } + + override def addPending(remoteNodeId: PublicKey, pending: OnTheFlyFunding.Pending): Unit = { + runAsync(secondary.addPending(remoteNodeId, pending)) + primary.addPending(remoteNodeId, pending) + } + + override def removePending(remoteNodeId: PublicKey, paymentHash: ByteVector32): Unit = { + runAsync(secondary.removePending(remoteNodeId, paymentHash)) + primary.removePending(remoteNodeId, paymentHash) + } + + override def listPending(remoteNodeId: PublicKey): Map[ByteVector32, OnTheFlyFunding.Pending] = { + runAsync(secondary.listPending(remoteNodeId)) + primary.listPending(remoteNodeId) + } + + override def listPendingPayments(): Map[PublicKey, Set[ByteVector32]] = { + runAsync(secondary.listPendingPayments()) + primary.listPendingPayments() + } +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/OnTheFlyFundingDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/OnTheFlyFundingDb.scala new file mode 100644 index 0000000000..ee729d3bf8 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/OnTheFlyFundingDb.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db + +import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.payment.relay.OnTheFlyFunding + +/** + * Created by t-bast on 25/06/2024. + */ + +trait OnTheFlyFundingDb { + + /** When we receive the preimage for an on-the-fly payment, we save it to protect against replays. */ + def addPreimage(preimage: ByteVector32): Unit + + /** Check if we received the preimage for the given payment hash. */ + def getPreimage(paymentHash: ByteVector32): Option[ByteVector32] + + /** We save funded proposals to allow completing the payment after a disconnection or a restart. */ + def addPending(remoteNodeId: PublicKey, pending: OnTheFlyFunding.Pending): Unit + + /** Once complete (succeeded or failed), we forget the pending on-the-fly funding proposal. */ + def removePending(remoteNodeId: PublicKey, paymentHash: ByteVector32): Unit + + /** List pending proposals we funded for the given remote node. */ + def listPending(remoteNodeId: PublicKey): Map[ByteVector32, OnTheFlyFunding.Pending] + + /** List the payment_hashes of pending proposals we funded for all remote nodes. */ + def listPendingPayments(): Map[PublicKey, Set[ByteVector32]] + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index b7af1fa566..95d2c19855 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -238,7 +238,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { val payments = e match { case ChannelPaymentRelayed(amountIn, amountOut, _, fromChannelId, toChannelId, startedAt, settledAt) => // non-trampoline relayed payments have one input and one output - Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", startedAt), RelayedPart(toChannelId, amountOut, "OUT", "channel", settledAt)) + val in = Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", startedAt)) + val out = Seq(RelayedPart(toChannelId, amountOut, "OUT", "channel", settledAt)) + in ++ out case TrampolinePaymentRelayed(_, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount) => using(pg.prepareStatement("INSERT INTO audit.relayed_trampoline VALUES (?, ?, ?, ?)")) { statement => statement.setString(1, e.paymentHash.toHex) @@ -248,7 +250,13 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate() } // trampoline relayed payments do MPP aggregation and may have M inputs and N outputs - incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", i.receivedAt)) ++ outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", o.settledAt)) + val in = incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", i.receivedAt)) + val out = outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", o.settledAt)) + in ++ out + case OnTheFlyFundingPaymentRelayed(_, incoming, outgoing) => + val in = incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "on-the-fly-funding", i.receivedAt)) + val out = outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "on-the-fly-funding", o.settledAt)) + in ++ out } for (p <- payments) { using(pg.prepareStatement("INSERT INTO audit.relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => @@ -453,6 +461,8 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { case Some(RelayedPart(_, _, _, "trampoline", _)) => val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount) :: Nil + case Some(RelayedPart(_, _, _, "on-the-fly-funding", _)) => + Seq(OnTheFlyFundingPaymentRelayed(paymentHash, incoming, outgoing)) case _ => Nil } }.toSeq.sortBy(_.timestamp) @@ -480,10 +490,21 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { } override def stats(from: TimestampMilli, to: TimestampMilli): Seq[Stats] = { - val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => - feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) - } case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) + + def aggregateRelayStats(previous: Map[ByteVector32, Seq[Relayed]], incoming: PaymentRelayed.Incoming, outgoing: PaymentRelayed.Outgoing): Map[ByteVector32, Seq[Relayed]] = { + // We ensure trampoline payments are counted only once per channel and per direction (if multiple HTLCs were sent + // from/to the same channel, we group them). + val amountIn = incoming.map(_.amount).sum + val amountOut = outgoing.map(_.amount).sum + val in = incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq + val out = outgoing.groupBy(_.channelId).map { case (channelId, parts) => + val fee = (amountIn - amountOut) * parts.length / outgoing.length // we split the fee among outgoing channels + (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) + }.toSeq + (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } + } + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) => // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. val current = e match { @@ -492,17 +513,17 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)), ) case t: TrampolinePaymentRelayed => - // We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were - // sent from/to the same channel, we group them). - val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq - val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) => - val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels - (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) - }.toSeq - (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } + aggregateRelayStats(previous, t.incoming, t.outgoing) + case f: OnTheFlyFundingPaymentRelayed => + aggregateRelayStats(previous, f.incoming, f.outgoing) } previous ++ current } + + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => + feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) + } + // Channels opened by our peers won't have any network fees paid by us, but we still want to compute stats for them. val allChannels = networkFees.keySet ++ relayed.keySet allChannels.toSeq.flatMap(channelId => { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgOnTheFlyFundingDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgOnTheFlyFundingDb.scala new file mode 100644 index 0000000000..b5c59f9769 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgOnTheFlyFundingDb.scala @@ -0,0 +1,147 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db.pg + +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, TxId} +import fr.acinq.eclair.MilliSatoshiLong +import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics +import fr.acinq.eclair.db.Monitoring.Tags.DbBackends +import fr.acinq.eclair.db.OnTheFlyFundingDb +import fr.acinq.eclair.db.pg.PgUtils.PgLock +import fr.acinq.eclair.payment.relay.OnTheFlyFunding +import scodec.bits.BitVector + +import java.sql.Timestamp +import java.time.Instant +import javax.sql.DataSource + +/** + * Created by t-bast on 25/06/2024. + */ + +object PgOnTheFlyFundingDb { + val DB_NAME = "on_the_fly_funding" + val CURRENT_VERSION = 1 +} + +class PgOnTheFlyFundingDb(implicit ds: DataSource, lock: PgLock) extends OnTheFlyFundingDb { + + import PgOnTheFlyFundingDb._ + import PgUtils.ExtendedResultSet._ + import PgUtils._ + import lock._ + + inTransaction { pg => + using(pg.createStatement()) { statement => + getVersion(statement, DB_NAME) match { + case None => + statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS on_the_fly_funding") + statement.executeUpdate("CREATE TABLE on_the_fly_funding.preimages (payment_hash TEXT NOT NULL PRIMARY KEY, preimage TEXT NOT NULL, received_at TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE on_the_fly_funding.pending (remote_node_id TEXT NOT NULL, payment_hash TEXT NOT NULL, channel_id TEXT NOT NULL, tx_id TEXT NOT NULL, funding_tx_index BIGINT NOT NULL, remaining_fees_msat BIGINT NOT NULL, proposed BYTEA NOT NULL, funded_at TIMESTAMP WITH TIME ZONE NOT NULL, PRIMARY KEY (remote_node_id, payment_hash))") + case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do + case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") + } + setVersion(statement, DB_NAME, CURRENT_VERSION) + } + } + + override def addPreimage(preimage: ByteVector32): Unit = withMetrics("on-the-fly-funding/add-preimage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("INSERT INTO on_the_fly_funding.preimages (payment_hash, preimage, received_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => + statement.setString(1, Crypto.sha256(preimage).toHex) + statement.setString(2, preimage.toHex) + statement.setTimestamp(3, Timestamp.from(Instant.now())) + statement.executeUpdate() + } + } + } + + override def getPreimage(paymentHash: ByteVector32): Option[ByteVector32] = withMetrics("on-the-fly-funding/get-preimage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("SELECT preimage FROM on_the_fly_funding.preimages WHERE payment_hash = ?")) { statement => + statement.setString(1, paymentHash.toHex) + statement.executeQuery().map { rs => rs.getByteVector32FromHex("preimage") }.lastOption + } + } + } + + override def addPending(remoteNodeId: Crypto.PublicKey, pending: OnTheFlyFunding.Pending): Unit = withMetrics("on-the-fly-funding/add-pending", DbBackends.Postgres) { + pending.status match { + case _: OnTheFlyFunding.Status.Proposed => () + case status: OnTheFlyFunding.Status.Funded => withLock { pg => + using(pg.prepareStatement("INSERT INTO on_the_fly_funding.pending (remote_node_id, payment_hash, channel_id, tx_id, funding_tx_index, remaining_fees_msat, proposed, funded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING")) { statement => + statement.setString(1, remoteNodeId.toHex) + statement.setString(2, pending.paymentHash.toHex) + statement.setString(3, status.channelId.toHex) + statement.setString(4, status.txId.value.toHex) + statement.setLong(5, status.fundingTxIndex) + statement.setLong(6, status.remainingFees.toLong) + statement.setBytes(7, OnTheFlyFunding.Codecs.proposals.encode(pending.proposed).require.bytes.toArray) + statement.setTimestamp(8, Timestamp.from(Instant.now())) + statement.executeUpdate() + } + } + } + } + + override def removePending(remoteNodeId: Crypto.PublicKey, paymentHash: ByteVector32): Unit = withMetrics("on-the-fly-funding/remove-pending", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("DELETE FROM on_the_fly_funding.pending WHERE remote_node_id = ? AND payment_hash = ?")) { statement => + statement.setString(1, remoteNodeId.toHex) + statement.setString(2, paymentHash.toHex) + statement.executeUpdate() + } + } + } + + override def listPending(remoteNodeId: Crypto.PublicKey): Map[ByteVector32, OnTheFlyFunding.Pending] = withMetrics("on-the-fly-funding/list-pending", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("SELECT * FROM on_the_fly_funding.pending WHERE remote_node_id = ?")) { statement => + statement.setString(1, remoteNodeId.toHex) + statement.executeQuery().map { rs => + val paymentHash = rs.getByteVector32FromHex("payment_hash") + val pending = OnTheFlyFunding.Pending( + proposed = OnTheFlyFunding.Codecs.proposals.decode(BitVector(rs.getBytes("proposed"))).require.value, + status = OnTheFlyFunding.Status.Funded( + channelId = rs.getByteVector32FromHex("channel_id"), + txId = TxId(rs.getByteVector32FromHex("tx_id")), + fundingTxIndex = rs.getLong("funding_tx_index"), + remainingFees = rs.getLong("remaining_fees_msat").msat + ) + ) + paymentHash -> pending + }.toMap + } + } + } + + override def listPendingPayments(): Map[Crypto.PublicKey, Set[ByteVector32]] = withMetrics("on-the-fly-funding/list-pending-payments", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("SELECT remote_node_id, payment_hash FROM on_the_fly_funding.pending")) { statement => + statement.executeQuery().map { rs => + val remoteNodeId = PublicKey(rs.getByteVectorFromHex("remote_node_id")) + val paymentHash = rs.getByteVector32FromHex("payment_hash") + remoteNodeId -> paymentHash + }.groupMap(_._1)(_._2).map { + case (remoteNodeId, payments) => remoteNodeId -> payments.toSet + } + } + } + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 9b418451bc..ec1914e290 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -226,7 +226,9 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { val payments = e match { case ChannelPaymentRelayed(amountIn, amountOut, _, fromChannelId, toChannelId, startedAt, settledAt) => // non-trampoline relayed payments have one input and one output - Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", startedAt), RelayedPart(toChannelId, amountOut, "OUT", "channel", settledAt)) + val in = Seq(RelayedPart(fromChannelId, amountIn, "IN", "channel", startedAt)) + val out = Seq(RelayedPart(toChannelId, amountOut, "OUT", "channel", settledAt)) + in ++ out case TrampolinePaymentRelayed(_, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount) => using(sqlite.prepareStatement("INSERT INTO relayed_trampoline VALUES (?, ?, ?, ?)")) { statement => statement.setBytes(1, e.paymentHash.toArray) @@ -236,8 +238,13 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate() } // trampoline relayed payments do MPP aggregation and may have M inputs and N outputs - incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", i.receivedAt)) ++ - outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", o.settledAt)) + val in = incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "trampoline", i.receivedAt)) + val out = outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "trampoline", o.settledAt)) + in ++ out + case OnTheFlyFundingPaymentRelayed(_, incoming, outgoing) => + val in = incoming.map(i => RelayedPart(i.channelId, i.amount, "IN", "on-the-fly-funding", i.receivedAt)) + val out = outgoing.map(o => RelayedPart(o.channelId, o.amount, "OUT", "on-the-fly-funding", o.settledAt)) + in ++ out } for (p <- payments) { using(sqlite.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => @@ -423,6 +430,8 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { case Some(RelayedPart(_, _, _, "trampoline", _)) => val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount) :: Nil + case Some(RelayedPart(_, _, _, "on-the-fly-funding", _)) => + Seq(OnTheFlyFundingPaymentRelayed(paymentHash, incoming, outgoing)) case _ => Nil } }.toSeq.sortBy(_.timestamp) @@ -449,10 +458,21 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { } override def stats(from: TimestampMilli, to: TimestampMilli): Seq[Stats] = { - val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => - feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) - } case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) + + def aggregateRelayStats(previous: Map[ByteVector32, Seq[Relayed]], incoming: PaymentRelayed.Incoming, outgoing: PaymentRelayed.Outgoing): Map[ByteVector32, Seq[Relayed]] = { + // We ensure trampoline payments are counted only once per channel and per direction (if multiple HTLCs were sent + // from/to the same channel, we group them). + val amountIn = incoming.map(_.amount).sum + val amountOut = outgoing.map(_.amount).sum + val in = incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq + val out = outgoing.groupBy(_.channelId).map { case (channelId, parts) => + val fee = (amountIn - amountOut) * parts.length / outgoing.length // we split the fee among outgoing channels + (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) + }.toSeq + (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } + } + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) => // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. val current = e match { @@ -461,17 +481,17 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { c.toChannelId -> (Relayed(c.amountOut, c.amountIn - c.amountOut, "OUT") +: previous.getOrElse(c.toChannelId, Nil)), ) case t: TrampolinePaymentRelayed => - // We ensure a trampoline payment is counted only once per channel and per direction (if multiple HTLCs were - // sent from/to the same channel, we group them). - val in = t.incoming.groupBy(_.channelId).map { case (channelId, parts) => (channelId, Relayed(parts.map(_.amount).sum, 0 msat, "IN")) }.toSeq - val out = t.outgoing.groupBy(_.channelId).map { case (channelId, parts) => - val fee = (t.amountIn - t.amountOut) * parts.length / t.outgoing.length // we split the fee among outgoing channels - (channelId, Relayed(parts.map(_.amount).sum, fee, "OUT")) - }.toSeq - (in ++ out).groupBy(_._1).map { case (channelId, payments) => (channelId, payments.map(_._2) ++ previous.getOrElse(channelId, Nil)) } + aggregateRelayStats(previous, t.incoming, t.outgoing) + case f: OnTheFlyFundingPaymentRelayed => + aggregateRelayStats(previous, f.incoming, f.outgoing) } previous ++ current } + + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => + feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) + } + // Channels opened by our peers won't have any network fees paid by us, but we still want to compute stats for them. val allChannels = networkFees.keySet ++ relayed.keySet allChannels.toSeq.flatMap(channelId => { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteOnTheFlyFundingDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteOnTheFlyFundingDb.scala new file mode 100644 index 0000000000..7ed667179c --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteOnTheFlyFundingDb.scala @@ -0,0 +1,129 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db.sqlite + +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, TxId} +import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics +import fr.acinq.eclair.db.Monitoring.Tags.DbBackends +import fr.acinq.eclair.db.OnTheFlyFundingDb +import fr.acinq.eclair.payment.relay.OnTheFlyFunding +import fr.acinq.eclair.{MilliSatoshiLong, TimestampMilli} +import scodec.bits.BitVector + +import java.sql.Connection + +/** + * Created by t-bast on 25/06/2024. + */ + +object SqliteOnTheFlyFundingDb { + val DB_NAME = "on_the_fly_funding" + val CURRENT_VERSION = 1 +} + +class SqliteOnTheFlyFundingDb(val sqlite: Connection) extends OnTheFlyFundingDb { + + import SqliteOnTheFlyFundingDb._ + import SqliteUtils.ExtendedResultSet._ + import SqliteUtils._ + + using(sqlite.createStatement(), inTransaction = true) { statement => + getVersion(statement, DB_NAME) match { + case None => + statement.executeUpdate("CREATE TABLE on_the_fly_funding_preimages (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, received_at INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE on_the_fly_funding_pending (remote_node_id BLOB NOT NULL, payment_hash BLOB NOT NULL, channel_id BLOB NOT NULL, tx_id BLOB NOT NULL, funding_tx_index INTEGER NOT NULL, remaining_fees_msat INTEGER NOT NULL, proposed BLOB NOT NULL, funded_at INTEGER NOT NULL, PRIMARY KEY (remote_node_id, payment_hash))") + case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do + case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") + } + setVersion(statement, DB_NAME, CURRENT_VERSION) + } + + override def addPreimage(preimage: ByteVector32): Unit = withMetrics("on-the-fly-funding/add-preimage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("INSERT OR IGNORE INTO on_the_fly_funding_preimages (payment_hash, preimage, received_at) VALUES (?, ?, ?)")) { statement => + statement.setBytes(1, Crypto.sha256(preimage).toArray) + statement.setBytes(2, preimage.toArray) + statement.setLong(3, TimestampMilli.now().toLong) + statement.executeUpdate() + } + } + + override def getPreimage(paymentHash: ByteVector32): Option[ByteVector32] = withMetrics("on-the-fly-funding/get-preimage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT preimage FROM on_the_fly_funding_preimages WHERE payment_hash = ?")) { statement => + statement.setBytes(1, paymentHash.toArray) + statement.executeQuery().map { rs => rs.getByteVector32("preimage") }.lastOption + } + } + + override def addPending(remoteNodeId: Crypto.PublicKey, pending: OnTheFlyFunding.Pending): Unit = withMetrics("on-the-fly-funding/add-pending", DbBackends.Sqlite) { + pending.status match { + case _: OnTheFlyFunding.Status.Proposed => () + case status: OnTheFlyFunding.Status.Funded => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO on_the_fly_funding_pending (remote_node_id, payment_hash, channel_id, tx_id, funding_tx_index, remaining_fees_msat, proposed, funded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, remoteNodeId.value.toArray) + statement.setBytes(2, pending.paymentHash.toArray) + statement.setBytes(3, status.channelId.toArray) + statement.setBytes(4, status.txId.value.toArray) + statement.setLong(5, status.fundingTxIndex) + statement.setLong(6, status.remainingFees.toLong) + statement.setBytes(7, OnTheFlyFunding.Codecs.proposals.encode(pending.proposed).require.bytes.toArray) + statement.setLong(8, TimestampMilli.now().toLong) + statement.executeUpdate() + } + } + } + + override def removePending(remoteNodeId: Crypto.PublicKey, paymentHash: ByteVector32): Unit = withMetrics("on-the-fly-funding/remove-pending", DbBackends.Sqlite) { + using(sqlite.prepareStatement("DELETE FROM on_the_fly_funding_pending WHERE remote_node_id = ? AND payment_hash = ?")) { statement => + statement.setBytes(1, remoteNodeId.value.toArray) + statement.setBytes(2, paymentHash.toArray) + statement.executeUpdate() + } + } + + override def listPending(remoteNodeId: Crypto.PublicKey): Map[ByteVector32, OnTheFlyFunding.Pending] = withMetrics("on-the-fly-funding/list-pending", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT * FROM on_the_fly_funding_pending WHERE remote_node_id = ?")) { statement => + statement.setBytes(1, remoteNodeId.value.toArray) + statement.executeQuery().map { rs => + val paymentHash = rs.getByteVector32("payment_hash") + val pending = OnTheFlyFunding.Pending( + proposed = OnTheFlyFunding.Codecs.proposals.decode(BitVector(rs.getBytes("proposed"))).require.value, + status = OnTheFlyFunding.Status.Funded( + channelId = rs.getByteVector32("channel_id"), + txId = TxId(rs.getByteVector32("tx_id")), + fundingTxIndex = rs.getLong("funding_tx_index"), + remainingFees = rs.getLong("remaining_fees_msat").msat + ) + ) + paymentHash -> pending + }.toMap + } + } + + override def listPendingPayments(): Map[Crypto.PublicKey, Set[ByteVector32]] = withMetrics("on-the-fly-funding/list-pending-payments", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT remote_node_id, payment_hash FROM on_the_fly_funding_pending")) { statement => + statement.executeQuery().map { rs => + val remoteNodeId = PublicKey(rs.getByteVector("remote_node_id")) + val paymentHash = rs.getByteVector32("payment_hash") + remoteNodeId -> paymentHash + }.groupMap(_._1)(_._2).map { + case (remoteNodeId, payments) => remoteNodeId -> payments.toSet + } + } + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 841b10a2f0..fa93848203 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -148,7 +148,7 @@ private class MessageRelay(nodeParams: NodeParams, waitForConnection(msg, nodeId) } case EncodedNodeId.WithPublicKey.Wallet(nodeId) => - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.onTheFlyFundingConfig.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) waitForWalletNodeUp(msg, nodeId) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Monitoring.scala index 71140c1144..d60a10cfe7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Monitoring.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.io +import fr.acinq.eclair.payment.relay.OnTheFlyFunding import kamon.Kamon object Monitoring { @@ -36,6 +37,9 @@ object Monitoring { val IncomingConnectionsNoChannels = Kamon.gauge("incomingconnections.nochannels") val IncomingConnectionsDisconnected = Kamon.counter("incomingconnections.disconnected") + + val OnTheFlyFunding = Kamon.counter("on-the-fly-funding.attempts") + val OnTheFlyFundingFees = Kamon.histogram("on-the-fly-funding.fees-msat") } object Tags { @@ -64,6 +68,18 @@ object Monitoring { val NoChannelWithNextPeer = "NoChannelWithNextPeer" val ConnectionFailure = "ConnectionFailure" } + + val OnTheFlyFundingState = "state" + object OnTheFlyFundingStates { + val Proposed = "proposed" + val Rejected = "rejected" + val Expired = "expired" + val Timeout = "timeout" + val Funded = "funded" + val RelaySucceeded = "relay-succeeded" + + def relayFailed(failure: OnTheFlyFunding.PaymentRelayer.RelayFailure) = s"relay-failed-${failure.getClass.getSimpleName}" + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/OpenChannelInterceptor.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/OpenChannelInterceptor.scala index e3310f7711..9ab2627f58 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/OpenChannelInterceptor.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/OpenChannelInterceptor.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.io.Peer.{OpenChannelResponse, SpawnChannelNonInitiator} import fr.acinq.eclair.io.PendingChannelsRateLimiter.AddOrRejectChannel import fr.acinq.eclair.wire.protocol -import fr.acinq.eclair.wire.protocol.{Error, NodeAddress} +import fr.acinq.eclair.wire.protocol.{Error, LiquidityAds, NodeAddress} import fr.acinq.eclair.{AcceptOpenChannel, CltvExpiryDelta, Features, InitFeature, InterceptOpenChannelPlugin, InterceptOpenChannelReceived, InterceptOpenChannelResponse, Logs, MilliSatoshi, NodeParams, RejectOpenChannel, ToMilliSatoshiConversion} import scodec.bits.ByteVector @@ -63,6 +63,8 @@ object OpenChannelInterceptor { private sealed trait CheckRateLimitsCommands extends Command private case class PendingChannelsRateLimiterResponse(response: PendingChannelsRateLimiter.Response) extends CheckRateLimitsCommands + private case class WrappedPeerChannels(channels: Seq[Peer.ChannelInfo]) extends Command + private sealed trait QueryPluginCommands extends Command private case class PluginOpenChannelResponse(pluginResponse: InterceptOpenChannelResponse) extends QueryPluginCommands private case object PluginTimeout extends QueryPluginCommands @@ -176,9 +178,17 @@ private class OpenChannelInterceptor(peer: ActorRef[Any], nodeParams.pluginOpenChannelInterceptor match { case Some(plugin) => queryPlugin(plugin, request, localParams, ChannelConfig.standard, channelType) case None => - // We don't honor liquidity ads for new channels: we let the node operator's plugin decide what to do. - peer ! SpawnChannelNonInitiator(request.open, ChannelConfig.standard, channelType, addFunding_opt = None, localParams, request.peerConnection.toClassic) - waitForRequest() + request.open.fold(_ => None, _.requestFunding_opt) match { + case Some(requestFunding) if request.localFeatures.hasFeature(Features.OnTheFlyFunding) && request.remoteFeatures.hasFeature(Features.OnTheFlyFunding) && request.channelFlags.nonInitiatorPaysCommitFees => + val addFunding = LiquidityAds.AddFunding(requestFunding.requestedAmount, nodeParams.willFundRates_opt) + val localParams1 = localParams.copy(paysCommitTxFees = true) + val accept = SpawnChannelNonInitiator(request.open, ChannelConfig.standard, channelType, Some(addFunding), localParams1, request.peerConnection.toClassic) + checkNoExistingChannel(request, accept) + case _ => + // We don't honor liquidity ads for new channels: node operators should use plugin for that. + peer ! SpawnChannelNonInitiator(request.open, ChannelConfig.standard, channelType, addFunding_opt = None, localParams, request.peerConnection.toClassic) + waitForRequest() + } } case PendingChannelsRateLimiterResponse(PendingChannelsRateLimiter.ChannelRateLimited) => context.log.warn(s"ignoring remote channel open: rate limited") @@ -187,6 +197,25 @@ private class OpenChannelInterceptor(peer: ActorRef[Any], } } + /** + * In some cases we want to reject additional channels when we already have one: it is usually better to splice the + * existing channel instead of opening another one. + */ + private def checkNoExistingChannel(request: OpenChannelNonInitiator, accept: SpawnChannelNonInitiator): Behavior[Command] = { + peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](r => WrappedPeerChannels(r.channels))) + receiveCommandMessage[WrappedPeerChannels](context, "checkNoExistingChannel") { + case WrappedPeerChannels(channels) => + if (channels.forall(isClosing)) { + peer ! accept + waitForRequest() + } else { + context.log.warn("we already have an active channel, so we won't accept another one: our peer should request a splice instead") + sendFailure("we already have an active channel: you should splice instead of requesting another channel", request) + waitForRequest() + } + } + } + private def queryPlugin(plugin: InterceptOpenChannelPlugin, request: OpenChannelInterceptor.OpenChannelNonInitiator, localParams: LocalParams, channelConfig: ChannelConfig, channelType: SupportedChannelType): Behavior[Command] = Behaviors.withTimers { timers => timers.startSingleTimer(PluginTimeout, pluginTimeout) @@ -210,6 +239,23 @@ private class OpenChannelInterceptor(peer: ActorRef[Any], } } + private def isClosing(channel: Peer.ChannelInfo): Boolean = channel.state match { + case CLOSED => true + case _ => channel.data match { + case _: TransientChannelData => false + case _: ChannelDataWithoutCommitments => false + case _: DATA_WAIT_FOR_FUNDING_CONFIRMED => false + case _: DATA_WAIT_FOR_CHANNEL_READY => false + case _: DATA_WAIT_FOR_DUAL_FUNDING_CONFIRMED => false + case _: DATA_WAIT_FOR_DUAL_FUNDING_READY => false + case _: DATA_NORMAL => false + case _: DATA_SHUTDOWN => true + case _: DATA_NEGOTIATING => true + case _: DATA_CLOSING => true + case _: DATA_WAIT_FOR_REMOTE_PUBLISH_FUTURE_COMMITMENT => true + } + } + private def sendFailure(error: String, request: OpenChannelNonInitiator): Unit = { peer ! Peer.OutgoingMessage(Error(request.temporaryChannelId, error), request.peerConnection.toClassic) context.system.eventStream ! Publish(ChannelAborted(actor.ActorRef.noSender, request.remoteNodeId, request.temporaryChannelId)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index dcc9fc6c9c..0f5fec711d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -23,24 +23,29 @@ import akka.event.Logging.MDC import akka.event.{BusLogging, DiagnosticLoggingAdapter} import akka.util.Timeout import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.bitcoin.scalacompat.{ByteVector32, Satoshi, SatoshiLong, TxId} +import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi, SatoshiLong, TxId} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.NotificationsLogger.NotifyNodeOperator import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher import fr.acinq.eclair.blockchain.fee.FeeratePerKw -import fr.acinq.eclair.blockchain.{CurrentFeerates, OnChainChannelFunder, OnchainPubkeyCache} +import fr.acinq.eclair.blockchain.{CurrentBlockHeight, CurrentFeerates, OnChainChannelFunder, OnchainPubkeyCache} import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.io.MessageRelay.Status -import fr.acinq.eclair.io.Monitoring.Metrics +import fr.acinq.eclair.io.Monitoring.{Metrics, Tags} import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChannelNonInitiator} import fr.acinq.eclair.io.PeerConnection.KillReason import fr.acinq.eclair.message.OnionMessages +import fr.acinq.eclair.payment.relay.OnTheFlyFunding +import fr.acinq.eclair.payment.{OnTheFlyFundingPaymentRelayed, PaymentRelayed} import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol -import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning} +import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure +import fr.acinq.eclair.wire.protocol.LiquidityAds.PaymentDetails +import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, RoutingMessage, SpliceInit, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} /** * This actor represents a logical peer. There is one [[Peer]] per unique remote node id at all time. @@ -63,7 +68,10 @@ class Peer(val nodeParams: NodeParams, import Peer._ + private var pendingOnTheFlyFunding = nodeParams.db.onTheFlyFunding.listPending(remoteNodeId) + context.system.eventStream.subscribe(self, classOf[CurrentFeerates]) + context.system.eventStream.subscribe(self, classOf[CurrentBlockHeight]) startWith(INSTANTIATING, Nothing) @@ -91,10 +99,10 @@ class Peer(val nodeParams: NodeParams, val channelIds = d.channels.filter(_._2 == actor).keys log.info(s"channel closed: channelId=${channelIds.mkString("/")}") val channels1 = d.channels -- channelIds - if (channels1.isEmpty) { + if (channels1.isEmpty && !pendingSignedOnTheFlyFunding()) { log.info("that was the last open channel") context.system.eventStream.publish(LastChannelClosed(self, remoteNodeId)) - // we have no existing channels, we can forget about this peer + // We have no existing channels or pending signed transaction, we can forget about this peer. stopPeer() } else { stay() using d.copy(channels = channels1) @@ -104,8 +112,8 @@ class Peer(val nodeParams: NodeParams, Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION))) { log.debug("connection lost while negotiating connection") } - if (d.channels.isEmpty) { - // we have no existing channels, we can forget about this peer + if (d.channels.isEmpty && !pendingSignedOnTheFlyFunding()) { + // We have no existing channels or pending signed transaction, we can forget about this peer. stopPeer() } else { stay() @@ -205,23 +213,152 @@ class Peer(val nodeParams: NodeParams, case Event(SpawnChannelNonInitiator(open, channelConfig, channelType, addFunding_opt, localParams, peerConnection), d: ConnectedData) => val temporaryChannelId = open.fold(_.temporaryChannelId, _.temporaryChannelId) if (peerConnection == d.peerConnection) { - val channel = spawnChannel() - log.info(s"accepting a new channel with type=$channelType temporaryChannelId=$temporaryChannelId localParams=$localParams") - open match { - case Left(open) => - channel ! INPUT_INIT_CHANNEL_NON_INITIATOR(open.temporaryChannelId, None, dualFunded = false, None, localParams, d.peerConnection, d.remoteInit, channelConfig, channelType) - channel ! open - case Right(open) => - channel ! INPUT_INIT_CHANNEL_NON_INITIATOR(open.temporaryChannelId, addFunding_opt, dualFunded = true, None, localParams, d.peerConnection, d.remoteInit, channelConfig, channelType) - channel ! open + OnTheFlyFunding.validateOpen(open, pendingOnTheFlyFunding) match { + case reject: OnTheFlyFunding.ValidationResult.Reject => + log.warning("rejecting on-the-fly channel: {}", reject.cancel.toAscii) + self ! Peer.OutgoingMessage(reject.cancel, d.peerConnection) + cancelUnsignedOnTheFlyFunding(reject.paymentHashes) + context.system.eventStream.publish(ChannelAborted(ActorRef.noSender, remoteNodeId, temporaryChannelId)) + stay() + case accept: OnTheFlyFunding.ValidationResult.Accept => + val channel = spawnChannel() + log.info(s"accepting a new channel with type=$channelType temporaryChannelId=$temporaryChannelId localParams=$localParams") + open match { + case Left(open) => + channel ! INPUT_INIT_CHANNEL_NON_INITIATOR(open.temporaryChannelId, None, dualFunded = false, None, localParams, d.peerConnection, d.remoteInit, channelConfig, channelType) + channel ! open + case Right(open) => + channel ! INPUT_INIT_CHANNEL_NON_INITIATOR(open.temporaryChannelId, addFunding_opt, dualFunded = true, None, localParams, d.peerConnection, d.remoteInit, channelConfig, channelType) + channel ! open + } + fulfillOnTheFlyFundingHtlcs(accept.preimages) + stay() using d.copy(channels = d.channels + (TemporaryChannelId(temporaryChannelId) -> channel)) } - stay() using d.copy(channels = d.channels + (TemporaryChannelId(temporaryChannelId) -> channel)) } else { log.warning("ignoring open_channel request that reconnected during channel intercept, temporaryChannelId={}", temporaryChannelId) context.system.eventStream.publish(ChannelAborted(ActorRef.noSender, remoteNodeId, temporaryChannelId)) stay() } + case Event(cmd: ProposeOnTheFlyFunding, d: ConnectedData) if !d.remoteFeatures.hasFeature(Features.OnTheFlyFunding) => + cmd.replyTo ! ProposeOnTheFlyFundingResponse.NotAvailable("peer does not support on-the-fly funding") + stay() + + case Event(cmd: ProposeOnTheFlyFunding, d: ConnectedData) => + // We send the funding proposal to our peer, and report it to the sender. + val htlc = WillAddHtlc(nodeParams.chainHash, randomBytes32(), cmd.amount, cmd.paymentHash, cmd.expiry, cmd.onion, cmd.nextBlindingKey_opt) + cmd.replyTo ! ProposeOnTheFlyFundingResponse.Proposed + // We update our list of pending proposals for that payment. + val pending = pendingOnTheFlyFunding.get(htlc.paymentHash) match { + case Some(pending) => + pending.status match { + case status: OnTheFlyFunding.Status.Proposed => + self ! Peer.OutgoingMessage(htlc, d.peerConnection) + // We extend the previous timer. + status.timer.cancel() + val timer = context.system.scheduler.scheduleOnce(nodeParams.onTheFlyFundingConfig.proposalTimeout, self, OnTheFlyFundingTimeout(cmd.paymentHash))(context.dispatcher) + pending.copy( + proposed = pending.proposed :+ OnTheFlyFunding.Proposal(htlc, cmd.upstream), + status = OnTheFlyFunding.Status.Proposed(timer) + ) + case status: OnTheFlyFunding.Status.Funded => + log.info("received extra payment for on-the-fly funding that has already been funded with txId={} (payment_hash={}, amount={})", status.txId, cmd.paymentHash, cmd.amount) + pending.copy(proposed = pending.proposed :+ OnTheFlyFunding.Proposal(htlc, cmd.upstream)) + } + case None => + self ! Peer.OutgoingMessage(htlc, d.peerConnection) + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.Proposed).increment() + val timer = context.system.scheduler.scheduleOnce(nodeParams.onTheFlyFundingConfig.proposalTimeout, self, OnTheFlyFundingTimeout(cmd.paymentHash))(context.dispatcher) + OnTheFlyFunding.Pending(Seq(OnTheFlyFunding.Proposal(htlc, cmd.upstream)), OnTheFlyFunding.Status.Proposed(timer)) + } + pendingOnTheFlyFunding += (htlc.paymentHash -> pending) + stay() + + case Event(msg: OnTheFlyFundingFailureMessage, d: ConnectedData) => + pendingOnTheFlyFunding.get(msg.paymentHash) match { + case Some(pending) => + pending.status match { + case status: OnTheFlyFunding.Status.Proposed => + pending.proposed.find(_.htlc.id == msg.id) match { + case Some(htlc) => + val failure = msg match { + case msg: WillFailHtlc => Left(msg.reason) + case msg: WillFailMalformedHtlc => Right(createBadOnionFailure(msg.onionHash, msg.failureCode)) + } + htlc.createFailureCommands(Some(failure)).foreach { case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } + val proposed1 = pending.proposed.filterNot(_.htlc.id == msg.id) + if (proposed1.isEmpty) { + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.Rejected).increment() + status.timer.cancel() + pendingOnTheFlyFunding -= msg.paymentHash + } else { + pendingOnTheFlyFunding += (msg.paymentHash -> pending.copy(proposed = proposed1)) + } + case None => + log.warning("ignoring will_fail_htlc: no matching proposal for id={}", msg.id) + self ! Peer.OutgoingMessage(Warning(s"ignoring will_fail_htlc: no matching proposal for id=${msg.id}"), d.peerConnection) + } + case status: OnTheFlyFunding.Status.Funded => + log.warning("ignoring will_fail_htlc: on-the-fly funding already signed with txId={}", status.txId) + self ! Peer.OutgoingMessage(Warning(s"ignoring will_fail_htlc: on-the-fly funding already signed with txId=${status.txId}"), d.peerConnection) + } + case None => + log.warning("ignoring will_fail_htlc: no matching proposal for payment_hash={}", msg.paymentHash) + self ! Peer.OutgoingMessage(Warning(s"ignoring will_fail_htlc: no matching proposal for payment_hash=${msg.paymentHash}"), d.peerConnection) + } + stay() + + case Event(timeout: OnTheFlyFundingTimeout, d: ConnectedData) => + pendingOnTheFlyFunding.get(timeout.paymentHash) match { + case Some(pending) => + pending.status match { + case _: OnTheFlyFunding.Status.Proposed => + log.warning("on-the-fly funding proposal timed out for payment_hash={}", timeout.paymentHash) + pending.createFailureCommands().foreach { case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.Expired).increment() + pendingOnTheFlyFunding -= timeout.paymentHash + self ! Peer.OutgoingMessage(Warning(s"on-the-fly funding proposal timed out for payment_hash=${timeout.paymentHash}"), d.peerConnection) + case status: OnTheFlyFunding.Status.Funded => + log.warning("ignoring on-the-fly funding proposal timeout, already funded with txId={}", status.txId) + } + case None => + log.debug("ignoring on-the-fly funding timeout for payment_hash={} (already completed)", timeout.paymentHash) + } + stay() + + case Event(msg: SpliceInit, d: ConnectedData) => + d.channels.get(FinalChannelId(msg.channelId)) match { + case Some(channel) => + OnTheFlyFunding.validateSplice(msg, nodeParams.channelConf.htlcMinimum, pendingOnTheFlyFunding) match { + case reject: OnTheFlyFunding.ValidationResult.Reject => + log.warning("rejecting on-the-fly splice: {}", reject.cancel.toAscii) + self ! Peer.OutgoingMessage(reject.cancel, d.peerConnection) + cancelUnsignedOnTheFlyFunding(reject.paymentHashes) + case accept: OnTheFlyFunding.ValidationResult.Accept => + fulfillOnTheFlyFundingHtlcs(accept.preimages) + channel forward msg + } + case None => replyUnknownChannel(d.peerConnection, msg.channelId) + } + stay() + + case Event(e: ChannelReadyForPayments, _: ConnectedData) => + pendingOnTheFlyFunding.foreach { + case (paymentHash, pending) => + pending.status match { + case _: OnTheFlyFunding.Status.Proposed => () + case status: OnTheFlyFunding.Status.Funded => + context.child(paymentHash.toHex) match { + case Some(_) => log.debug("already relaying payment_hash={}", paymentHash) + case None if e.fundingTxIndex < status.fundingTxIndex => log.debug("too early to relay payment_hash={}, funding not locked ({} < {})", paymentHash, e.fundingTxIndex, status.fundingTxIndex) + case None => + val relayer = context.spawn(Behaviors.supervise(OnTheFlyFunding.PaymentRelayer(nodeParams, remoteNodeId, e.channelId, paymentHash)).onFailure(typed.SupervisorStrategy.stop), paymentHash.toHex) + relayer ! OnTheFlyFunding.PaymentRelayer.TryRelay(self.toTyped, e.channel.toTyped, pending.proposed, status) + } + } + } + stay() + case Event(msg: HasChannelId, d: ConnectedData) => d.channels.get(FinalChannelId(msg.channelId)) match { case Some(channel) => channel forward msg @@ -257,8 +394,8 @@ class Peer(val nodeParams: NodeParams, Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION))) { log.debug("connection lost") } - if (d.channels.isEmpty) { - // we have no existing channels, we can forget about this peer + if (d.channels.isEmpty && !pendingSignedOnTheFlyFunding()) { + // We have no existing channels or pending signed transaction, we can forget about this peer. stopPeer() } else { d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) @@ -325,6 +462,10 @@ class Peer(val nodeParams: NodeParams, sender() ! Status.Failure(new RuntimeException("not connected")) stay() + case Event(r: Peer.ProposeOnTheFlyFunding, _) => + r.replyTo ! ProposeOnTheFlyFundingResponse.NotAvailable("peer not connected") + stay() + case Event(Disconnect(nodeId, replyTo_opt), _) => val replyTo = replyTo_opt.getOrElse(sender().toTyped) replyTo ! NotConnected(nodeId) @@ -354,6 +495,112 @@ class Peer(val nodeParams: NodeParams, } stay() + case Event(current: CurrentBlockHeight, d) => + // If we have pending will_add_htlc that are timing out, it doesn't make any sense to keep them, even if we have + // already funded the corresponding channel: our peer will force-close if we relay them. + // Our only option is to fail the upstream HTLCs to ensure that the upstream channels don't force-close. + // Note that we won't be paid for the liquidity we've provided, but we don't have a choice. + val expired = pendingOnTheFlyFunding.filter { + case (_, pending) => pending.proposed.exists(_.htlc.expiry.blockHeight <= current.blockHeight) + } + expired.foreach { + case (paymentHash, pending) => + log.warning("will_add_htlc expired for payment_hash={}, our peer may be malicious", paymentHash) + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.Timeout).increment() + pending.createFailureCommands().foreach { case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } + } + expired.foreach { + case (paymentHash, pending) => pending.status match { + case _: OnTheFlyFunding.Status.Proposed => () + case _: OnTheFlyFunding.Status.Funded => nodeParams.db.onTheFlyFunding.removePending(remoteNodeId, paymentHash) + } + } + pendingOnTheFlyFunding = pendingOnTheFlyFunding.removedAll(expired.keys) + d match { + case Peer.Nothing => stay() + case d if d.channels.isEmpty && pendingOnTheFlyFunding.isEmpty => stopPeer() + case _ => stay() + } + + case Event(e: LiquidityPurchaseSigned, _: ConnectedData) => + // We signed a liquidity purchase from our peer. At that point we're not 100% sure yet it will succeed: if + // we disconnect before our peer sends their signature, the funding attempt may be cancelled when reconnecting. + // If that happens, the on-the-fly proposal will stay in our state until we reach the CLTV expiry, at which + // point we will forget it and fail the upstream HTLCs. This is also what would happen if we successfully + // funded the channel, but it closed before we could relay the HTLCs. + val (paymentHashes, fees) = e.purchase.paymentDetails match { + case PaymentDetails.FromChannelBalance => (Nil, 0 sat) + case p: PaymentDetails.FromChannelBalanceForFutureHtlc => (p.paymentHashes, 0 sat) + case p: PaymentDetails.FromFutureHtlc => (p.paymentHashes, e.purchase.fees.total) + case p: PaymentDetails.FromFutureHtlcWithPreimage => (p.preimages.map(preimage => Crypto.sha256(preimage)), e.purchase.fees.total) + } + // We split the fees across payments. We could dynamically re-split depending on whether some payments are failed + // instead of fulfilled, but that's overkill: if our peer fails one of those payment, they're likely malicious + // and will fail anyway, even if we try to be clever with fees splitting. + var remainingFees = fees.toMilliSatoshi + pendingOnTheFlyFunding + .filter { case (paymentHash, _) => paymentHashes.contains(paymentHash) } + .values.toSeq + // In case our peer goes offline, we start with payments that are as far as possible from timing out. + .sortBy(_.expiry).reverse + .foreach(payment => { + payment.status match { + case status: OnTheFlyFunding.Status.Proposed => + status.timer.cancel() + val paymentFees = remainingFees.min(payment.maxFees(e.htlcMinimum)) + remainingFees -= paymentFees + log.info("liquidity purchase signed for payment_hash={}, waiting to relay HTLCs (txId={}, fundingTxIndex={}, fees={})", payment.paymentHash, e.txId, e.fundingTxIndex, paymentFees) + val payment1 = payment.copy(status = OnTheFlyFunding.Status.Funded(e.channelId, e.txId, e.fundingTxIndex, paymentFees)) + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.Funded).increment() + nodeParams.db.onTheFlyFunding.addPending(remoteNodeId, payment1) + pendingOnTheFlyFunding += payment.paymentHash -> payment1 + case status: OnTheFlyFunding.Status.Funded => + log.warning("liquidity purchase was already signed for payment_hash={} (previousTxId={}, currentTxId={})", payment.paymentHash, status.txId, e.txId) + } + }) + stay() + + case Event(e: OnTheFlyFunding.PaymentRelayer.RelayResult, _) => + e match { + case success: OnTheFlyFunding.PaymentRelayer.RelaySuccess => + pendingOnTheFlyFunding.get(success.paymentHash) match { + case Some(pending) => + log.info("successfully relayed on-the-fly HTLC for payment_hash={}", success.paymentHash) + // We've been paid for our liquidity fees: we can now fulfill upstream. + pending.createFulfillCommands(success.preimage).foreach { + case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) + } + // We emit a relay event: since we waited for on-chain funding before relaying the payment, the timestamps + // won't be accurate, but everything else is. + pending.proposed.foreach { + case OnTheFlyFunding.Proposal(htlc, upstream) => upstream match { + case _: Upstream.Local => () + case u: Upstream.Hot.Channel => + val incoming = PaymentRelayed.IncomingPart(u.add.amountMsat, u.add.channelId, u.receivedAt) + val outgoing = PaymentRelayed.OutgoingPart(htlc.amount, success.channelId, TimestampMilli.now()) + context.system.eventStream.publish(OnTheFlyFundingPaymentRelayed(htlc.paymentHash, Seq(incoming), Seq(outgoing))) + case u: Upstream.Hot.Trampoline => + val incoming = u.received.map(r => PaymentRelayed.IncomingPart(r.add.amountMsat, r.add.channelId, r.receivedAt)) + val outgoing = PaymentRelayed.OutgoingPart(htlc.amount, success.channelId, TimestampMilli.now()) + context.system.eventStream.publish(OnTheFlyFundingPaymentRelayed(htlc.paymentHash, incoming, Seq(outgoing))) + } + } + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.RelaySucceeded).increment() + Metrics.OnTheFlyFundingFees.withoutTags().record(success.fees.toLong) + nodeParams.db.onTheFlyFunding.removePending(remoteNodeId, success.paymentHash) + pendingOnTheFlyFunding -= success.paymentHash + case None => () + } + stay() + case OnTheFlyFunding.PaymentRelayer.RelayFailed(paymentHash, failure) => + log.warning("on-the-fly HTLC failure for payment_hash={}: {}", paymentHash, failure.toString) + Metrics.OnTheFlyFunding.withTag(Tags.OnTheFlyFundingState, Tags.OnTheFlyFundingStates.relayFailed(failure)).increment() + // We don't give up yet by relaying the failure upstream: we may have simply been disconnected, or the added + // liquidity may have been consumed by concurrent HTLCs. We'll retry at the next reconnection with that peer + // or after the next splice, and will only give up when the outgoing will_add_htlc timeout. + stay() + } + case Event(_: Peer.OutgoingMessage, _) => stay() // we got disconnected or reconnected and this message was for the previous connection case Event(RelayOnionMessage(messageId, _, replyTo_opt), _) => @@ -373,9 +620,11 @@ class Peer(val nodeParams: NodeParams, context.system.eventStream.publish(PeerConnected(self, remoteNodeId, nextStateData.asInstanceOf[Peer.ConnectedData].connectionInfo)) case CONNECTED -> CONNECTED => // connection switch context.system.eventStream.publish(PeerConnected(self, remoteNodeId, nextStateData.asInstanceOf[Peer.ConnectedData].connectionInfo)) + cancelUnsignedOnTheFlyFunding() case CONNECTED -> DISCONNECTED => Metrics.PeersConnected.withoutTags().decrement() context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId)) + cancelUnsignedOnTheFlyFunding() } onTermination { @@ -427,11 +676,50 @@ class Peer(val nodeParams: NodeParams, self ! Peer.OutgoingMessage(msg, peerConnection) } + private def cancelUnsignedOnTheFlyFunding(): Unit = cancelUnsignedOnTheFlyFunding(pendingOnTheFlyFunding.keySet) + + private def cancelUnsignedOnTheFlyFunding(paymentHashes: Set[ByteVector32]): Unit = { + val unsigned = pendingOnTheFlyFunding.filter { + case (paymentHash, pending) if paymentHashes.contains(paymentHash) => + pending.status match { + case status: OnTheFlyFunding.Status.Proposed => + status.timer.cancel() + true + case _: OnTheFlyFunding.Status.Funded => false + } + case _ => false + } + unsigned.foreach { + case (paymentHash, pending) => + log.info("cancelling on-the-fly funding for payment_hash={}", paymentHash) + pending.createFailureCommands().foreach { case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } + } + pendingOnTheFlyFunding = pendingOnTheFlyFunding.removedAll(unsigned.keys) + } + + private def fulfillOnTheFlyFundingHtlcs(preimages: Set[ByteVector32]): Unit = { + preimages.foreach(preimage => pendingOnTheFlyFunding.get(Crypto.sha256(preimage)) match { + case Some(pending) => pending.createFulfillCommands(preimage).foreach { case (channelId, cmd) => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } + case None => () + }) + } + + /** Return true if we have signed on-the-fly funding transactions and haven't settled the corresponding HTLCs yet. */ + private def pendingSignedOnTheFlyFunding(): Boolean = { + pendingOnTheFlyFunding.exists { + case (_, pending) => pending.status match { + case _: OnTheFlyFunding.Status.Proposed => false + case _: OnTheFlyFunding.Status.Funded => true + } + } + } + // resume the openChannelInterceptor in case of failure, we always want the open channel request to succeed or fail private val openChannelInterceptor = context.spawnAnonymous(Behaviors.supervise(OpenChannelInterceptor(context.self.toTyped, nodeParams, remoteNodeId, wallet, pendingChannelsRateLimiter)).onFailure(typed.SupervisorStrategy.resume)) private def stopPeer(): State = { log.info("removing peer from db") + cancelUnsignedOnTheFlyFunding() nodeParams.db.peers.removePeer(remoteNodeId) stop(FSM.Normal) } @@ -561,6 +849,20 @@ object Peer { case class SpawnChannelInitiator(replyTo: akka.actor.typed.ActorRef[OpenChannelResponse], cmd: Peer.OpenChannel, channelConfig: ChannelConfig, channelType: SupportedChannelType, localParams: LocalParams) case class SpawnChannelNonInitiator(open: Either[protocol.OpenChannel, protocol.OpenDualFundedChannel], channelConfig: ChannelConfig, channelType: SupportedChannelType, addFunding_opt: Option[LiquidityAds.AddFunding], localParams: LocalParams, peerConnection: ActorRef) + /** If [[Features.OnTheFlyFunding]] is supported and we're connected, relay a funding proposal to our peer. */ + case class ProposeOnTheFlyFunding(replyTo: typed.ActorRef[ProposeOnTheFlyFundingResponse], amount: MilliSatoshi, paymentHash: ByteVector32, expiry: CltvExpiry, onion: OnionRoutingPacket, nextBlindingKey_opt: Option[PublicKey], upstream: Upstream.Hot) + + sealed trait ProposeOnTheFlyFundingResponse + object ProposeOnTheFlyFundingResponse { + case object Proposed extends ProposeOnTheFlyFundingResponse + case class NotAvailable(reason: String) extends ProposeOnTheFlyFundingResponse + } + + /** We signed a funding transaction where our peer purchased some liquidity. */ + case class LiquidityPurchaseSigned(channelId: ByteVector32, txId: TxId, fundingTxIndex: Long, htlcMinimum: MilliSatoshi, purchase: LiquidityAds.Purchase) + + case class OnTheFlyFundingTimeout(paymentHash: ByteVector32) + case class GetPeerInfo(replyTo: Option[typed.ActorRef[PeerInfoResponse]]) sealed trait PeerInfoResponse { def nodeId: PublicKey } case class PeerInfo(peer: ActorRef, nodeId: PublicKey, state: State, address: Option[NodeAddress], channels: Set[ActorRef]) extends PeerInfoResponse @@ -571,7 +873,6 @@ object Peer { case class ChannelInfo(channel: typed.ActorRef[Command], state: ChannelState, data: ChannelData) case class PeerChannels(nodeId: PublicKey, channels: Seq[ChannelInfo]) - case class PeerRoutingMessage(peerConnection: ActorRef, remoteNodeId: PublicKey, message: RoutingMessage) extends RemoteTypes /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index 4f6e7e6cb1..2d388ec57d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -57,11 +57,16 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) log.info(s"closing channel ${c.channelId}") nodeParams.db.channels.removeChannel(c.channelId) }) - val peerChannels = channels.groupBy(_.remoteNodeId) - peerChannels.foreach { case (remoteNodeId, states) => createOrGetPeer(remoteNodeId, offlineChannels = states.toSet) } - log.info("restoring {} peer(s) with {} channel(s)", peerChannels.size, channels.size) + val peersWithChannels = channels.groupBy(_.remoteNodeId) + peersWithChannels.foreach { case (remoteNodeId, states) => createOrGetPeer(remoteNodeId, offlineChannels = states.toSet) } + // We must re-create peers that have a funded on-the-fly payment, even if they don't have a channel yet. + // It will retry relaying that payment and complete the on-the-fly funding. + // Note that since we use createOrGetPeer, we will automatically skip peers that have channels. + val peersWithOnTheFlyFunding = nodeParams.db.onTheFlyFunding.listPendingPayments().keySet + peersWithOnTheFlyFunding.foreach(remoteNodeId => createOrGetPeer(remoteNodeId, Set.empty)) + log.info("restoring {} peer(s) with {} channel(s)", peersWithChannels.size, channels.size) unstashAll() - context.become(normal(peerChannels.keySet)) + context.become(normal(peersWithChannels.keySet)) case _ => stash() } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index 20ae3fe823..133ea7e1d7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -96,10 +96,12 @@ object Monitoring { object RelayType { val Channel = "channel" val Trampoline = "trampoline" + val OnTheFly = "on-the-fly" def apply(e: PaymentRelayed): String = e match { case _: ChannelPaymentRelayed => Channel case _: TrampolinePaymentRelayed => Trampoline + case _: OnTheFlyFundingPaymentRelayed => OnTheFly } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index a312963781..030d5de45e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -102,6 +102,14 @@ case class TrampolinePaymentRelayed(paymentHash: ByteVector32, incoming: Payment override val timestamp: TimestampMilli = settledAt } +case class OnTheFlyFundingPaymentRelayed(paymentHash: ByteVector32, incoming: PaymentRelayed.Incoming, outgoing: PaymentRelayed.Outgoing) extends PaymentRelayed { + override val amountIn: MilliSatoshi = incoming.map(_.amount).sum + override val amountOut: MilliSatoshi = outgoing.map(_.amount).sum + override val startedAt: TimestampMilli = incoming.map(_.receivedAt).minOption.getOrElse(TimestampMilli.now()) + override val settledAt: TimestampMilli = outgoing.map(_.settledAt).maxOption.getOrElse(TimestampMilli.now()) + override val timestamp: TimestampMilli = settledAt +} + object PaymentRelayed { case class IncomingPart(amount: MilliSatoshi, channelId: ByteVector32, receivedAt: TimestampMilli) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 2d87f5d4f6..d4f145606c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -26,14 +26,15 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb -import fr.acinq.eclair.io.PeerReadyNotifier +import fr.acinq.eclair.io.Peer.ProposeOnTheFlyFundingResponse +import fr.acinq.eclair.io.{Peer, PeerReadyNotifier} import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{Features, Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -48,11 +49,13 @@ object ChannelRelay { private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command + private case class WrappedOnTheFlyFundingResponse(result: Peer.ProposeOnTheFlyFundingResponse) extends Command // @formatter:on // @formatter:off sealed trait RelayResult case class RelayFailure(cmdFail: CMD_FAIL_HTLC) extends RelayResult + case class RelayNeedsFunding(nextNode: PublicKey, cmdFail: CMD_FAIL_HTLC) extends RelayResult case class RelaySuccess(selectedChannelId: ByteVector32, cmdAdd: CMD_ADD_HTLC) extends RelayResult // @formatter:on @@ -114,6 +117,8 @@ class ChannelRelay private(nodeParams: NodeParams, private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) + private val forwardNodeIdFailureAdapter = context.messageAdapter[Register.ForwardNodeIdFailure[Peer.ProposeOnTheFlyFunding]](_ => WrappedOnTheFlyFundingResponse(Peer.ProposeOnTheFlyFundingResponse.NotAvailable("peer not found"))) + private val onTheFlyFundingResponseAdapter = context.messageAdapter[Peer.ProposeOnTheFlyFundingResponse](WrappedOnTheFlyFundingResponse) private val nextBlindingKey_opt = r.payload match { case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) @@ -136,13 +141,13 @@ class ChannelRelay private(nodeParams: NodeParams, case Left(walletNodeId) => wakeUp(walletNodeId) case Right(requestedShortChannelId) => context.self ! DoRelay - relay(Some(requestedShortChannelId), Seq.empty) + relay(Some(requestedShortChannelId), None, Seq.empty) } } private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.onTheFlyFundingConfig.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => @@ -151,11 +156,11 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => context.self ! DoRelay - relay(None, Seq.empty) + relay(None, Some(walletNodeId), Seq.empty) } } - def relay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { + def relay(requestedShortChannelId_opt: Option[ShortChannelId], walletNodeId_opt: Option[PublicKey], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { @@ -163,20 +168,24 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) - handleRelay(requestedShortChannelId_opt, previousFailures) match { + handleRelay(requestedShortChannelId_opt, walletNodeId_opt, previousFailures) match { case RelayFailure(cmdFail) => Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) context.log.info("rejecting htlc reason={}", cmdFail.reason) safeSendAndStop(r.add.channelId, cmdFail) + case RelayNeedsFunding(nextNodeId, cmdFail) => + val cmd = Peer.ProposeOnTheFlyFunding(onTheFlyFundingResponseAdapter, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, upstream) + register ! Register.ForwardNodeId(forwardNodeIdFailureAdapter, nextNodeId, cmd) + waitForOnTheFlyFundingResponse(cmdFail) case RelaySuccess(selectedChannelId, cmdAdd) => context.log.info("forwarding htlc #{} from channelId={} to channelId={}", r.add.id, r.add.channelId, selectedChannelId) register ! Register.Forward(forwardFailureAdapter, selectedChannelId, cmdAdd) - waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, previousFailures) + waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, walletNodeId_opt, previousFailures) } } } - private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], walletNodeId_opt: Option[PublicKey], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}") @@ -187,7 +196,7 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(addFailed: RES_ADD_FAILED[_]) => context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName) context.self ! DoRelay - relay(requestedShortChannelId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) + relay(requestedShortChannelId_opt, walletNodeId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") @@ -211,6 +220,21 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(upstream.add.channelId, cmd) } + private def waitForOnTheFlyFundingResponse(cmdFail: CMD_FAIL_HTLC): Behavior[Command] = Behaviors.receiveMessagePartial { + case WrappedOnTheFlyFundingResponse(response) => + response match { + case ProposeOnTheFlyFundingResponse.Proposed => + context.log.info("on-the-fly funding proposed for htlc #{} from channelId={}", r.add.id, r.add.channelId) + // We're not responsible for the payment relay anymore: another actor will take care of relaying the payment + // once on-the-fly funding completes. + Behaviors.stopped + case ProposeOnTheFlyFundingResponse.NotAvailable(reason) => + context.log.warn("could not propose on-the-fly funding for htlc #{} from channelId={}: {}", r.add.id, r.add.channelId, reason) + Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) + safeSendAndStop(r.add.channelId, cmdFail) + } + } + private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { val toSend = cmd match { case _: CMD_FULFILL_HTLC => cmd @@ -242,7 +266,7 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], walletNodeId_opt: Option[PublicKey], previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) selectPreferredChannel(requestedShortChannelId_opt, alreadyTried) match { case Some(outgoingChannel) => relayOrFail(outgoingChannel) @@ -259,7 +283,10 @@ class ChannelRelay private(nodeParams: NodeParams, } else { CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true) } - RelayFailure(cmdFail) + walletNodeId_opt match { + case Some(walletNodeId) if shouldAttemptOnTheFlyFunding(previousFailures) => RelayNeedsFunding(walletNodeId, cmdFail) + case _ => RelayFailure(cmdFail) + } } } @@ -358,6 +385,20 @@ class ChannelRelay private(nodeParams: NodeParams, } } + /** If we fail to relay a payment, we may want to attempt on-the-fly funding. */ + private def shouldAttemptOnTheFlyFunding(previousFailures: Seq[PreviouslyTried]): Boolean = { + val featureOk = nodeParams.features.hasFeature(Features.OnTheFlyFunding) + // If we have a channel with the next node, we only want to perform on-the-fly funding for liquidity issues. + val liquidityIssue = previousFailures.forall { + case PreviouslyTried(_, RES_ADD_FAILED(_, _: InsufficientFunds, _)) => true + case _ => false + } + // If we have a channel with the next peer, but we skipped it because the sender is using invalid relay parameters, + // we don't want to perform on-the-fly funding: the sender should send a valid payment first. + val relayParamsOk = channels.values.forall(c => validateRelayParams(c).isEmpty) + featureOk && liquidityIssue && relayParamsOk + } + private def recordRelayDuration(isSuccess: Boolean): Unit = Metrics.RelayedPaymentDuration .withTag(Tags.Relay, Tags.RelayType.Channel) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 84cb79c2ec..8de4138fff 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -24,9 +24,10 @@ import akka.actor.{ActorRef, typed} import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} +import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.PendingCommandsDb -import fr.acinq.eclair.io.PeerReadyNotifier +import fr.acinq.eclair.io.Peer.ProposeOnTheFlyFundingResponse +import fr.acinq.eclair.io.{Peer, PeerReadyNotifier} import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ @@ -37,11 +38,11 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode import fr.acinq.eclair.payment.send._ -import fr.acinq.eclair.router.Router.RouteParams +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, Route, RouteParams} import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{Alias, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit @@ -65,6 +66,7 @@ object NodeRelay { private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command + private case class WrappedOnTheFlyFundingResponse(result: Peer.ProposeOnTheFlyFundingResponse) extends Command // @formatter:on trait OutgoingPaymentFactory { @@ -161,6 +163,14 @@ object NodeRelay { .modify(_.includeLocalChannelCost).setTo(true) } + /** If we fail to relay a payment, we may want to attempt on-the-fly funding if it makes sense. */ + private def shouldAttemptOnTheFlyFunding(nodeParams: NodeParams, failures: Seq[PaymentFailure]): Boolean = { + val featureOk = nodeParams.features.hasFeature(Features.OnTheFlyFunding) + val balanceTooLow = failures.collectFirst { case f@LocalFailure(_, _, BalanceTooLow) => f }.nonEmpty + val routeNotFound = failures.collectFirst { case f@LocalFailure(_, _, RouteNotFound) => f }.nonEmpty + featureOk && (balanceTooLow || routeNotFound) + } + /** * This helper method translates relaying errors (returned by the downstream nodes) to a BOLT 4 standard error that we * should return upstream. @@ -230,7 +240,7 @@ class NodeRelay private(nodeParams: NodeParams, stopping() case WrappedMultiPartPaymentSucceeded(MultiPartPaymentFSM.MultiPartPaymentSucceeded(_, parts)) => context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.amount).sum) - val upstream = Upstream.Hot.Trampoline(htlcs) + val upstream = Upstream.Hot.Trampoline(htlcs.toList) validateRelay(nodeParams, upstream, nextPayload) match { case Some(failure) => context.log.warn(s"rejecting trampoline payment reason=$failure") @@ -298,7 +308,7 @@ class NodeRelay private(nodeParams: NodeParams, */ private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { context.log.info("trying to wake up next peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.onTheFlyFundingConfig.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { @@ -331,7 +341,7 @@ class NodeRelay private(nodeParams: NodeParams, } val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) payFSM ! payment - sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) + sending(upstream, recipient, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) } /** @@ -341,7 +351,7 @@ class NodeRelay private(nodeParams: NodeParams, * @param nextPayload relay instructions. * @param fulfilledUpstream true if we already fulfilled the payment upstream. */ - private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = + private def sending(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { // this is the fulfill that arrives from downstream channels @@ -350,7 +360,7 @@ class NodeRelay private(nodeParams: NodeParams, // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). context.log.debug("got preimage from downstream") fulfillPayment(upstream, paymentPreimage) - sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) + sending(upstream, recipient, nextPayload, startedAt, fulfilledUpstream = true) } else { // we don't want to fulfill multiple times Behaviors.same @@ -360,16 +370,66 @@ class NodeRelay private(nodeParams: NodeParams, success(upstream, fulfilledUpstream, paymentSent) recordRelayDuration(startedAt, isSuccess = true) stopping() + case _: WrappedPaymentFailed if fulfilledUpstream => + context.log.warn("trampoline payment failed downstream but was fulfilled upstream") + recordRelayDuration(startedAt, isSuccess = true) + stopping() case WrappedPaymentFailed(PaymentFailed(_, _, failures, _)) => - context.log.debug(s"trampoline payment failed downstream") - if (!fulfilledUpstream) { - rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload)) + nextWalletNodeId(nodeParams, recipient) match { + case Some(walletNodeId) if shouldAttemptOnTheFlyFunding(nodeParams, failures) => + context.log.info("trampoline payment failed, attempting on-the-fly funding") + attemptOnTheFlyFunding(upstream, walletNodeId, recipient, nextPayload, failures, startedAt) + case _ => + rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload)) + recordRelayDuration(startedAt, isSuccess = false) + stopping() } - recordRelayDuration(startedAt, isSuccess = fulfilledUpstream) - stopping() } } + /** We couldn't forward the payment, but the next node may accept on-the-fly funding. */ + private def attemptOnTheFlyFunding(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, failures: Seq[PaymentFailure], startedAt: TimestampMilli): Behavior[Command] = { + // We create a payment onion, using a dummy channel hop between our node and the wallet node. + val dummyEdge = Invoice.ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(0), 0 msat, 0, CltvExpiryDelta(0), 1 msat, None) + val dummyHop = ChannelHop(Alias(0), nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(dummyEdge)) + val finalHop_opt = recipient match { + case _: ClearRecipient => None + case _: SpontaneousRecipient => None + case _: TrampolineRecipient => None + case r: BlindedRecipient => r.blindedHops.headOption + } + val dummyRoute = Route(nextPayload.amountToForward, Seq(dummyHop), finalHop_opt) + OutgoingPaymentPacket.buildOutgoingPayment(Origin.Hot(ActorRef.noSender, upstream), paymentHash, dummyRoute, recipient) match { + case Left(f) => + context.log.warn("could not create payment onion for on-the-fly funding: {}", f.getMessage) + rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload)) + recordRelayDuration(startedAt, isSuccess = false) + stopping() + case Right(nextPacket) => + val forwardNodeIdFailureAdapter = context.messageAdapter[Register.ForwardNodeIdFailure[Peer.ProposeOnTheFlyFunding]](_ => WrappedOnTheFlyFundingResponse(Peer.ProposeOnTheFlyFundingResponse.NotAvailable("peer not found"))) + val onTheFlyFundingResponseAdapter = context.messageAdapter[Peer.ProposeOnTheFlyFundingResponse](WrappedOnTheFlyFundingResponse) + val cmd = Peer.ProposeOnTheFlyFunding(onTheFlyFundingResponseAdapter, nextPayload.amountToForward, paymentHash, nextPayload.outgoingCltv, nextPacket.cmd.onion, nextPacket.cmd.nextBlindingKey_opt, upstream) + register ! Register.ForwardNodeId(forwardNodeIdFailureAdapter, walletNodeId, cmd) + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + case WrappedOnTheFlyFundingResponse(response) => + response match { + case ProposeOnTheFlyFundingResponse.Proposed => + context.log.info("on-the-fly funding proposed") + // We're not responsible for the payment relay anymore: another actor will take care of relaying the + // payment once on-the-fly funding completes. + stopping() + case ProposeOnTheFlyFundingResponse.NotAvailable(reason) => + context.log.warn("could not propose on-the-fly funding: {}", reason) + rejectPayment(upstream, Some(UnknownNextPeer())) + recordRelayDuration(startedAt, isSuccess = false) + stopping() + } + } + } + } + } + /** * Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us. */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala new file mode 100644 index 0000000000..8a79f20783 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala @@ -0,0 +1,318 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.payment.relay + +import akka.actor.Cancellable +import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{ActorRef, Behavior} +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, TxId} +import fr.acinq.eclair.blockchain.fee.FeeratePerKw +import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.wire.protocol.LiquidityAds.PaymentDetails +import fr.acinq.eclair.wire.protocol._ +import fr.acinq.eclair.{Logs, MilliSatoshi, NodeParams, TimestampMilli, ToMilliSatoshiConversion} +import scodec.bits.ByteVector + +import scala.concurrent.duration.FiniteDuration + +/** + * Created by t-bast on 19/06/2024. + */ + +object OnTheFlyFunding { + + case class Config(wakeUpTimeout: FiniteDuration, proposalTimeout: FiniteDuration) + + // @formatter:off + sealed trait Status + object Status { + /** We sent will_add_htlc, but didn't fund a transaction yet. */ + case class Proposed(timer: Cancellable) extends Status + /** + * We signed a transaction matching the on-the-fly funding proposed. We're waiting for the liquidity to be + * available (channel ready or splice locked) to relay the HTLCs and complete the payment. + */ + case class Funded(channelId: ByteVector32, txId: TxId, fundingTxIndex: Long, remainingFees: MilliSatoshi) extends Status + } + // @formatter:on + + /** An on-the-fly funding proposal sent to our peer. */ + case class Proposal(htlc: WillAddHtlc, upstream: Upstream.Hot) { + /** Maximum fees that can be collected from this HTLC. */ + def maxFees(htlcMinimum: MilliSatoshi): MilliSatoshi = htlc.amount - htlcMinimum + + /** Create commands to fail all upstream HTLCs. */ + def createFailureCommands(failure_opt: Option[Either[ByteVector, FailureMessage]]): Seq[(ByteVector32, CMD_FAIL_HTLC)] = upstream match { + case _: Upstream.Local => Nil + case u: Upstream.Hot.Channel => + val failure = htlc.blinding_opt match { + case Some(_) => Right(InvalidOnionBlinding(Sphinx.hash(u.add.onionRoutingPacket))) + case None => failure_opt.getOrElse(Right(UnknownNextPeer())) + } + Seq(u.add.channelId -> CMD_FAIL_HTLC(u.add.id, failure, commit = true)) + case u: Upstream.Hot.Trampoline => + // In the trampoline case, we currently ignore downstream failures: we should add dedicated failures to the + // BOLTs to better handle those cases. + val failure = failure_opt match { + case Some(f) => f.getOrElse(TemporaryNodeFailure()) + case None => UnknownNextPeer() + } + u.received.map(_.add).map(add => add.channelId -> CMD_FAIL_HTLC(add.id, Right(failure), commit = true)) + } + + /** Create commands to fulfill all upstream HTLCs. */ + def createFulfillCommands(preimage: ByteVector32): Seq[(ByteVector32, CMD_FULFILL_HTLC)] = upstream match { + case _: Upstream.Local => Nil + case u: Upstream.Hot.Channel => Seq(u.add.channelId -> CMD_FULFILL_HTLC(u.add.id, preimage, commit = true)) + case u: Upstream.Hot.Trampoline => u.received.map(_.add).map(add => add.channelId -> CMD_FULFILL_HTLC(add.id, preimage, commit = true)) + } + } + + /** A set of funding proposals for a given payment. */ + case class Pending(proposed: Seq[Proposal], status: Status) { + val paymentHash = proposed.head.htlc.paymentHash + val expiry = proposed.map(_.htlc.expiry).min + + /** Maximum fees that can be collected from this HTLC set. */ + def maxFees(htlcMinimum: MilliSatoshi): MilliSatoshi = proposed.map(_.maxFees(htlcMinimum)).sum + + /** Create commands to fail all upstream HTLCs. */ + def createFailureCommands(): Seq[(ByteVector32, CMD_FAIL_HTLC)] = proposed.flatMap(_.createFailureCommands(None)) + + /** Create commands to fulfill all upstream HTLCs. */ + def createFulfillCommands(preimage: ByteVector32): Seq[(ByteVector32, CMD_FULFILL_HTLC)] = proposed.flatMap(_.createFulfillCommands(preimage)) + } + + // @formatter:off + sealed trait ValidationResult + object ValidationResult { + /** The incoming channel or splice cannot pay the liquidity fees: we must reject it and fail the corresponding upstream HTLCs. */ + case class Reject(cancel: CancelOnTheFlyFunding, paymentHashes: Set[ByteVector32]) extends ValidationResult + /** We are on-the-fly funding a channel: if we received preimages, we must fulfill the corresponding upstream HTLCs. */ + case class Accept(preimages: Set[ByteVector32]) extends ValidationResult + } + // @formatter:on + + /** Validate an incoming channel that may use on-the-fly funding. */ + def validateOpen(open: Either[OpenChannel, OpenDualFundedChannel], pendingOnTheFlyFunding: Map[ByteVector32, Pending]): ValidationResult = { + open match { + case Left(_) => ValidationResult.Accept(Set.empty) + case Right(open) => open.requestFunding_opt match { + case Some(requestFunding) => validate(open.temporaryChannelId, requestFunding, open.fundingFeerate, open.htlcMinimum, pendingOnTheFlyFunding) + case None => ValidationResult.Accept(Set.empty) + } + } + } + + /** Validate an incoming splice that may use on-the-fly funding. */ + def validateSplice(splice: SpliceInit, htlcMinimum: MilliSatoshi, pendingOnTheFlyFunding: Map[ByteVector32, Pending]): ValidationResult = { + splice.requestFunding_opt match { + case Some(requestFunding) => validate(splice.channelId, requestFunding, splice.feerate, htlcMinimum, pendingOnTheFlyFunding) + case None => ValidationResult.Accept(Set.empty) + } + } + + private def validate(channelId: ByteVector32, + requestFunding: LiquidityAds.RequestFunding, + feerate: FeeratePerKw, + htlcMinimum: MilliSatoshi, + pendingOnTheFlyFunding: Map[ByteVector32, Pending]): ValidationResult = { + val paymentHashes = requestFunding.paymentDetails match { + case PaymentDetails.FromChannelBalance => Nil + case PaymentDetails.FromChannelBalanceForFutureHtlc(paymentHashes) => paymentHashes + case PaymentDetails.FromFutureHtlc(paymentHashes) => paymentHashes + case PaymentDetails.FromFutureHtlcWithPreimage(preimages) => preimages.map(preimage => Crypto.sha256(preimage)) + } + val pending = paymentHashes.flatMap(paymentHash => pendingOnTheFlyFunding.get(paymentHash)).filter(_.status.isInstanceOf[OnTheFlyFunding.Status.Proposed]) + val totalPaymentAmount = pending.flatMap(_.proposed.map(_.htlc.amount)).sum + // We will deduce fees from HTLCs: we check that the amount is large enough to cover the fees. + val availableAmountForFees = pending.map(_.maxFees(htlcMinimum)).sum + val fees = requestFunding.fees(feerate) + val cancelAmountTooLow = CancelOnTheFlyFunding(channelId, paymentHashes, s"requested amount is too low to relay HTLCs: ${requestFunding.requestedAmount} < $totalPaymentAmount") + val cancelFeesTooLow = CancelOnTheFlyFunding(channelId, paymentHashes, s"htlc amount is too low to pay liquidity fees: $availableAmountForFees < ${fees.total}") + requestFunding.paymentDetails match { + case PaymentDetails.FromChannelBalance => ValidationResult.Accept(Set.empty) + case _ if requestFunding.requestedAmount.toMilliSatoshi < totalPaymentAmount => ValidationResult.Reject(cancelAmountTooLow, paymentHashes.toSet) + case _: PaymentDetails.FromChannelBalanceForFutureHtlc => ValidationResult.Accept(Set.empty) + case _: PaymentDetails.FromFutureHtlc if availableAmountForFees < fees.total => ValidationResult.Reject(cancelFeesTooLow, paymentHashes.toSet) + case _: PaymentDetails.FromFutureHtlc => ValidationResult.Accept(Set.empty) + case _: PaymentDetails.FromFutureHtlcWithPreimage if availableAmountForFees < fees.total => ValidationResult.Reject(cancelFeesTooLow, paymentHashes.toSet) + case p: PaymentDetails.FromFutureHtlcWithPreimage => ValidationResult.Accept(p.preimages.toSet) + } + } + + /** + * This actor relays HTLCs that were proposed with [[WillAddHtlc]] once funding is complete. + * It verifies that this payment was not previously relayed, to protect against over-paying and paying multiple times. + */ + object PaymentRelayer { + // @formatter:off + sealed trait Command + case class TryRelay(replyTo: ActorRef[RelayResult], channel: ActorRef[fr.acinq.eclair.channel.Command], proposed: Seq[Proposal], status: Status.Funded) extends Command + private case class WrappedChannelInfo(state: ChannelState, data: ChannelData) extends Command + private case class WrappedCommandResponse(response: CommandResponse[CMD_ADD_HTLC]) extends Command + private case class WrappedHtlcSettled(result: RES_ADD_SETTLED[Origin.Hot, HtlcResult]) extends Command + + sealed trait RelayResult + case class RelaySuccess(channelId: ByteVector32, paymentHash: ByteVector32, preimage: ByteVector32, fees: MilliSatoshi) extends RelayResult + case class RelayFailed(paymentHash: ByteVector32, failure: RelayFailure) extends RelayResult + + sealed trait RelayFailure + case object ExpiryTooClose extends RelayFailure { override def toString: String = "htlcs are too close to expiry to be relayed" } + case class ChannelNotAvailable(state: ChannelState) extends RelayFailure { override def toString: String = s"channel is not ready for payments (state=${state.toString})" } + case class CannotAddToChannel(t: Throwable) extends RelayFailure { override def toString: String = s"could not relay on-the-fly HTLC: ${t.getMessage}" } + case class RemoteFailure(failure: HtlcResult.Fail) extends RelayFailure { override def toString: String = s"relayed on-the-fly HTLC was failed: ${failure.getClass.getSimpleName}" } + // @formatter:on + + def apply(nodeParams: NodeParams, remoteNodeId: PublicKey, channelId: ByteVector32, paymentHash: ByteVector32): Behavior[Command] = + Behaviors.setup { context => + Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT), remoteNodeId_opt = Some(remoteNodeId), channelId_opt = Some(channelId), paymentHash_opt = Some(paymentHash))) { + Behaviors.receiveMessagePartial { + case cmd: TryRelay => new PaymentRelayer(nodeParams, channelId, paymentHash, cmd, context).start() + } + } + } + } + + class PaymentRelayer private(nodeParams: NodeParams, channelId: ByteVector32, paymentHash: ByteVector32, cmd: PaymentRelayer.TryRelay, context: ActorContext[PaymentRelayer.Command]) { + + import PaymentRelayer._ + + def start(): Behavior[Command] = { + if (cmd.proposed.exists(_.htlc.expiry.blockHeight <= nodeParams.currentBlockHeight + 12)) { + // The funding proposal expires soon: we shouldn't relay HTLCs to avoid risking a force-close. + cmd.replyTo ! RelayFailed(paymentHash, ExpiryTooClose) + Behaviors.stopped + } else { + checkChannelState() + } + } + + private def checkChannelState(): Behavior[Command] = { + cmd.channel ! CMD_GET_CHANNEL_INFO(context.messageAdapter[RES_GET_CHANNEL_INFO](r => WrappedChannelInfo(r.state, r.data))) + Behaviors.receiveMessagePartial { + case WrappedChannelInfo(_, data: DATA_NORMAL) if paymentAlreadyRelayed(paymentHash, data) => + context.log.warn("payment is already being relayed, waiting for it to be settled") + Behaviors.stopped + case WrappedChannelInfo(_, data: DATA_NORMAL) => + nodeParams.db.onTheFlyFunding.getPreimage(paymentHash) match { + case Some(preimage) => + // We have already received the preimage for that payment, but we probably restarted before removing the + // on-the-fly funding proposal from our DB. We must not relay the payment again, otherwise we will pay + // the next node twice. + cmd.replyTo ! RelaySuccess(channelId, paymentHash, preimage, cmd.status.remainingFees) + Behaviors.stopped + case None => relay(data) + } + case WrappedChannelInfo(state, _) => + cmd.replyTo ! RelayFailed(paymentHash, ChannelNotAvailable(state)) + Behaviors.stopped + } + } + + private def paymentAlreadyRelayed(paymentHash: ByteVector32, data: DATA_NORMAL): Boolean = { + data.commitments.changes.localChanges.all.exists { + case add: UpdateAddHtlc => add.paymentHash == paymentHash && add.fundingFee_opt.nonEmpty + case _ => false + } + } + + private def relay(data: DATA_NORMAL): Behavior[Command] = { + context.log.debug("relaying {} on-the-fly HTLCs that have been funded", cmd.proposed.size) + val htlcMinimum = data.commitments.params.remoteParams.htlcMinimum + val cmdAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedCommandResponse) + val htlcSettledAdapter = context.messageAdapter[RES_ADD_SETTLED[Origin.Hot, HtlcResult]](WrappedHtlcSettled) + cmd.proposed.foldLeft(cmd.status.remainingFees) { + case (remainingFees, p) => + // We always set the funding_fee field, even if the fee for this specific HTLC is 0. + // This lets us detect that this HTLC is an on-the-fly funded HTLC. + val htlcFees = LiquidityAds.FundingFee(remainingFees.min(p.maxFees(htlcMinimum)), cmd.status.txId) + val origin = Origin.Hot(htlcSettledAdapter.toClassic, p.upstream) + // We only sign at the end of the whole batch. + val commit = p.htlc.id == cmd.proposed.last.htlc.id + val add = CMD_ADD_HTLC(cmdAdapter.toClassic, p.htlc.amount - htlcFees.amount, paymentHash, p.htlc.expiry, p.htlc.finalPacket, p.htlc.blinding_opt, Some(htlcFees), origin, commit) + cmd.channel ! add + remainingFees - htlcFees.amount + } + waitForResult(cmd.proposed.size) + } + + private def waitForResult(remaining: Int): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedCommandResponse(response) => response match { + case _: CommandSuccess[_] => Behaviors.same + case failure: CommandFailure[_, _] => + cmd.replyTo ! RelayFailed(paymentHash, CannotAddToChannel(failure.t)) + stopping(remaining - 1) + } + case WrappedHtlcSettled(settled) => + settled.result match { + case fulfill: HtlcResult.Fulfill => cmd.replyTo ! RelaySuccess(channelId, paymentHash, fulfill.paymentPreimage, cmd.status.remainingFees) + case failure: HtlcResult.Fail => cmd.replyTo ! RelayFailed(paymentHash, RemoteFailure(failure)) + } + stopping(remaining - 1) + } + } + + private def stopping(remaining: Int): Behavior[Command] = { + if (remaining == 0) { + Behaviors.stopped + } else { + Behaviors.receiveMessagePartial { + case WrappedCommandResponse(response) => + response match { + case _: CommandSuccess[_] => Behaviors.same + case _: CommandFailure[_, _] => stopping(remaining - 1) + } + case WrappedHtlcSettled(settled) => + settled.result match { + case fulfill: HtlcResult.Fulfill => cmd.replyTo ! RelaySuccess(channelId, paymentHash, fulfill.paymentPreimage, cmd.status.remainingFees) + case _: HtlcResult.Fail => () + } + stopping(remaining - 1) + } + } + } + + } + + object Codecs { + + import fr.acinq.eclair.wire.protocol.CommonCodecs._ + import fr.acinq.eclair.wire.protocol.LightningMessageCodecs._ + import scodec.Codec + import scodec.codecs._ + + private val upstreamLocal: Codec[Upstream.Local] = uuid.as[Upstream.Local] + private val upstreamChannel: Codec[Upstream.Hot.Channel] = (lengthDelimited(updateAddHtlcCodec) :: uint64overflow.as[TimestampMilli] :: publicKey).as[Upstream.Hot.Channel] + private val upstreamTrampoline: Codec[Upstream.Hot.Trampoline] = listOfN(uint16, upstreamChannel).as[Upstream.Hot.Trampoline] + + val upstream: Codec[Upstream.Hot] = discriminated[Upstream.Hot].by(uint16) + .typecase(0x00, upstreamLocal) + .typecase(0x01, upstreamChannel) + .typecase(0x02, upstreamTrampoline) + + val proposal: Codec[Proposal] = (("willAddHtlc" | lengthDelimited(willAddHtlcCodec)) :: ("upstream" | upstream)).as[Proposal] + + val proposals: Codec[Seq[Proposal]] = listOfN(uint16, proposal).xmap(_.toSeq, _.toList) + + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala index 212226aacf..9d86b6b2bc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala @@ -68,6 +68,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial // result upstream to preserve channels. val brokenHtlcs: BrokenHtlcs = { val channels = listLocalChannels(init.channels) + val onTheFlyPayments = nodeParams.db.onTheFlyFunding.listPendingPayments().values.flatten.toSet val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs(nodeParams, log) }.flatten val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey, nodeParams.features) ++ nonStandardIncomingHtlcs val nonStandardRelayedOutHtlcs: Map[Origin.Cold, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn, nodeParams, log) }.flatten.toMap @@ -85,7 +86,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial log.info(s"htlcsIn=${htlcsIn.length} notRelayed=${notRelayed.length} relayedOut=${relayedOut.values.flatten.size}") log.info("notRelayed={}", notRelayed.map(htlc => (htlc.add.channelId, htlc.add.id))) log.info("relayedOut={}", relayedOut) - BrokenHtlcs(notRelayed, relayedOut, Set.empty) + BrokenHtlcs(notRelayed, relayedOut, Set.empty, onTheFlyPayments) } Metrics.PendingNotRelayed.update(brokenHtlcs.notRelayed.size) @@ -120,6 +121,10 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial } else { log.info(s"got preimage but upstream channel is closed for htlc=$htlc") } + case None if brokenHtlcs.pendingPayments.contains(htlc.paymentHash) => + // We don't fail on-the-fly HTLCs that have been funded: we haven't been paid our fee yet, so we will + // retry relaying them unless we reach the HTLC timeout. + log.info("htlc #{} from channelId={} wasn't relayed, but has a pending on-the-fly relay (paymentHash={})", htlc.id, htlc.channelId, htlc.paymentHash) case None => Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = false).increment() if (e.currentState != CLOSING && e.currentState != CLOSED) { @@ -150,12 +155,17 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial case _: ChannelStateChanged => // ignore other channel state changes case RES_ADD_SETTLED(o: Origin.Cold, htlc, fulfill: HtlcResult.Fulfill) => - log.info("htlc fulfilled downstream: ({},{})", htlc.channelId, htlc.id) + log.info("htlc #{} from channelId={} fulfilled downstream", htlc.id, htlc.channelId) handleDownstreamFulfill(brokenHtlcs, o, htlc, fulfill.paymentPreimage) case RES_ADD_SETTLED(o: Origin.Cold, htlc, fail: HtlcResult.Fail) => - log.info("htlc failed downstream: ({},{},{})", htlc.channelId, htlc.id, fail.getClass.getSimpleName) - handleDownstreamFailure(brokenHtlcs, o, htlc, fail) + if (htlc.fundingFee_opt.nonEmpty) { + log.info("htlc #{} from channelId={} failed downstream but has a pending on-the-fly funding", htlc.id, htlc.channelId) + // We don't fail upstream: we haven't been paid our funding fee yet, so we will try relaying again. + } else { + log.info("htlc #{} from channelId={} failed downstream: {}", htlc.id, htlc.channelId, fail.getClass.getSimpleName) + handleDownstreamFailure(brokenHtlcs, o, htlc, fail) + } case GetBrokenHtlcs => sender() ! brokenHtlcs } @@ -329,8 +339,9 @@ object PostRestartHtlcCleaner { * @param notRelayed incoming HTLCs that were committed upstream but not relayed downstream. * @param relayedOut outgoing HTLC sets that may have been incompletely sent and need to be watched. * @param settledUpstream upstream payments that have already been settled (failed or fulfilled) by this actor. + * @param pendingPayments payments that are pending and will be relayed: we mustn't fail them upstream. */ - case class BrokenHtlcs(notRelayed: Seq[IncomingHtlc], relayedOut: Map[Origin.Cold, Set[(ByteVector32, Long)]], settledUpstream: Set[Origin.Cold]) + case class BrokenHtlcs(notRelayed: Seq[IncomingHtlc], relayedOut: Map[Origin.Cold, Set[(ByteVector32, Long)]], settledUpstream: Set[Origin.Cold], pendingPayments: Set[ByteVector32]) /** Returns true if the given HTLC matches the given origin. */ private def matchesOrigin(htlcIn: UpdateAddHtlc, origin: Origin.Cold): Boolean = origin.upstream match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index d85f9876ac..471cf462b5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -96,10 +96,15 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) } - case r: RES_ADD_SETTLED[_, _] => r.origin match { - case _: Origin.Cold => postRestartCleaner ! r - case o: Origin.Hot => o.replyTo ! r - } + case r: RES_ADD_SETTLED[_, HtlcResult] => + r.result match { + case fulfill: HtlcResult.Fulfill if r.htlc.fundingFee_opt.nonEmpty => nodeParams.db.onTheFlyFunding.addPreimage(fulfill.paymentPreimage) + case _ => () + } + r.origin match { + case _: Origin.Cold => postRestartCleaner ! r + case o: Origin.Hot => o.replyTo ! r + } case g: GetOutgoingChannels => channelRelayer ! ChannelRelayer.GetOutgoingChannels(sender(), g) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 2bf58173d3..50329e4b72 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -18,8 +18,6 @@ package fr.acinq.eclair import akka.actor.ActorRef import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Satoshi, SatoshiLong} -import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} -import fr.acinq.eclair.Features._ import fr.acinq.eclair.blockchain.fee._ import fr.acinq.eclair.channel.fsm.Channel.{ChannelConf, RemoteRbfLimits, UnhandledExceptionStrategy} import fr.acinq.eclair.channel.{ChannelFlags, LocalParams, Origin, Upstream} @@ -28,6 +26,7 @@ import fr.acinq.eclair.db.RevokedHtlcInfoCleaner import fr.acinq.eclair.io.MessageRelay.RelayAll import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig +import fr.acinq.eclair.payment.relay.OnTheFlyFunding import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios} import fr.acinq.eclair.router.PathFindingExperimentConf @@ -97,16 +96,16 @@ object TestConstants { torAddress_opt = None, features = Features( Map( - DataLossProtect -> Optional, - ChannelRangeQueries -> Optional, - ChannelRangeQueriesExtended -> Optional, - VariableLengthOnion -> Mandatory, - PaymentSecret -> Mandatory, - BasicMultiPartPayment -> Optional, - Wumbo -> Optional, - PaymentMetadata -> Optional, - RouteBlinding -> Optional, - StaticRemoteKey -> Mandatory + Features.DataLossProtect -> FeatureSupport.Optional, + Features.ChannelRangeQueries -> FeatureSupport.Optional, + Features.ChannelRangeQueriesExtended -> FeatureSupport.Optional, + Features.VariableLengthOnion -> FeatureSupport.Mandatory, + Features.PaymentSecret -> FeatureSupport.Mandatory, + Features.BasicMultiPartPayment -> FeatureSupport.Optional, + Features.Wumbo -> FeatureSupport.Optional, + Features.PaymentMetadata -> FeatureSupport.Optional, + Features.RouteBlinding -> FeatureSupport.Optional, + Features.StaticRemoteKey -> FeatureSupport.Mandatory, ), unknown = Set(UnknownFeature(TestFeature.optional)) ), @@ -237,7 +236,10 @@ object TestConstants { purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), willFundRates_opt = Some(defaultLiquidityRates), - wakeUpTimeout = 30 seconds, + onTheFlyFundingConfig = OnTheFlyFunding.Config( + wakeUpTimeout = 30 seconds, + proposalTimeout = 90 seconds, + ) ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -270,17 +272,17 @@ object TestConstants { publicAddresses = NodeAddress.fromParts("localhost", 9732).get :: Nil, torAddress_opt = None, features = Features( - DataLossProtect -> Optional, - ChannelRangeQueries -> Optional, - ChannelRangeQueriesExtended -> Optional, - VariableLengthOnion -> Mandatory, - PaymentSecret -> Mandatory, - BasicMultiPartPayment -> Optional, - Wumbo -> Optional, - PaymentMetadata -> Optional, - RouteBlinding -> Optional, - StaticRemoteKey -> Mandatory, - AnchorOutputsZeroFeeHtlcTx -> Optional + Features.DataLossProtect -> FeatureSupport.Optional, + Features.ChannelRangeQueries -> FeatureSupport.Optional, + Features.ChannelRangeQueriesExtended -> FeatureSupport.Optional, + Features.VariableLengthOnion -> FeatureSupport.Mandatory, + Features.PaymentSecret -> FeatureSupport.Mandatory, + Features.BasicMultiPartPayment -> FeatureSupport.Optional, + Features.Wumbo -> FeatureSupport.Optional, + Features.PaymentMetadata -> FeatureSupport.Optional, + Features.RouteBlinding -> FeatureSupport.Optional, + Features.StaticRemoteKey -> FeatureSupport.Mandatory, + Features.AnchorOutputsZeroFeeHtlcTx -> FeatureSupport.Optional, ), pluginParams = Nil, overrideInitFeatures = Map.empty, @@ -409,7 +411,10 @@ object TestConstants { purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), willFundRates_opt = Some(defaultLiquidityRates), - wakeUpTimeout = 30 seconds, + onTheFlyFundingConfig = OnTheFlyFunding.Config( + wakeUpTimeout = 30 seconds, + proposalTimeout = 90 seconds, + ) ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index 2b25448ce4..27d56c78e3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -34,6 +34,7 @@ sealed trait TestDatabases extends Databases { override def peers: PeersDb = db.peers override def payments: PaymentsDb = db.payments override def pendingCommands: PendingCommandsDb = db.pendingCommands + override def onTheFlyFunding: OnTheFlyFundingDb = db.onTheFlyFunding def close(): Unit // @formatter:on } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/InteractiveTxBuilderSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/InteractiveTxBuilderSpec.scala index fcff644b5b..6d3f8ba113 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/InteractiveTxBuilderSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/InteractiveTxBuilderSpec.scala @@ -2239,7 +2239,7 @@ class InteractiveTxBuilderSpec extends TestKitBaseClass with AnyFunSuiteLike wit val bob = params.spawnTxBuilderBob(wallet, params.fundingParamsB, Some(purchase)) bob ! Start(probe.ref) assert(probe.expectMsgType[LocalFailure].cause == InvalidFundingBalances(params.channelId, 524_000 sat, 525_000_000 msat, -1_000_000 msat)) - // Bob reject a splice proposed by Alice where she doesn't have enough funds to pay the liquidity fees. + // Bob rejects a splice proposed by Alice where she doesn't have enough funds to pay the liquidity fees. val previousCommitment = CommitmentsSpec.makeCommitments(450_000_000 msat, 50_000_000 msat).active.head val sharedInput = params.dummySharedInputB(500_000 sat) val spliceParams = params.fundingParamsB.copy(localContribution = 150_000 sat, remoteContribution = -30_000 sat, sharedInput_opt = Some(sharedInput)) @@ -2250,6 +2250,11 @@ class InteractiveTxBuilderSpec extends TestKitBaseClass with AnyFunSuiteLike wit val bobFutureHtlc = params.spawnTxBuilderBob(wallet, params.fundingParamsB, Some(purchase.copy(paymentDetails = LiquidityAds.PaymentDetails.FromFutureHtlc(Nil)))) bobFutureHtlc ! Start(probe.ref) probe.expectNoMessage(100 millis) + // Bob rejects a splice proposed by Alice where she has enough funds to pay the liquidity fees, but wants to pay + // them outside of the interactive-tx session, which requires some trust. + val bobFutureHtlcWithBalance = params.spawnTxBuilderSpliceBob(spliceParams, previousCommitment, wallet, Some(purchase.copy(fees = LiquidityAds.Fees(1000 sat, 4000 sat), paymentDetails = LiquidityAds.PaymentDetails.FromFutureHtlc(Nil)))) + bobFutureHtlcWithBalance ! Start(probe.ref) + assert(probe.expectMsgType[LocalFailure].cause == InvalidLiquidityAdsPaymentType(params.channelId, LiquidityAds.PaymentType.FromFutureHtlc, Set(LiquidityAds.PaymentType.FromChannelBalance, LiquidityAds.PaymentType.FromChannelBalanceForFutureHtlc))) } test("invalid input") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/b/WaitForDualFundingSignedStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/b/WaitForDualFundingSignedStateSpec.scala index a6d2a557c7..a690ffb8c1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/b/WaitForDualFundingSignedStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/b/WaitForDualFundingSignedStateSpec.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.fund.InteractiveTxBuilder.{FullySignedSharedTransaction, PartiallySignedSharedTransaction} import fr.acinq.eclair.channel.publish.TxPublisher import fr.acinq.eclair.channel.states.{ChannelStateTestsBase, ChannelStateTestsTags} -import fr.acinq.eclair.io.Peer.OpenChannelResponse +import fr.acinq.eclair.io.Peer.{LiquidityPurchaseSigned, OpenChannelResponse} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Features, MilliSatoshiLong, TestConstants, TestKitBaseClass, ToMilliSatoshiConversion} @@ -39,7 +39,7 @@ import scala.concurrent.duration.DurationInt class WaitForDualFundingSignedStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with ChannelStateTestsBase { - case class FixtureParam(alice: TestFSMRef[ChannelState, ChannelData, Channel], bob: TestFSMRef[ChannelState, ChannelData, Channel], alice2bob: TestProbe, bob2alice: TestProbe, alice2blockchain: TestProbe, bob2blockchain: TestProbe, wallet: SingleKeyOnChainWallet, aliceListener: TestProbe, bobListener: TestProbe) + case class FixtureParam(alice: TestFSMRef[ChannelState, ChannelData, Channel], bob: TestFSMRef[ChannelState, ChannelData, Channel], alicePeer: TestProbe, bobPeer: TestProbe, alice2bob: TestProbe, bob2alice: TestProbe, alice2blockchain: TestProbe, bob2blockchain: TestProbe, wallet: SingleKeyOnChainWallet, aliceListener: TestProbe, bobListener: TestProbe) override def withFixture(test: OneArgTest): Outcome = { val wallet = new SingleKeyOnChainWallet() @@ -50,7 +50,8 @@ class WaitForDualFundingSignedStateSpec extends TestKitBaseClass with FixtureAny val (aliceParams, bobParams, channelType) = computeFeatures(setup, test.tags, channelFlags) val aliceInit = Init(aliceParams.initFeatures) val bobInit = Init(bobParams.initFeatures) - val bobContribution = if (channelType.features.contains(Features.ZeroConf)) None else Some(LiquidityAds.AddFunding(TestConstants.nonInitiatorFundingSatoshis, None)) + val bobContribution = if (channelType.features.contains(Features.ZeroConf)) None else Some(LiquidityAds.AddFunding(TestConstants.nonInitiatorFundingSatoshis, Some(TestConstants.defaultLiquidityRates))) + val requestFunding_opt = if (test.tags.contains(ChannelStateTestsTags.LiquidityAds)) Some(LiquidityAds.RequestFunding(TestConstants.nonInitiatorFundingSatoshis, TestConstants.defaultLiquidityRates.fundingRates.head, LiquidityAds.PaymentDetails.FromChannelBalance)) else None val (initiatorPushAmount, nonInitiatorPushAmount) = if (test.tags.contains("both_push_amount")) (Some(TestConstants.initiatorPushAmount), Some(TestConstants.nonInitiatorPushAmount)) else (None, None) val commitFeerate = channelType.commitmentFormat match { case Transactions.DefaultCommitmentFormat => TestConstants.feeratePerKw @@ -61,7 +62,7 @@ class WaitForDualFundingSignedStateSpec extends TestKitBaseClass with FixtureAny within(30 seconds) { alice.underlying.system.eventStream.subscribe(aliceListener.ref, classOf[ChannelAborted]) bob.underlying.system.eventStream.subscribe(bobListener.ref, classOf[ChannelAborted]) - alice ! INPUT_INIT_CHANNEL_INITIATOR(ByteVector32.Zeroes, TestConstants.fundingSatoshis, dualFunded = true, commitFeerate, TestConstants.feeratePerKw, fundingTxFeeBudget_opt = None, initiatorPushAmount, requireConfirmedInputs = false, requestFunding_opt = None, aliceParams, alice2bob.ref, bobInit, channelFlags, channelConfig, channelType, replyTo = aliceOpenReplyTo.ref.toTyped) + alice ! INPUT_INIT_CHANNEL_INITIATOR(ByteVector32.Zeroes, TestConstants.fundingSatoshis, dualFunded = true, commitFeerate, TestConstants.feeratePerKw, fundingTxFeeBudget_opt = None, initiatorPushAmount, requireConfirmedInputs = false, requestFunding_opt, aliceParams, alice2bob.ref, bobInit, channelFlags, channelConfig, channelType, replyTo = aliceOpenReplyTo.ref.toTyped) bob ! INPUT_INIT_CHANNEL_NON_INITIATOR(ByteVector32.Zeroes, bobContribution, dualFunded = true, nonInitiatorPushAmount, bobParams, bob2alice.ref, aliceInit, channelConfig, channelType) alice2blockchain.expectMsgType[TxPublisher.SetChannelId] // temporary channel id bob2blockchain.expectMsgType[TxPublisher.SetChannelId] // temporary channel id @@ -95,7 +96,7 @@ class WaitForDualFundingSignedStateSpec extends TestKitBaseClass with FixtureAny aliceOpenReplyTo.expectMsgType[OpenChannelResponse.Created] awaitCond(alice.stateName == WAIT_FOR_DUAL_FUNDING_SIGNED) awaitCond(bob.stateName == WAIT_FOR_DUAL_FUNDING_SIGNED) - withFixture(test.toNoArgTest(FixtureParam(alice, bob, alice2bob, bob2alice, alice2blockchain, bob2blockchain, wallet, aliceListener, bobListener))) + withFixture(test.toNoArgTest(FixtureParam(alice, bob, alicePeer, bobPeer, alice2bob, bob2alice, alice2blockchain, bob2blockchain, wallet, aliceListener, bobListener))) } } @@ -196,6 +197,44 @@ class WaitForDualFundingSignedStateSpec extends TestKitBaseClass with FixtureAny assert(aliceData.commitments.latest.localCommit.spec.toRemote == expectedBalanceBob) } + test("complete interactive-tx protocol (with liquidity ads)", Tag(ChannelStateTestsTags.DualFunding), Tag(ChannelStateTestsTags.LiquidityAds)) { f => + import f._ + + val listener = TestProbe() + alice.underlyingActor.context.system.eventStream.subscribe(listener.ref, classOf[TransactionPublished]) + + bob2alice.expectMsgType[CommitSig] + bob2alice.forward(alice) + alice2bob.expectMsgType[CommitSig] + alice2bob.forward(bob) + + // Bob sends its signatures first as he contributed less than Alice. + bob2alice.expectMsgType[TxSignatures] + awaitCond(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED) + val bobData = bob.stateData.asInstanceOf[DATA_WAIT_FOR_DUAL_FUNDING_CONFIRMED] + val fundingTxId = bobData.latestFundingTx.sharedTx.asInstanceOf[PartiallySignedSharedTransaction].txId + assert(bob2blockchain.expectMsgType[WatchFundingConfirmed].txId == fundingTxId) + + // Bob signed a liquidity purchase. + bobPeer.fishForMessage() { + case l: LiquidityPurchaseSigned => + assert(l.purchase.paymentDetails == LiquidityAds.PaymentDetails.FromChannelBalance) + assert(l.fundingTxIndex == 0) + assert(l.txId == fundingTxId) + true + case _ => false + } + + // Alice receives Bob's signatures and sends her own signatures. + bob2alice.forward(alice) + assert(listener.expectMsgType[TransactionPublished].tx.txid == fundingTxId) + assert(alice2blockchain.expectMsgType[WatchFundingConfirmed].txId == fundingTxId) + alice2bob.expectMsgType[TxSignatures] + awaitCond(alice.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED) + val aliceData = alice.stateData.asInstanceOf[DATA_WAIT_FOR_DUAL_FUNDING_CONFIRMED] + assert(aliceData.latestFundingTx.sharedTx.asInstanceOf[FullySignedSharedTransaction].signedTx.txid == fundingTxId) + } + test("recv invalid CommitSig", Tag(ChannelStateTestsTags.DualFunding)) { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/c/WaitForDualFundingReadyStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/c/WaitForDualFundingReadyStateSpec.scala index 9a4bd93123..71453ed587 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/c/WaitForDualFundingReadyStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/c/WaitForDualFundingReadyStateSpec.scala @@ -38,7 +38,7 @@ import scala.concurrent.duration.DurationInt class WaitForDualFundingReadyStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with ChannelStateTestsBase { - case class FixtureParam(alice: TestFSMRef[ChannelState, ChannelData, Channel], bob: TestFSMRef[ChannelState, ChannelData, Channel], alice2bob: TestProbe, bob2alice: TestProbe, alice2blockchain: TestProbe, bob2blockchain: TestProbe, listener: TestProbe) + case class FixtureParam(alice: TestFSMRef[ChannelState, ChannelData, Channel], bob: TestFSMRef[ChannelState, ChannelData, Channel], alicePeer: TestProbe, bobPeer: TestProbe, alice2bob: TestProbe, bob2alice: TestProbe, alice2blockchain: TestProbe, bob2blockchain: TestProbe, listener: TestProbe) override def withFixture(test: OneArgTest): Outcome = { val setup = init(tags = test.tags) @@ -105,7 +105,7 @@ class WaitForDualFundingReadyStateSpec extends TestKitBaseClass with FixtureAnyF } awaitCond(alice.stateName == WAIT_FOR_DUAL_FUNDING_READY) awaitCond(bob.stateName == WAIT_FOR_DUAL_FUNDING_READY) - withFixture(test.toNoArgTest(FixtureParam(alice, bob, alice2bob, bob2alice, alice2blockchain, bob2blockchain, listener))) + withFixture(test.toNoArgTest(FixtureParam(alice, bob, alicePeer, bobPeer, alice2bob, bob2alice, alice2blockchain, bob2blockchain, listener))) } } @@ -129,6 +129,22 @@ class WaitForDualFundingReadyStateSpec extends TestKitBaseClass with FixtureAnyF listenerA.expectMsg(ChannelOpened(alice, bob.underlyingActor.nodeParams.nodeId, channelId(alice))) awaitCond(alice.stateName == NORMAL) + // The channel is now ready to process payments. + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => + assert(e.fundingTxIndex == 0) + assert(e.channelId == aliceChannelReady.channelId) + true + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => + assert(e.fundingTxIndex == 0) + assert(e.channelId == aliceChannelReady.channelId) + true + case _ => false + } + assert(alice.stateData.asInstanceOf[DATA_NORMAL].shortIds.real.isInstanceOf[RealScidStatus.Temporary]) val aliceCommitments = alice.stateData.asInstanceOf[DATA_NORMAL].commitments.latest val aliceUpdate = alice.stateData.asInstanceOf[DATA_NORMAL].channelUpdate diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala index da0afb500a..95e5d8590a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala @@ -36,6 +36,7 @@ import fr.acinq.eclair.channel.publish.TxPublisher.{PublishFinalTx, PublishRepla import fr.acinq.eclair.channel.states.ChannelStateTestsBase.{FakeTxPublisherFactory, PimpTestFSM} import fr.acinq.eclair.channel.states.{ChannelStateTestsBase, ChannelStateTestsTags} import fr.acinq.eclair.db.RevokedHtlcInfoCleaner.ForgetHtlcInfos +import fr.acinq.eclair.io.Peer.LiquidityPurchaseSigned import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.testutils.PimpTestProbe.convert import fr.acinq.eclair.transactions.DirectedHtlc.{incoming, outgoing} @@ -341,9 +342,19 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik assert(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.latest.capacity == 2_400_000.sat) assert(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.latest.localCommit.spec.toLocal < 1_300_000_000.msat) assert(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.latest.localCommit.spec.toRemote > 1_100_000_000.msat) + + // Bob signed a liquidity purchase. + bobPeer.fishForMessage() { + case l: LiquidityPurchaseSigned => + assert(l.purchase.paymentDetails == LiquidityAds.PaymentDetails.FromChannelBalance) + assert(l.fundingTxIndex == 1) + assert(l.txId == alice.stateData.asInstanceOf[DATA_NORMAL].commitments.latest.fundingTxId) + true + case _ => false + } } - test("recv CMD_SPLICE (splice-in, liquidity ads, invalid lease witness)", Tag(ChannelStateTestsTags.Quiescence)) { f => + test("recv CMD_SPLICE (splice-in, liquidity ads, invalid will_fund signature)", Tag(ChannelStateTestsTags.Quiescence)) { f => import f._ val sender = TestProbe() @@ -938,6 +949,16 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik aliceEvents.expectNoMessage(100 millis) bobEvents.expectNoMessage(100 millis) + // The channel is now ready to use liquidity from the first splice. + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + bob ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx2) bob2alice.expectMsgType[SpliceLocked] bob2alice.forward(alice) @@ -953,6 +974,16 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik bobEvents.expectAvailableBalanceChanged(balance = 650_000_000.msat, capacity = 2_500_000.sat) aliceEvents.expectNoMessage(100 millis) bobEvents.expectNoMessage(100 millis) + + // The channel is now ready to use liquidity from the second splice. + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 2 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 2 + case _ => false + } } test("recv CMD_ADD_HTLC with multiple commitments") { f => @@ -1514,6 +1545,16 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectNoMessage(100 millis) bob2alice.expectNoMessage(100 millis) + // splice transactions are not locked yet: we're still at the initial funding index + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 0 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 0 + case _ => false + } + // splice 1 confirms on alice's side watchConfirmed1a.replyTo ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx1) assert(alice2bob.expectMsgType[SpliceLocked].fundingTxId == fundingTx1.txid) @@ -1540,12 +1581,28 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectNoMessage(100 millis) bob2alice.expectNoMessage(100 millis) + // splice transactions are not locked by bob yet: we're still at the initial funding index + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 0 + case _ => false + } + // splice 1 confirms on bob's side watchConfirmed1b.replyTo ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx1) assert(bob2alice.expectMsgType[SpliceLocked].fundingTxId == fundingTx1.txid) bob2alice.forward(alice) bob2blockchain.expectMsgType[WatchFundingSpent] + // splice 1 is locked on both sides + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + disconnect(f) reconnect(f) @@ -1556,11 +1613,26 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectNoMessage(100 millis) bob2alice.expectNoMessage(100 millis) + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 1 + case _ => false + } + // splice 2 confirms on bob's side watchConfirmed2b.replyTo ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx2) assert(bob2alice.expectMsgType[SpliceLocked].fundingTxId == fundingTx2.txid) bob2blockchain.expectMsgType[WatchFundingSpent] + // splice 2 is locked on both sides + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 2 + case _ => false + } + // NB: we disconnect *before* transmitting the splice_confirmed to alice disconnect(f) reconnect(f) @@ -1582,6 +1654,16 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik bob2alice.forward(alice) alice2bob.expectNoMessage(100 millis) bob2alice.expectNoMessage(100 millis) + + // splice 2 is locked on both sides + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 2 + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => e.fundingTxIndex == 2 + case _ => false + } } /** Check type of published transactions */ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 2f2c3d106c..6629023b97 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -135,7 +135,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val h = randomBytes32() val originHtlc1 = UpdateAddHtlc(randomBytes32(), 47, 30000000 msat, h, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None, None) val originHtlc2 = UpdateAddHtlc(randomBytes32(), 32, 20000000 msat, h, CltvExpiryDelta(160).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None, None) - val origin = Origin.Hot(sender.ref, Upstream.Hot.Trampoline(Seq(originHtlc1, originHtlc2).map(htlc => Upstream.Hot.Channel(htlc, TimestampMilli.now(), randomKey().publicKey)))) + val origin = Origin.Hot(sender.ref, Upstream.Hot.Trampoline(List(originHtlc1, originHtlc2).map(htlc => Upstream.Hot.Channel(htlc, TimestampMilli.now(), randomKey().publicKey)))) val cmd = CMD_ADD_HTLC(sender.ref, originHtlc1.amountMsat + originHtlc2.amountMsat - 10000.msat, h, originHtlc2.cltvExpiry - CltvExpiryDelta(7), TestConstants.emptyOnionPacket, None, None, origin) alice ! cmd sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index 4e7851fb9f..6c998d315c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -72,6 +72,37 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val aliceInit = Init(TestConstants.Alice.nodeParams.features.initFeatures()) val bobInit = Init(TestConstants.Bob.nodeParams.features.initFeatures()) + test("reconnect after creating channel", Tag(IgnoreChannelUpdates)) { f => + import f._ + + disconnect(alice, bob) + reconnect(alice, bob, alice2bob, bob2alice) + alice2bob.expectMsgType[ChannelReestablish] + alice2bob.forward(bob) + bob2alice.expectMsgType[ChannelReestablish] + bob2alice.forward(alice) + + // This is a new channel: peers exchange channel_ready again. + val channelId = alice2bob.expectMsgType[ChannelReady].channelId + bob2alice.expectMsgType[ChannelReady] + + // The channel is ready to process payments. + alicePeer.fishForMessage() { + case e: ChannelReadyForPayments => + assert(e.fundingTxIndex == 0) + assert(e.channelId == channelId) + true + case _ => false + } + bobPeer.fishForMessage() { + case e: ChannelReadyForPayments => + assert(e.fundingTxIndex == 0) + assert(e.channelId == channelId) + true + case _ => false + } + } + test("re-send lost htlc and signature after first commitment", Tag(IgnoreChannelUpdates)) { f => import f._ // alice bob diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/OnTheFlyFundingDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/OnTheFlyFundingDbSpec.scala new file mode 100644 index 0000000000..84d263b9bf --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/OnTheFlyFundingDbSpec.scala @@ -0,0 +1,138 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db + +import fr.acinq.bitcoin.scalacompat.{Crypto, TxId} +import fr.acinq.eclair.channel.Upstream +import fr.acinq.eclair.payment.relay.OnTheFlyFunding +import fr.acinq.eclair.payment.relay.OnTheFlyFundingSpec._ +import fr.acinq.eclair.wire.protocol.UpdateAddHtlc +import fr.acinq.eclair.{CltvExpiry, MilliSatoshiLong, TimestampMilli, randomBytes32, randomKey} +import org.scalatest.funsuite.AnyFunSuite + +import java.util.UUID + +class OnTheFlyFundingDbSpec extends AnyFunSuite { + + import fr.acinq.eclair.TestDatabases.forAllDbs + + test("add/get preimages") { + forAllDbs { dbs => + val db = dbs.onTheFlyFunding + + val preimage1 = randomBytes32() + val preimage2 = randomBytes32() + + db.addPreimage(preimage1) + db.addPreimage(preimage1) // no-op + db.addPreimage(preimage2) + + assert(db.getPreimage(Crypto.sha256(preimage1)).contains(preimage1)) + assert(db.getPreimage(Crypto.sha256(preimage2)).contains(preimage2)) + assert(db.getPreimage(randomBytes32()).isEmpty) + } + } + + test("add/list/remove pending proposals") { + forAllDbs { dbs => + val db = dbs.onTheFlyFunding + + val alice = randomKey().publicKey + val bob = randomKey().publicKey + val paymentHash1 = randomBytes32() + val paymentHash2 = randomBytes32() + val upstream = Seq( + Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 7, 25_000_000 msat, paymentHash1, CltvExpiry(750_000), randomOnion(), None, None), TimestampMilli(0), randomKey().publicKey), + Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 0, 1 msat, paymentHash1, CltvExpiry(750_000), randomOnion(), Some(randomKey().publicKey), None), TimestampMilli.now(), randomKey().publicKey), + Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 561, 100_000_000 msat, paymentHash2, CltvExpiry(799_999), randomOnion(), None, None), TimestampMilli.now(), randomKey().publicKey), + Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 1105, 100_000_000 msat, paymentHash2, CltvExpiry(799_999), randomOnion(), None, None), TimestampMilli.now(), randomKey().publicKey), + ) + val pendingAlice = Seq( + OnTheFlyFunding.Pending( + proposed = Seq( + OnTheFlyFunding.Proposal(createWillAdd(20_000 msat, paymentHash1, CltvExpiry(500)), upstream(0)), + OnTheFlyFunding.Proposal(createWillAdd(1 msat, paymentHash1, CltvExpiry(750), Some(randomKey().publicKey)), upstream(1)), + ), + status = OnTheFlyFunding.Status.Funded(randomBytes32(), TxId(randomBytes32()), 7, 500 msat) + ), + OnTheFlyFunding.Pending( + proposed = Seq( + OnTheFlyFunding.Proposal(createWillAdd(195_000_000 msat, paymentHash2, CltvExpiry(1000)), Upstream.Hot.Trampoline(upstream(2) :: upstream(3) :: Nil)), + ), + status = OnTheFlyFunding.Status.Funded(randomBytes32(), TxId(randomBytes32()), 3, 0 msat) + ) + ) + val pendingBob = Seq( + OnTheFlyFunding.Pending( + proposed = Seq( + OnTheFlyFunding.Proposal(createWillAdd(20_000 msat, paymentHash1, CltvExpiry(42)), upstream(0)), + ), + status = OnTheFlyFunding.Status.Funded(randomBytes32(), TxId(randomBytes32()), 11, 3_500 msat) + ), + OnTheFlyFunding.Pending( + proposed = Seq( + OnTheFlyFunding.Proposal(createWillAdd(24_000_000 msat, paymentHash2, CltvExpiry(800_000), Some(randomKey().publicKey)), Upstream.Local(UUID.randomUUID())), + ), + status = OnTheFlyFunding.Status.Funded(randomBytes32(), TxId(randomBytes32()), 0, 10_000 msat) + ) + ) + + assert(db.listPendingPayments().isEmpty) + assert(db.listPending(alice).isEmpty) + db.removePending(alice, paymentHash1) // no-op + + // Add pending proposals for Alice. + db.addPending(alice, pendingAlice(0)) + assert(db.listPending(alice) == Map(paymentHash1 -> pendingAlice(0))) + db.addPending(alice, pendingAlice(1).copy(status = OnTheFlyFunding.Status.Proposed(null))) + assert(db.listPending(alice) == Map(paymentHash1 -> pendingAlice(0))) + db.addPending(alice, pendingAlice(1)) + assert(db.listPending(alice) == Map(paymentHash1 -> pendingAlice(0), paymentHash2 -> pendingAlice(1))) + assert(db.listPendingPayments() == Map(alice -> Set(paymentHash1, paymentHash2))) + + // Add pending proposals for Bob. + assert(db.listPending(bob).isEmpty) + db.addPending(bob, pendingBob(0)) + db.addPending(bob, pendingBob(1)) + assert(db.listPending(alice) == Map(paymentHash1 -> pendingAlice(0), paymentHash2 -> pendingAlice(1))) + assert(db.listPending(bob) == Map(paymentHash1 -> pendingBob(0), paymentHash2 -> pendingBob(1))) + assert(db.listPendingPayments() == Map(alice -> Set(paymentHash1, paymentHash2), bob -> Set(paymentHash1, paymentHash2))) + + // Remove pending proposals that are completed. + db.removePending(alice, paymentHash1) + assert(db.listPending(alice) == Map(paymentHash2 -> pendingAlice(1))) + assert(db.listPending(bob) == Map(paymentHash1 -> pendingBob(0), paymentHash2 -> pendingBob(1))) + assert(db.listPendingPayments() == Map(alice -> Set(paymentHash2), bob -> Set(paymentHash1, paymentHash2))) + db.removePending(alice, paymentHash1) // no-op + db.removePending(bob, randomBytes32()) // no-op + assert(db.listPending(alice) == Map(paymentHash2 -> pendingAlice(1))) + assert(db.listPending(bob) == Map(paymentHash1 -> pendingBob(0), paymentHash2 -> pendingBob(1))) + assert(db.listPendingPayments() == Map(alice -> Set(paymentHash2), bob -> Set(paymentHash1, paymentHash2))) + db.removePending(alice, paymentHash2) + assert(db.listPending(alice).isEmpty) + assert(db.listPending(bob) == Map(paymentHash1 -> pendingBob(0), paymentHash2 -> pendingBob(1))) + assert(db.listPendingPayments() == Map(bob -> Set(paymentHash1, paymentHash2))) + db.removePending(bob, paymentHash2) + assert(db.listPending(bob) == Map(paymentHash1 -> pendingBob(0))) + assert(db.listPendingPayments() == Map(bob -> Set(paymentHash1))) + db.removePending(bob, paymentHash1) + assert(db.listPending(bob).isEmpty) + assert(db.listPendingPayments().isEmpty) + } + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 63ee150f2f..b81a532092 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -22,6 +22,7 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.testkit.TestProbe +import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} @@ -56,7 +57,8 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(wakeUpTimeout = 100 millis) else Alice.nodeParams + val nodeParams = Alice.nodeParams + .modify(_.onTheFlyFundingConfig.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/OpenChannelInterceptorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/OpenChannelInterceptorSpec.scala index 8f88a078fe..b8489ab2d9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/OpenChannelInterceptorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/OpenChannelInterceptorSpec.scala @@ -22,22 +22,25 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory -import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, SatoshiLong} +import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, OutPoint, SatoshiLong, Transaction, TxId, TxOut} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.blockchain.DummyOnChainWallet import fr.acinq.eclair.channel.ChannelTypes.UnsupportedChannelType +import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.states.ChannelStateTestsTags -import fr.acinq.eclair.channel.{ChannelAborted, ChannelTypes} import fr.acinq.eclair.io.OpenChannelInterceptor.{DefaultParams, OpenChannelInitiator, OpenChannelNonInitiator} import fr.acinq.eclair.io.Peer.{OpenChannelResponse, OutgoingMessage, SpawnChannelNonInitiator} -import fr.acinq.eclair.io.PeerSpec.createOpenChannelMessage +import fr.acinq.eclair.io.PeerSpec.{createOpenChannelMessage, createOpenDualFundedChannelMessage} import fr.acinq.eclair.io.PendingChannelsRateLimiter.AddOrRejectChannel -import fr.acinq.eclair.wire.protocol.{ChannelTlv, Error, IPAddress, LiquidityAds, NodeAddress, OpenChannel, OpenChannelTlv, TlvStream} -import fr.acinq.eclair.{AcceptOpenChannel, CltvExpiryDelta, FeatureSupport, Features, InitFeature, InterceptOpenChannelCommand, InterceptOpenChannelPlugin, InterceptOpenChannelReceived, MilliSatoshiLong, RejectOpenChannel, TestConstants, UnknownFeature, randomBytes32, randomKey} +import fr.acinq.eclair.transactions.Transactions.{ClosingTx, InputInfo} +import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec +import fr.acinq.eclair.wire.protocol.{ChannelReestablish, ChannelTlv, Error, IPAddress, LiquidityAds, NodeAddress, OpenChannel, OpenChannelTlv, Shutdown, TlvStream} +import fr.acinq.eclair.{AcceptOpenChannel, BlockHeight, CltvExpiryDelta, FeatureSupport, Features, InitFeature, InterceptOpenChannelCommand, InterceptOpenChannelPlugin, InterceptOpenChannelReceived, MilliSatoshiLong, RejectOpenChannel, TestConstants, UnknownFeature, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} +import scodec.bits.ByteVector import java.net.InetAddress import scala.concurrent.duration.DurationInt @@ -47,10 +50,12 @@ class OpenChannelInterceptorSpec extends ScalaTestWithActorTestKit(ConfigFactory val defaultParams: DefaultParams = DefaultParams(100 sat, 100000 msat, 100 msat, CltvExpiryDelta(288), 10) val openChannel: OpenChannel = createOpenChannelMessage() val remoteAddress: NodeAddress = IPAddress(InetAddress.getLoopbackAddress, 19735) - val acceptStaticRemoteKeyChannelsTag = "accept static_remote_key channels" val defaultFeatures: Features[InitFeature] = Features(Map[InitFeature, FeatureSupport](StaticRemoteKey -> Optional, AnchorOutputsZeroFeeHtlcTx -> Optional)) val staticRemoteKeyFeatures: Features[InitFeature] = Features(Map[InitFeature, FeatureSupport](StaticRemoteKey -> Optional)) + val acceptStaticRemoteKeyChannelsTag = "accept static_remote_key channels" + val noPlugin = "no plugin" + override def withFixture(test: OneArgTest): Outcome = { val peer = TestProbe[Any]() val peerConnection = TestProbe[Any]() @@ -58,12 +63,13 @@ class OpenChannelInterceptorSpec extends ScalaTestWithActorTestKit(ConfigFactory val wallet = new DummyOnChainWallet() val pendingChannelsRateLimiter = TestProbe[PendingChannelsRateLimiter.Command]() val plugin = new InterceptOpenChannelPlugin { + // @formatter:off override def name: String = "OpenChannelInterceptorPlugin" - override def openChannelInterceptor: ActorRef[InterceptOpenChannelCommand] = pluginInterceptor.ref + // @formatter:on } - val pluginParams = TestConstants.Alice.nodeParams.pluginParams :+ plugin - val nodeParams = TestConstants.Alice.nodeParams.copy(pluginParams = pluginParams) + val nodeParams = TestConstants.Alice.nodeParams + .modify(_.pluginParams).usingIf(!test.tags.contains(noPlugin))(_ :+ plugin) .modify(_.channelConf).usingIf(test.tags.contains(acceptStaticRemoteKeyChannelsTag))(_.copy(acceptIncomingStaticRemoteKeyChannels = true)) val eventListener = TestProbe[ChannelAborted]() @@ -75,6 +81,11 @@ class OpenChannelInterceptorSpec extends ScalaTestWithActorTestKit(ConfigFactory case class FixtureParam(openChannelInterceptor: ActorRef[OpenChannelInterceptor.Command], peer: TestProbe[Any], pluginInterceptor: TestProbe[InterceptOpenChannelCommand], pendingChannelsRateLimiter: TestProbe[PendingChannelsRateLimiter.Command], peerConnection: TestProbe[Any], eventListener: TestProbe[ChannelAborted], wallet: DummyOnChainWallet) + private def commitments(isOpener: Boolean = false): Commitments = { + val commitments = CommitmentsSpec.makeCommitments(500_000 msat, 400_000 msat, TestConstants.Alice.nodeParams.nodeId, remoteNodeId, announceChannel = false) + commitments.copy(params = commitments.params.copy(localParams = commitments.params.localParams.copy(isChannelOpener = isOpener, paysCommitTxFees = isOpener))) + } + test("reject channel open if timeout waiting for plugin to respond") { f => import f._ @@ -112,6 +123,33 @@ class OpenChannelInterceptorSpec extends ScalaTestWithActorTestKit(ConfigFactory assert(peer.expectMessageType[SpawnChannelNonInitiator].addFunding_opt.contains(addFunding)) } + test("add liquidity if on-the-fly funding is used", Tag(noPlugin)) { f => + import f._ + + val features = defaultFeatures.add(Features.SplicePrototype, FeatureSupport.Optional).add(Features.OnTheFlyFunding, FeatureSupport.Optional) + val requestFunding = LiquidityAds.RequestFunding(250_000 sat, TestConstants.defaultLiquidityRates.fundingRates.head, LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc(randomBytes32() :: Nil)) + val open = createOpenDualFundedChannelMessage().copy( + channelFlags = ChannelFlags(nonInitiatorPaysCommitFees = true, announceChannel = false), + tlvStream = TlvStream(ChannelTlv.RequestFundingTlv(requestFunding)) + ) + val openChannelNonInitiator = OpenChannelNonInitiator(remoteNodeId, Right(open), features, features, peerConnection.ref, remoteAddress) + openChannelInterceptor ! openChannelNonInitiator + pendingChannelsRateLimiter.expectMessageType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + // We check that all existing channels (if any) are closing before accepting the request. + val currentChannels = Seq( + Peer.ChannelInfo(TestProbe().ref, SHUTDOWN, DATA_SHUTDOWN(commitments(isOpener = true), Shutdown(randomBytes32(), ByteVector.empty), Shutdown(randomBytes32(), ByteVector.empty), None)), + Peer.ChannelInfo(TestProbe().ref, NEGOTIATING, DATA_NEGOTIATING(commitments(), Shutdown(randomBytes32(), ByteVector.empty), Shutdown(randomBytes32(), ByteVector.empty), List(Nil), None)), + Peer.ChannelInfo(TestProbe().ref, CLOSING, DATA_CLOSING(commitments(), BlockHeight(0), ByteVector.empty, Nil, ClosingTx(InputInfo(OutPoint(TxId(randomBytes32()), 5), TxOut(100_000 sat, Nil), Nil), Transaction(2, Nil, Nil, 0), None) :: Nil)), + Peer.ChannelInfo(TestProbe().ref, WAIT_FOR_REMOTE_PUBLISH_FUTURE_COMMITMENT, DATA_WAIT_FOR_REMOTE_PUBLISH_FUTURE_COMMITMENT(commitments(), ChannelReestablish(randomBytes32(), 0, 0, randomKey(), randomKey().publicKey))), + ) + peer.expectMessageType[Peer.GetPeerChannels].replyTo ! Peer.PeerChannels(remoteNodeId, currentChannels) + val result = peer.expectMessageType[SpawnChannelNonInitiator] + assert(!result.localParams.isChannelOpener) + assert(result.localParams.paysCommitTxFees) + assert(result.addFunding_opt.map(_.fundingAmount).contains(250_000 sat)) + assert(result.addFunding_opt.flatMap(_.rates_opt).contains(TestConstants.defaultLiquidityRates)) + } + test("continue channel open if no interceptor plugin registered and pending channels rate limiter accepts it") { f => import f._ @@ -195,6 +233,29 @@ class OpenChannelInterceptorSpec extends ScalaTestWithActorTestKit(ConfigFactory peer.expectMessageType[SpawnChannelNonInitiator] } + test("reject on-the-fly channel if another channel exists", Tag(noPlugin)) { f => + import f._ + + val features = defaultFeatures.add(Features.SplicePrototype, FeatureSupport.Optional).add(Features.OnTheFlyFunding, FeatureSupport.Optional) + val requestFunding = LiquidityAds.RequestFunding(250_000 sat, TestConstants.defaultLiquidityRates.fundingRates.head, LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc(randomBytes32() :: Nil)) + val open = createOpenDualFundedChannelMessage().copy( + channelFlags = ChannelFlags(nonInitiatorPaysCommitFees = true, announceChannel = false), + tlvStream = TlvStream(ChannelTlv.RequestFundingTlv(requestFunding)) + ) + val currentChannel = Seq( + Peer.ChannelInfo(TestProbe().ref, NORMAL, ChannelCodecsSpec.normal), + Peer.ChannelInfo(TestProbe().ref, OFFLINE, ChannelCodecsSpec.normal), + Peer.ChannelInfo(TestProbe().ref, SYNCING, ChannelCodecsSpec.normal), + ) + currentChannel.foreach(channel => { + val openChannelNonInitiator = OpenChannelNonInitiator(remoteNodeId, Right(open), features, features, peerConnection.ref, remoteAddress) + openChannelInterceptor ! openChannelNonInitiator + pendingChannelsRateLimiter.expectMessageType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + peer.expectMessageType[Peer.GetPeerChannels].replyTo ! Peer.PeerChannels(remoteNodeId, Seq(channel)) + assert(peer.expectMessageType[OutgoingMessage].msg.asInstanceOf[Error].channelId == open.temporaryChannelId) + }) + } + test("don't spawn a wumbo channel if wumbo feature isn't enabled", Tag(ChannelStateTestsTags.DisableWumbo)) { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala index 37f7a0e92c..14b227844c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/SwitchboardSpec.scala @@ -6,15 +6,17 @@ import akka.testkit.{TestActorRef, TestProbe} import fr.acinq.bitcoin.scalacompat.ByteVector64 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants._ -import fr.acinq.eclair.channel.{ChannelIdAssigned, PersistentChannelData} +import fr.acinq.eclair.channel.{ChannelIdAssigned, PersistentChannelData, Upstream} import fr.acinq.eclair.io.Peer.PeerNotFound import fr.acinq.eclair.io.Switchboard._ +import fr.acinq.eclair.payment.relay.{OnTheFlyFunding, OnTheFlyFundingSpec} import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Features, InitFeature, NodeParams, TestKitBaseClass, TimestampSecondLong, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, Features, InitFeature, MilliSatoshiLong, NodeParams, TestKitBaseClass, TimestampSecondLong, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits._ +import java.util.UUID import scala.concurrent.duration.DurationInt class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { @@ -33,6 +35,23 @@ class SwitchboardSpec extends TestKitBaseClass with AnyFunSuiteLike { peer.expectMsg(Peer.Init(Set(ChannelCodecsSpec.normal))) } + test("on initialization create peers with pending on-the-fly funding proposals") { + val nodeParams = Alice.nodeParams + // We don't have channels yet with that remote peer, but we have a pending on-the-fly funding proposal. + val remoteNodeId = randomKey().publicKey + val pendingOnTheFly = OnTheFlyFunding.Pending( + proposed = Seq(OnTheFlyFunding.Proposal(OnTheFlyFundingSpec.createWillAdd(100_000 msat, randomBytes32(), CltvExpiry(600)), Upstream.Local(UUID.randomUUID()))), + status = OnTheFlyFundingSpec.createStatus() + ) + nodeParams.db.onTheFlyFunding.addPending(remoteNodeId, pendingOnTheFly) + + val (probe, peer) = (TestProbe(), TestProbe()) + val switchboard = TestActorRef(new Switchboard(nodeParams, FakePeerFactory(probe, peer))) + switchboard ! Switchboard.Init(Nil) + probe.expectMsg(remoteNodeId) + peer.expectMsg(Peer.Init(Set.empty)) + } + test("when connecting to a new peer forward Peer.Connect to it") { val nodeParams = Alice.nodeParams val (probe, peer) = (TestProbe(), TestProbe()) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index d163c04a48..f080a99bf3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -315,7 +315,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -367,7 +367,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -408,7 +408,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None, None) val Right(ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -467,7 +467,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid trampoline onion to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e.copy(payload = trampolinePacket_e.payload.reverse))) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None, None) val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -617,7 +617,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer total amount doesn't match the inner amount). val invalidTotalAmount = inner_c.amountToForward - 1.msat val recipient_e = ClearRecipient(e, Features.empty, invalidTotalAmount, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -633,7 +633,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer expiry doesn't match the inner expiry). val invalidExpiry = inner_c.outgoingCltv - CltvExpiryDelta(12) val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, invalidExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 28be0a2715..09895a8efc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -18,11 +18,10 @@ package fr.acinq.eclair.payment import akka.Done import akka.actor.ActorRef -import akka.actor.typed.scaladsl.adapter.ClassicActorRefOps import akka.event.LoggingAdapter import akka.testkit.TestProbe import com.softwaremill.quicklens.{ModifyPimp, QuicklensAt} -import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, OutPoint, SatoshiLong, Script, Transaction, TxIn, TxOut} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, OutPoint, SatoshiLong, Script, Transaction, TxId, TxIn, TxOut} import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.WatchTxConfirmedTriggered import fr.acinq.eclair.channel.Helpers.Closing import fr.acinq.eclair.channel._ @@ -30,7 +29,8 @@ import fr.acinq.eclair.channel.states.ChannelStateTestsBase import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentType} import fr.acinq.eclair.payment.OutgoingPaymentPacket.buildOutgoingPayment import fr.acinq.eclair.payment.PaymentPacketSpec._ -import fr.acinq.eclair.payment.relay.{PostRestartHtlcCleaner, Relayer} +import fr.acinq.eclair.payment.relay.OnTheFlyFundingSpec._ +import fr.acinq.eclair.payment.relay.{OnTheFlyFunding, PostRestartHtlcCleaner, Relayer} import fr.acinq.eclair.payment.send.SpontaneousRecipient import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate import fr.acinq.eclair.router.Router.Route @@ -153,6 +153,43 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit channel.expectNoMessage(100 millis) } + test("keep upstream HTLCs that weren't relayed downstream but use on-the-fly funding") { f => + import f._ + + val channelPaymentHash = randomBytes32() + val trampolinePaymentHash = randomBytes32() + + val htlc_ab_1 = Seq( + buildHtlcIn(0, channelId_ab_1, channelPaymentHash), + buildHtlcIn(1, channelId_ab_1, trampolinePaymentHash), + ) + val htlc_ab_2 = Seq( + buildHtlcIn(2, channelId_ab_2, trampolinePaymentHash), + ) + + val channels = Seq( + ChannelCodecsSpec.makeChannelDataNormal(htlc_ab_1, Map(0L -> Origin.Cold(Upstream.Local(UUID.randomUUID())), 1L -> Origin.Cold(Upstream.Local(UUID.randomUUID())))), + ChannelCodecsSpec.makeChannelDataNormal(htlc_ab_2, Map(2L -> Origin.Cold(Upstream.Local(UUID.randomUUID())))) + ) + + // The HTLCs were not relayed yet, but they match pending on-the-fly funding proposals. + val upstreamChannel = Upstream.Hot.Channel(htlc_ab_1.head.add, TimestampMilli.now(), a) + val downstreamChannel = OnTheFlyFunding.Pending(Seq(OnTheFlyFunding.Proposal(createWillAdd(100_000 msat, channelPaymentHash, CltvExpiry(500)), upstreamChannel)), createStatus()) + val upstreamTrampoline = Upstream.Hot.Trampoline(List(htlc_ab_1.last, htlc_ab_2.head).map(htlc => Upstream.Hot.Channel(htlc.add, TimestampMilli.now(), a))) + val downstreamTrampoline = OnTheFlyFunding.Pending(Seq(OnTheFlyFunding.Proposal(createWillAdd(100_000 msat, trampolinePaymentHash, CltvExpiry(500)), upstreamTrampoline)), createStatus()) + nodeParams.db.onTheFlyFunding.addPending(randomKey().publicKey, downstreamChannel) + nodeParams.db.onTheFlyFunding.addPending(randomKey().publicKey, downstreamTrampoline) + + val channel = TestProbe() + val (relayer, _) = f.createRelayer(nodeParams) + relayer ! PostRestartHtlcCleaner.Init(channels) + // Upstream channels go back to the NORMAL state, but HTLCs are kept because the on-the-fly proposal was funded. + system.eventStream.publish(ChannelStateChanged(channel.ref, channels.head.commitments.channelId, system.deadLetters, a, OFFLINE, NORMAL, Some(channels(0).commitments))) + channel.expectNoMessage(100 millis) + system.eventStream.publish(ChannelStateChanged(channel.ref, channels(1).commitments.channelId, system.deadLetters, a, OFFLINE, NORMAL, Some(channels(1).commitments))) + channel.expectNoMessage(100 millis) + } + test("clean up upstream HTLCs for which we're the final recipient") { f => import f._ @@ -512,7 +549,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit import f._ val htlc_ab = buildHtlcIn(0, channelId_ab_1, paymentHash1, blinded = true) - val upstream = Upstream.Cold.Channel(htlc_ab.add.channelId, htlc_ab.add.id, htlc_ab.add.amountMsat) + val upstream = Upstream.Cold.Channel(htlc_ab.add) val htlc_bc = buildHtlcOut(6, channelId_bc_1, paymentHash1, blinded = true) val data_ab = ChannelCodecsSpec.makeChannelDataNormal(Seq(htlc_ab), Map.empty) val data_bc = ChannelCodecsSpec.makeChannelDataNormal(Seq(htlc_bc), Map(6L -> Origin.Cold(upstream))) @@ -624,6 +661,54 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit eventListener.expectNoMessage(100 millis) } + test("ignore relayed htlc-fail for on-the-fly funding") { f => + import f._ + + // Upstream HTLCs. + val htlc_ab = Seq( + buildHtlcIn(0, channelId_ab_1, paymentHash1), // not relayed + buildHtlcIn(1, channelId_ab_1, paymentHash1), // channel relayed + buildHtlcIn(2, channelId_ab_1, paymentHash2), // trampoline relayed + ) + // The first upstream HTLC was not relayed but has a pending on-the-fly funding proposal. + val downstreamChannel = OnTheFlyFunding.Pending(Seq(OnTheFlyFunding.Proposal(createWillAdd(100_000 msat, paymentHash1, CltvExpiry(500)), Upstream.Hot.Channel(htlc_ab(0).add, TimestampMilli.now(), a))), createStatus()) + nodeParams.db.onTheFlyFunding.addPending(randomKey().publicKey, downstreamChannel) + // The other two HTLCs were relayed after completing on-the-fly funding. + val htlc_bc = Seq( + buildHtlcOut(1, channelId_bc_1, paymentHash1).modify(_.add.tlvStream).setTo(TlvStream(UpdateAddHtlcTlv.FundingFeeTlv(LiquidityAds.FundingFee(2500 msat, TxId(randomBytes32()))))), // channel relayed + buildHtlcOut(2, channelId_bc_1, paymentHash2).modify(_.add.tlvStream).setTo(TlvStream(UpdateAddHtlcTlv.FundingFeeTlv(LiquidityAds.FundingFee(1500 msat, TxId(randomBytes32()))))), // trampoline relayed + ) + + val upstreamChannel = Upstream.Cold.Channel(htlc_ab(1).add) + val upstreamTrampoline = Upstream.Cold.Trampoline(Upstream.Cold.Channel(htlc_ab(2).add) :: Nil) + val data_ab = ChannelCodecsSpec.makeChannelDataNormal(htlc_ab, Map.empty) + val data_bc = ChannelCodecsSpec.makeChannelDataNormal(htlc_bc, Map(1L -> Origin.Cold(upstreamChannel), 2L -> Origin.Cold(upstreamTrampoline))) + + val (relayer, _) = f.createRelayer(nodeParams) + relayer ! PostRestartHtlcCleaner.Init(Seq(data_ab, data_bc)) + + // HTLC failures are not relayed upstream, as we will retry until we reach the HTLC timeout. + sender.send(relayer, buildForwardFail(htlc_bc(0).add, Upstream.Cold.Channel(htlc_ab(0).add))) + sender.send(relayer, buildForwardFail(htlc_bc(0).add, upstreamChannel)) + sender.send(relayer, buildForwardOnChainFail(htlc_bc(0).add, upstreamChannel)) + sender.send(relayer, buildForwardFail(htlc_bc(1).add, upstreamTrampoline)) + sender.send(relayer, buildForwardOnChainFail(htlc_bc(1).add, upstreamTrampoline)) + register.expectNoMessage(100 millis) + + // HTLC fulfills are relayed upstream as soon as available. + sender.send(relayer, buildForwardFulfill(htlc_bc(0).add, upstreamChannel, preimage1)) + val fulfill1 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fulfill1.channelId == channelId_ab_1) + assert(fulfill1.message.id == 1) + assert(fulfill1.message.r == preimage1) + sender.send(relayer, buildForwardFulfill(htlc_bc(1).add, upstreamTrampoline, preimage2)) + val fulfill2 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fulfill2.channelId == channelId_ab_1) + assert(fulfill2.message.id == 2) + assert(fulfill2.message.r == preimage2) + register.expectNoMessage(100 millis) + } + test("relayed standard->non-standard HTLC is retained") { f => import f._ @@ -763,7 +848,7 @@ object PostRestartHtlcCleanerSpec { buildHtlcIn(0, channelId_ab_1, paymentHash1) ) - val upstream_1 = Upstream.Cold.Channel(htlc_ab_1.head.add.channelId, htlc_ab_1.head.add.id, htlc_ab_1.head.add.amountMsat) + val upstream_1 = Upstream.Cold.Channel(htlc_ab_1.head.add) val htlc_bc_1 = Seq( buildHtlcOut(6, channelId_bc_1, paymentHash1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index 6d2f38b838..decb98ead7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -28,6 +28,7 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, import fr.acinq.eclair.Features.ScidAlias import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw +import fr.acinq.eclair.channel.Register.ForwardNodeId import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.{Peer, Switchboard} @@ -52,12 +53,15 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a import ChannelRelayerSpec._ val wakeUpTimeout = "wake_up_timeout" + val onTheFlyFunding = "on_the_fly_funding" case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... - val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(wakeUpTimeout = 100 millis) else TestConstants.Bob.nodeParams + val nodeParams = TestConstants.Bob.nodeParams + .modify(_.onTheFlyFundingConfig.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) + .modify(_.features.activated).usingIf(test.tags.contains(onTheFlyFunding))(_ + (Features.OnTheFlyFunding -> FeatureSupport.Optional)) val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { @@ -195,6 +199,88 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } + test("relay blinded payment (on-the-fly funding)", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = false) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + // We try to use existing channels, but they don't have enough liquidity. + val fwd = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) + fwd.message.replyTo ! RES_ADD_FAILED(fwd.message, InsufficientFunds(channelIds(realScid1), outgoingAmount, 100 sat, 0 sat, 0 sat), Some(u.channelUpdate)) + + val fwdNodeId = register.expectMessageType[ForwardNodeId[Peer.ProposeOnTheFlyFunding]] + assert(fwdNodeId.nodeId == outgoingNodeId) + assert(fwdNodeId.message.nextBlindingKey_opt.nonEmpty) + assert(fwdNodeId.message.amount == outgoingAmount) + assert(fwdNodeId.message.expiry == outgoingExpiry) + } + + test("relay blinded payment (on-the-fly funding failed)", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = false) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + // We don't have any channel, so we attempt on-the-fly funding, but the peer is not available. + val fwdNodeId = register.expectMessageType[ForwardNodeId[Peer.ProposeOnTheFlyFunding]] + assert(fwdNodeId.nodeId == outgoingNodeId) + fwdNodeId.replyTo ! Register.ForwardNodeIdFailure(fwdNodeId) + expectFwdFail(register, r.add.channelId, CMD_FAIL_MALFORMED_HTLC(r.add.id, Sphinx.hash(r.add.onionRoutingPacket), InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)).code, commit = true)) + } + + test("relay blinded payment (on-the-fly funding not attempted)", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = false) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + // We try to use existing channels, but they reject the payment for a reason that isn't tied to the liquidity. + val fwd = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) + fwd.message.replyTo ! RES_ADD_FAILED(fwd.message, TooManyAcceptedHtlcs(channelIds(realScid1), 10), Some(u.channelUpdate)) + + // We fail without attempting on-the-fly funding. + expectFwdFail(register, r.add.channelId, CMD_FAIL_MALFORMED_HTLC(r.add.id, Sphinx.hash(r.add.onionRoutingPacket), InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)).code, commit = true)) + } + test("relay with retries") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 3fe24a0985..70aa9125f9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -68,6 +68,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ val wakeUpTimeout = "wake_up_timeout" + val onTheFlyFunding = "on_the_fly_funding" case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { @@ -97,7 +98,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires - .modify(_.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) + .modify(_.onTheFlyFundingConfig.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) + .modify(_.features.activated).usingIf(test.tags.contains(onTheFlyFunding))(_ + (Features.OnTheFlyFunding -> FeatureSupport.Optional)) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") @@ -251,7 +253,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) @@ -348,7 +350,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // upstream payment relayed val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -556,7 +558,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayer ! NodeRelay.Relay(incomingMultiPart.last, randomKey().publicKey) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -613,6 +615,81 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } + // The two tests below are disabled by default, since there is no default mechanism to flag the next trampoline node + // as being a wallet node. Feature branches that support wallet software should restore those tests and flag the + // outgoing node_id as being a wallet node. + ignore("relay incoming multi-part payment with on-the-fly funding", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + // Receive an upstream multi-part payment. + val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) + incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) + + // The remote node is a wallet node: we wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + validateOutgoingPayment(outgoingPayment) + + // The outgoing payment fails because we don't have enough balance: we trigger on-the-fly funding. + outgoingPayment.replyTo ! PaymentFailed(relayId, paymentHash, LocalFailure(outgoingAmount, Nil, BalanceTooLow) :: Nil) + val fwd = register.expectMessageType[Register.ForwardNodeId[Peer.ProposeOnTheFlyFunding]] + assert(fwd.nodeId == outgoingNodeId) + assert(fwd.message.nextBlindingKey_opt.isEmpty) + assert(fwd.message.onion.payload.size == PaymentOnionCodecs.paymentOnionPayloadLength) + // We verify that the next node is able to decrypt the onion that we will send in will_add_htlc. + val dummyAdd = UpdateAddHtlc(randomBytes32(), 0, fwd.message.amount, fwd.message.paymentHash, fwd.message.expiry, fwd.message.onion, None, None) + val Right(incoming) = IncomingPaymentPacket.decrypt(dummyAdd, outgoingNodeKey, nodeParams.features) + assert(incoming.isInstanceOf[IncomingPaymentPacket.FinalPacket]) + val finalPayload = incoming.asInstanceOf[IncomingPaymentPacket.FinalPacket].payload.asInstanceOf[FinalPayload.Standard] + assert(finalPayload.amount == fwd.message.amount) + assert(finalPayload.expiry == fwd.message.expiry) + assert(finalPayload.paymentSecret == paymentSecret) + + // Once on-the-fly funding has been proposed, the payment isn't our responsibility anymore. + fwd.message.replyTo ! Peer.ProposeOnTheFlyFundingResponse.Proposed + parent.expectMessageType[NodeRelayer.RelayComplete] + } + + ignore("relay incoming multi-part payment with on-the-fly funding (non-liquidity failure)", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + // Receive an upstream multi-part payment. + val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) + incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) + + // The remote node is a wallet node: we wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + validateOutgoingPayment(outgoingPayment) + + // The outgoing payment fails, but it's not a liquidity issue. + outgoingPayment.replyTo ! PaymentFailed(relayId, paymentHash, RemoteFailure(outgoingAmount, Nil, Sphinx.DecryptedFailurePacket(outgoingNodeId, TemporaryNodeFailure())) :: Nil) + incomingMultiPart.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) + } + parent.expectMessageType[NodeRelayer.RelayComplete] + } + test("relay to non-trampoline recipient supporting multi-part") { f => import f._ @@ -627,7 +704,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.nodeId == outgoingNodeId) assert(outgoingPayment.recipient.totalAmount == outgoingAmount) @@ -671,7 +748,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList)) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.recipient.nodeId == outgoingNodeId) assert(outgoingPayment.amount == outgoingAmount) @@ -729,7 +806,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -762,7 +839,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -804,7 +881,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -851,6 +928,76 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } } + test("relay to blinded path with on-the-fly funding", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + + // The outgoing payment fails because we don't have enough balance: we trigger on-the-fly funding. + outgoingPayment.replyTo ! PaymentFailed(relayId, paymentHash, LocalFailure(outgoingAmount, Nil, BalanceTooLow) :: Nil) + val fwd = register.expectMessageType[Register.ForwardNodeId[Peer.ProposeOnTheFlyFunding]] + assert(fwd.nodeId == outgoingNodeId) + assert(fwd.message.nextBlindingKey_opt.nonEmpty) + assert(fwd.message.onion.payload.size == PaymentOnionCodecs.paymentOnionPayloadLength) + // We verify that the next node is able to decrypt the onion that we will send in will_add_htlc. + val dummyAdd = UpdateAddHtlc(randomBytes32(), 0, fwd.message.amount, fwd.message.paymentHash, fwd.message.expiry, fwd.message.onion, fwd.message.nextBlindingKey_opt, None) + val Right(incoming) = IncomingPaymentPacket.decrypt(dummyAdd, outgoingNodeKey, nodeParams.features) + assert(incoming.isInstanceOf[IncomingPaymentPacket.FinalPacket]) + val finalPayload = incoming.asInstanceOf[IncomingPaymentPacket.FinalPacket].payload.asInstanceOf[FinalPayload.Blinded] + assert(finalPayload.amount == fwd.message.amount) + assert(finalPayload.expiry == fwd.message.expiry) + assert(finalPayload.pathId == hex"deadbeef") + + // Once on-the-fly funding has been proposed, the payment isn't our responsibility anymore. + fwd.message.replyTo ! Peer.ProposeOnTheFlyFundingResponse.Proposed + parent.expectMessageType[NodeRelayer.RelayComplete] + } + + test("relay to blinded path with on-the-fly funding failure", Tag(onTheFlyFunding)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + + // The outgoing payment fails because we don't have enough balance: we trigger on-the-fly funding, but can't reach our peer. + outgoingPayment.replyTo ! PaymentFailed(relayId, paymentHash, LocalFailure(outgoingAmount, Nil, BalanceTooLow) :: Nil) + val fwd = register.expectMessageType[Register.ForwardNodeId[Peer.ProposeOnTheFlyFunding]] + fwd.message.replyTo ! Peer.ProposeOnTheFlyFundingResponse.NotAvailable("peer disconnected") + // We fail the payments immediately since the recipient isn't available. + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) + } + } + test("relay to compact blinded paths") { f => import f._ @@ -865,7 +1012,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)).toList), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/OnTheFlyFundingSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/OnTheFlyFundingSpec.scala new file mode 100644 index 0000000000..4030f5bea8 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/OnTheFlyFundingSpec.scala @@ -0,0 +1,876 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.payment.relay + +import akka.actor.typed.scaladsl.adapter._ +import akka.actor.{ActorContext, ActorRef} +import akka.testkit.{TestFSMRef, TestProbe} +import com.softwaremill.quicklens.ModifyPimp +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, Satoshi, SatoshiLong, TxId} +import fr.acinq.eclair.blockchain.{CurrentBlockHeight, DummyOnChainWallet} +import fr.acinq.eclair.channel.Upstream.Hot +import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.Peer._ +import fr.acinq.eclair.io.PendingChannelsRateLimiter.AddOrRejectChannel +import fr.acinq.eclair.io.{Peer, PeerConnection, PendingChannelsRateLimiter} +import fr.acinq.eclair.wire.protocol +import fr.acinq.eclair.wire.protocol._ +import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, TestKitBaseClass, TimestampMilli, ToMilliSatoshiConversion, UInt64, randomBytes, randomBytes32, randomKey, randomLong} +import org.scalatest.Outcome +import org.scalatest.funsuite.FixtureAnyFunSuiteLike + +import java.util.UUID +import scala.concurrent.duration.DurationInt + +class OnTheFlyFundingSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike { + + import OnTheFlyFundingSpec._ + + val remoteFeatures = Features( + Features.StaticRemoteKey -> FeatureSupport.Optional, + Features.AnchorOutputsZeroFeeHtlcTx -> FeatureSupport.Optional, + Features.DualFunding -> FeatureSupport.Optional, + Features.SplicePrototype -> FeatureSupport.Optional, + Features.OnTheFlyFunding -> FeatureSupport.Optional, + ) + + case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channel: TestProbe, register: TestProbe, rateLimiter: TestProbe, probe: TestProbe) { + def connect(peer: TestFSMRef[Peer.State, Peer.Data, Peer], remoteInit: protocol.Init = protocol.Init(remoteFeatures.initFeatures()), channelCount: Int = 0): Unit = { + val localInit = protocol.Init(nodeParams.features.initFeatures()) + val address = NodeAddress.fromParts("0.0.0.0", 9735).get + peer ! PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, address, outgoing = true, localInit, remoteInit) + peerConnection.expectMsgType[RecommendedFeerates] + (0 until channelCount).foreach(_ => channel.expectMsgType[INPUT_RECONNECTED]) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + val peerInfo = probe.expectMsgType[Peer.PeerInfo] + assert(peerInfo.nodeId == remoteNodeId) + assert(peerInfo.state == Peer.CONNECTED) + } + + def openChannel(fundingAmount: Satoshi): ByteVector32 = { + peer ! Peer.OpenChannel(remoteNodeId, fundingAmount, None, None, None, None, None, None, None) + val temporaryChannelId = channel.expectMsgType[INPUT_INIT_CHANNEL_INITIATOR].temporaryChannelId + val channelId = randomBytes32() + peer ! ChannelIdAssigned(channel.ref, remoteNodeId, temporaryChannelId, channelId) + peerConnection.expectMsgType[PeerConnection.DoSync] + channelId + } + + def disconnect(channelCount: Int = 0): Unit = { + peer ! Peer.ConnectionDown(peerConnection.ref) + (0 until channelCount).foreach(_ => channel.expectMsg(INPUT_DISCONNECTED)) + } + + def createProposal(amount: MilliSatoshi, expiry: CltvExpiry, paymentHash: ByteVector32 = randomBytes32(), upstream: Upstream.Hot = Upstream.Local(UUID.randomUUID())): ProposeOnTheFlyFunding = { + val blindingKey = upstream match { + case u: Upstream.Hot.Channel if u.add.blinding_opt.nonEmpty => Some(randomKey().publicKey) + case u: Upstream.Hot.Trampoline if u.received.exists(_.add.blinding_opt.nonEmpty) => Some(randomKey().publicKey) + case _ => None + } + ProposeOnTheFlyFunding(probe.ref, amount, paymentHash, expiry, TestConstants.emptyOnionPacket, blindingKey, upstream) + } + + def proposeFunding(amount: MilliSatoshi, expiry: CltvExpiry, paymentHash: ByteVector32 = randomBytes32(), upstream: Upstream.Hot = Upstream.Local(UUID.randomUUID())): WillAddHtlc = { + val proposal = createProposal(amount, expiry, paymentHash, upstream) + peer ! proposal + val willAdd = peerConnection.expectMsgType[WillAddHtlc] + assert(willAdd.amount == amount) + assert(willAdd.expiry == expiry) + assert(willAdd.paymentHash == paymentHash) + probe.expectMsg(ProposeOnTheFlyFundingResponse.Proposed) + willAdd + } + + /** This should be used when the sender is buggy and keeps adding HTLCs after the funding proposal has been accepted. */ + def proposeExtraFunding(amount: MilliSatoshi, expiry: CltvExpiry, paymentHash: ByteVector32 = randomBytes32(), upstream: Upstream.Hot = Upstream.Local(UUID.randomUUID())): Unit = { + val proposal = createProposal(amount, expiry, paymentHash, upstream) + peer ! proposal + probe.expectMsg(ProposeOnTheFlyFundingResponse.Proposed) + peerConnection.expectNoMessage(100 millis) + } + + def signLiquidityPurchase(amount: Satoshi, + paymentDetails: LiquidityAds.PaymentDetails, + channelId: ByteVector32 = randomBytes32(), + fees: LiquidityAds.Fees = LiquidityAds.Fees(0 sat, 0 sat), + fundingTxIndex: Long = 0, + htlcMinimum: MilliSatoshi = 1 msat): LiquidityPurchaseSigned = { + val purchase = LiquidityAds.Purchase.Standard(amount, fees, paymentDetails) + val event = LiquidityPurchaseSigned(channelId, TxId(randomBytes32()), fundingTxIndex, htlcMinimum, purchase) + peer ! event + event + } + + def makeChannelData(htlcMinimum: MilliSatoshi = 1 msat, localChanges: LocalChanges = LocalChanges(Nil, Nil, Nil)): DATA_NORMAL = { + val commitments = CommitmentsSpec.makeCommitments(500_000_000 msat, 500_000_000 msat, nodeParams.nodeId, remoteNodeId, announceChannel = false) + .modify(_.params.remoteParams.htlcMinimum).setTo(htlcMinimum) + .modify(_.changes.localChanges).setTo(localChanges) + DATA_NORMAL(commitments, ShortIds(RealScidStatus.Unknown, Alias(42), None), None, null, None, None, None, SpliceStatus.NoSplice) + } + } + + case class FakeChannelFactory(remoteNodeId: PublicKey, channel: TestProbe) extends ChannelFactory { + override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = { + assert(remoteNodeId == remoteNodeId) + channel.ref + } + } + + override protected def withFixture(test: OneArgTest): Outcome = { + val nodeParams = TestConstants.Alice.nodeParams + .modify(_.features.activated).using(_ + (Features.AnchorOutputsZeroFeeHtlcTx -> FeatureSupport.Optional)) + .modify(_.features.activated).using(_ + (Features.DualFunding -> FeatureSupport.Optional)) + .modify(_.features.activated).using(_ + (Features.SplicePrototype -> FeatureSupport.Optional)) + .modify(_.features.activated).using(_ + (Features.OnTheFlyFunding -> FeatureSupport.Optional)) + val remoteNodeId = randomKey().publicKey + val register = TestProbe() + val channel = TestProbe() + val peerConnection = TestProbe() + val rateLimiter = TestProbe() + val probe = TestProbe() + val peer = TestFSMRef(new Peer(nodeParams, remoteNodeId, new DummyOnChainWallet(), FakeChannelFactory(remoteNodeId, channel), TestProbe().ref, register.ref, TestProbe().ref, rateLimiter.ref)) + peer ! Peer.Init(Set.empty) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, remoteNodeId, peer, peerConnection, channel, register, rateLimiter, probe))) + } + + test("ignore requests when peer doesn't support on-the-fly funding") { f => + import f._ + + connect(peer, remoteInit = protocol.Init(Features.empty)) + peer ! createProposal(100_000_000 msat, CltvExpiry(561)) + probe.expectMsgType[ProposeOnTheFlyFundingResponse.NotAvailable] + } + + test("ignore requests when disconnected") { f => + import f._ + + peer ! createProposal(100_000_000 msat, CltvExpiry(561)) + probe.expectMsgType[ProposeOnTheFlyFundingResponse.NotAvailable] + } + + test("receive remote failure") { f => + import f._ + + connect(peer) + + val paymentHash = randomBytes32() + val upstream1 = upstreamChannel(75_000_000 msat, CltvExpiry(561), paymentHash) + val willAdd1 = proposeFunding(70_000_000 msat, CltvExpiry(550), paymentHash, upstream1) + val upstream2 = upstreamChannel(80_000_000 msat, CltvExpiry(561), paymentHash, blinded = true) + val willAdd2 = proposeFunding(75_000_000 msat, CltvExpiry(550), paymentHash, upstream2) + val upstream3 = upstreamChannel(50_000_000 msat, CltvExpiry(561), paymentHash) + val willAdd3 = proposeFunding(50_000_000 msat, CltvExpiry(550), paymentHash, upstream3) + val upstream4 = Upstream.Hot.Trampoline(List( + upstreamChannel(50_000_000 msat, CltvExpiry(561), paymentHash), + upstreamChannel(50_000_000 msat, CltvExpiry(561), paymentHash), + )) + val willAdd4 = proposeFunding(100_000_000 msat, CltvExpiry(550), paymentHash, upstream4) + val upstream5 = Upstream.Hot.Trampoline(List( + upstreamChannel(50_000_000 msat, CltvExpiry(561), paymentHash, blinded = true), + upstreamChannel(50_000_000 msat, CltvExpiry(561), paymentHash, blinded = true), + )) + val willAdd5 = proposeFunding(100_000_000 msat, CltvExpiry(550), paymentHash, upstream5) + + val fail1 = WillFailHtlc(willAdd1.id, paymentHash, randomBytes(42)) + peerConnection.send(peer, fail1) + val fwd1 = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd1.channelId == upstream1.add.channelId) + assert(fwd1.message.id == upstream1.add.id) + assert(fwd1.message.reason == Left(fail1.reason)) + register.expectNoMessage(100 millis) + + val fail2 = WillFailHtlc(willAdd2.id, paymentHash, randomBytes(50)) + peerConnection.send(peer, fail2) + val fwd2 = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd2.channelId == upstream2.add.channelId) + assert(fwd2.message.id == upstream2.add.id) + assert(fwd2.message.reason == Right(InvalidOnionBlinding(Sphinx.hash(upstream2.add.onionRoutingPacket)))) + + val fail3 = WillFailMalformedHtlc(willAdd3.id, paymentHash, randomBytes32(), InvalidOnionHmac(randomBytes32()).code) + peerConnection.send(peer, fail3) + val fwd3 = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd3.channelId == upstream3.add.channelId) + assert(fwd3.message.id == upstream3.add.id) + assert(fwd3.message.reason == Right(InvalidOnionHmac(fail3.onionHash))) + + val fail4 = WillFailHtlc(willAdd4.id, paymentHash, randomBytes(75)) + peerConnection.send(peer, fail4) + upstream4.received.map(_.add).foreach(add => { + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == add.channelId) + assert(fwd.message.id == add.id) + assert(fwd.message.reason == Right(TemporaryNodeFailure())) + }) + + val fail5 = WillFailHtlc(willAdd5.id, paymentHash, randomBytes(75)) + peerConnection.send(peer, fail5) + upstream5.received.map(_.add).foreach(add => { + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == add.channelId) + assert(fwd.message.id == add.id) + assert(fwd.message.reason == Right(TemporaryNodeFailure())) + }) + } + + test("proposed on-the-fly funding timeout") { f => + import f._ + + connect(peer) + + // A first funding is proposed coming from two upstream channels. + val paymentHash1 = randomBytes32() + val upstream1 = Seq( + upstreamChannel(60_000_000 msat, CltvExpiry(561), paymentHash1, blinded = true), + upstreamChannel(45_000_000 msat, CltvExpiry(561), paymentHash1, blinded = true), + ) + proposeFunding(50_000_000 msat, CltvExpiry(550), paymentHash1, upstream1.head) + proposeFunding(40_000_000 msat, CltvExpiry(550), paymentHash1, upstream1.last) + + // A second funding is proposed coming from a trampoline payment. + val paymentHash2 = randomBytes32() + val upstream2 = Upstream.Hot.Trampoline(List( + upstreamChannel(60_000_000 msat, CltvExpiry(561), paymentHash2), + upstreamChannel(45_000_000 msat, CltvExpiry(561), paymentHash2), + )) + proposeFunding(100_000_000 msat, CltvExpiry(550), paymentHash2, upstream2) + + // A third funding is signed coming from a trampoline payment. + val paymentHash3 = randomBytes32() + val upstream3 = Upstream.Hot.Trampoline(List( + upstreamChannel(60_000_000 msat, CltvExpiry(561), paymentHash3), + upstreamChannel(45_000_000 msat, CltvExpiry(561), paymentHash3), + )) + proposeFunding(100_000_000 msat, CltvExpiry(550), paymentHash3, upstream3) + signLiquidityPurchase(100_000 sat, LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc(paymentHash3 :: Nil)) + + // The funding timeout is reached, unsigned proposals are failed upstream. + peer ! OnTheFlyFundingTimeout(paymentHash1) + upstream1.foreach(u => { + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == u.add.channelId) + assert(fwd.message.id == u.add.id) + assert(fwd.message.reason == Right(InvalidOnionBlinding(Sphinx.hash(u.add.onionRoutingPacket)))) + assert(fwd.message.commit) + }) + peerConnection.expectMsgType[Warning] + + peer ! OnTheFlyFundingTimeout(paymentHash2) + upstream2.received.foreach(u => { + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == u.add.channelId) + assert(fwd.message.id == u.add.id) + assert(fwd.message.reason == Right(UnknownNextPeer())) + assert(fwd.message.commit) + }) + peerConnection.expectMsgType[Warning] + + peer ! OnTheFlyFundingTimeout(paymentHash3) + register.expectNoMessage(100 millis) + peerConnection.expectNoMessage(100 millis) + } + + test("proposed on-the-fly funding HTLC timeout") { f => + import f._ + + connect(peer) + + // A first funding is proposed coming from two upstream channels. + val paymentHash1 = randomBytes32() + val upstream1 = Seq( + upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash1), + upstreamChannel(45_000_000 msat, CltvExpiry(550), paymentHash1), + ) + proposeFunding(50_000_000 msat, CltvExpiry(520), paymentHash1, upstream1.head) + proposeFunding(40_000_000 msat, CltvExpiry(510), paymentHash1, upstream1.last) + + // A second funding is signed coming from two upstream channels, one of them received after signing. + val paymentHash2 = randomBytes32() + val upstream2 = Seq( + upstreamChannel(45_000_000 msat, CltvExpiry(550), paymentHash2), + upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash2), + ) + proposeFunding(40_000_000 msat, CltvExpiry(515), paymentHash2, upstream2.head) + signLiquidityPurchase(100_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(paymentHash2 :: Nil)) + proposeExtraFunding(50_000_000 msat, CltvExpiry(525), paymentHash2, upstream2.last) + + // A third funding is signed coming from a trampoline payment. + val paymentHash3 = randomBytes32() + val upstream3 = Upstream.Hot.Trampoline(List( + upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash3), + upstreamChannel(45_000_000 msat, CltvExpiry(560), paymentHash3), + )) + proposeFunding(100_000_000 msat, CltvExpiry(512), paymentHash3, upstream3) + signLiquidityPurchase(100_000 sat, LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc(paymentHash3 :: Nil)) + + // A fourth funding is proposed coming from a trampoline payment. + val paymentHash4 = randomBytes32() + val upstream4 = Upstream.Hot.Trampoline(List(upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash4))) + proposeFunding(50_000_000 msat, CltvExpiry(516), paymentHash4, upstream4) + + // The first three proposals reach their CLTV expiry. + peer ! CurrentBlockHeight(BlockHeight(515)) + val fwds = (0 until 6).map(_ => register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]]) + register.expectNoMessage(100 millis) + fwds.foreach(fwd => { + assert(fwd.message.reason == Right(UnknownNextPeer())) + assert(fwd.message.commit) + }) + assert(fwds.map(_.channelId).toSet == (upstream1 ++ upstream2 ++ upstream3.received).map(_.add.channelId).toSet) + assert(fwds.map(_.message.id).toSet == (upstream1 ++ upstream2 ++ upstream3.received).map(_.add.id).toSet) + awaitCond(nodeParams.db.onTheFlyFunding.listPending(remoteNodeId).isEmpty, interval = 100 millis) + } + + test("signed on-the-fly funding HTLC timeout after disconnection") { f => + import f._ + + connect(peer) + probe.watch(peer.ref) + + // A first funding proposal is signed. + val upstream1 = upstreamChannel(60_000_000 msat, CltvExpiry(560)) + proposeFunding(50_000_000 msat, CltvExpiry(520), upstream1.add.paymentHash, upstream1) + signLiquidityPurchase(75_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(upstream1.add.paymentHash :: Nil)) + + // A second funding proposal is signed. + val upstream2 = upstreamChannel(60_000_000 msat, CltvExpiry(560)) + proposeFunding(50_000_000 msat, CltvExpiry(525), upstream2.add.paymentHash, upstream2) + signLiquidityPurchase(80_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(upstream2.add.paymentHash :: Nil)) + + // We don't fail signed proposals on disconnection. + disconnect() + register.expectNoMessage(100 millis) + + // But if a funding proposal reaches its CLTV expiry, we fail it. + peer ! CurrentBlockHeight(BlockHeight(522)) + val fwd1 = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd1.channelId == upstream1.add.channelId) + assert(fwd1.message.id == upstream1.add.id) + register.expectNoMessage(100 millis) + // We still have one pending proposal, so we don't stop. + probe.expectNoMessage(100 millis) + + // When restarting, we watch for pending proposals. + val peerAfterRestart = TestFSMRef(new Peer(nodeParams, remoteNodeId, new DummyOnChainWallet(), FakeChannelFactory(remoteNodeId, channel), TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref)) + peerAfterRestart ! Peer.Init(Set.empty) + probe.watch(peerAfterRestart.ref) + + // The last funding proposal reaches its CLTV expiry. + peerAfterRestart ! CurrentBlockHeight(BlockHeight(525)) + val fwd2 = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd2.channelId == upstream2.add.channelId) + assert(fwd2.message.id == upstream2.add.id) + register.expectNoMessage(100 millis) + probe.expectTerminated(peerAfterRestart.ref) + } + + test("receive open_channel2") { f => + import f._ + + connect(peer) + + val upstream = upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash) + proposeFunding(50_000_000 msat, CltvExpiry(520), paymentHash, upstream) + + val requestFunding = LiquidityAds.RequestFunding( + 100_000 sat, + LiquidityAds.FundingRate(10_000 sat, 500_000 sat, 0, 100, 1000 sat), + LiquidityAds.PaymentDetails.FromFutureHtlcWithPreimage(preimage :: Nil) + ) + val open = createOpenChannelMessage(requestFunding) + peerConnection.send(peer, open) + rateLimiter.expectMsgType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + val init = channel.expectMsgType[INPUT_INIT_CHANNEL_NON_INITIATOR] + assert(!init.localParams.isChannelOpener) + assert(init.localParams.paysCommitTxFees) + assert(init.fundingContribution_opt.contains(LiquidityAds.AddFunding(requestFunding.requestedAmount, nodeParams.willFundRates_opt))) + + // The preimage was provided, so we fulfill upstream HTLCs. + val fwd = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == upstream.add.channelId) + assert(fwd.message.id == upstream.add.id) + assert(fwd.message.r == preimage) + } + + test("receive splice_init") { f => + import f._ + + connect(peer) + val channelId = openChannel(200_000 sat) + + val upstream = upstreamChannel(60_000_000 msat, CltvExpiry(560), paymentHash) + proposeFunding(50_000_000 msat, CltvExpiry(520), paymentHash, upstream) + + val requestFunding = LiquidityAds.RequestFunding( + 100_000 sat, + LiquidityAds.FundingRate(10_000 sat, 500_000 sat, 0, 100, 1000 sat), + LiquidityAds.PaymentDetails.FromFutureHtlcWithPreimage(preimage :: Nil) + ) + val splice = createSpliceMessage(channelId, requestFunding) + peerConnection.send(peer, splice) + channel.expectMsg(splice) + channel.expectNoMessage(100 millis) + + // The preimage was provided, so we fulfill upstream HTLCs. + val fwd = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == upstream.add.channelId) + assert(fwd.message.id == upstream.add.id) + assert(fwd.message.r == preimage) + } + + test("reject invalid open_channel2") { f => + import f._ + + connect(peer) + + val requestFunding = LiquidityAds.RequestFunding( + 100_000 sat, + LiquidityAds.FundingRate(10_000 sat, 500_000 sat, 0, 0, 10_000 sat), + LiquidityAds.PaymentDetails.FromFutureHtlc(paymentHash :: Nil) + ) + val open = createOpenChannelMessage(requestFunding, htlcMinimum = 1_000_000 msat) + + // No matching will_add_htlc to pay fees. + peerConnection.send(peer, open) + rateLimiter.expectMsgType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == open.temporaryChannelId) + channel.expectNoMessage(100 millis) + + // Requested amount is too low. + val bigUpstream = upstreamChannel(200_000_000 msat, expiryIn, paymentHash) + proposeFunding(150_000_000 msat, expiryOut, paymentHash, bigUpstream) + peerConnection.send(peer, open) + rateLimiter.expectMsgType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == open.temporaryChannelId) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].channelId == bigUpstream.add.channelId) + channel.expectNoMessage(100 millis) + + // Not enough funds to pay fees. + val upstream = upstreamChannel(11_000_000 msat, expiryIn, paymentHash) + proposeFunding(10_999_999 msat, expiryOut, paymentHash, upstream) + peerConnection.send(peer, open) + rateLimiter.expectMsgType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == open.temporaryChannelId) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].channelId == upstream.add.channelId) + channel.expectNoMessage(100 millis) + + // Proposal already funded. + proposeFunding(11_000_000 msat, expiryOut, paymentHash, upstream) + signLiquidityPurchase(requestFunding.requestedAmount, requestFunding.paymentDetails, fees = requestFunding.fees(TestConstants.feeratePerKw)) + peerConnection.send(peer, open) + rateLimiter.expectMsgType[AddOrRejectChannel].replyTo ! PendingChannelsRateLimiter.AcceptOpenChannel + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == open.temporaryChannelId) + register.expectNoMessage(100 millis) + channel.expectNoMessage(100 millis) + } + + test("reject invalid splice_init") { f => + import f._ + + connect(peer) + val channelId = openChannel(500_000 sat) + + val requestFunding = LiquidityAds.RequestFunding( + 100_000 sat, + LiquidityAds.FundingRate(10_000 sat, 500_000 sat, 0, 0, 10_000 sat), + LiquidityAds.PaymentDetails.FromFutureHtlcWithPreimage(preimage :: Nil) + ) + val splice = createSpliceMessage(channelId, requestFunding) + + // No matching will_add_htlc to pay fees. + peerConnection.send(peer, splice) + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == channelId) + channel.expectNoMessage(100 millis) + + // Requested amount is too low. + val bigUpstream = upstreamChannel(200_000_000 msat, expiryIn, paymentHash) + proposeFunding(150_000_000 msat, expiryOut, paymentHash, bigUpstream) + peerConnection.send(peer, splice) + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == channelId) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].channelId == bigUpstream.add.channelId) + channel.expectNoMessage(100 millis) + + // Not enough funds to pay fees. + val upstream = upstreamChannel(11_000_000 msat, expiryIn, paymentHash) + proposeFunding(9_000_000 msat, expiryOut, paymentHash, upstream) + peerConnection.send(peer, splice) + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == channelId) + assert(register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].channelId == upstream.add.channelId) + channel.expectNoMessage(100 millis) + + // Proposal already funded. + proposeFunding(11_000_000 msat, expiryOut, paymentHash, upstream) + signLiquidityPurchase(requestFunding.requestedAmount, requestFunding.paymentDetails, fees = requestFunding.fees(TestConstants.feeratePerKw), fundingTxIndex = 1) + peerConnection.send(peer, splice) + assert(peerConnection.expectMsgType[CancelOnTheFlyFunding].channelId == channelId) + register.expectNoMessage(100 millis) + channel.expectNoMessage(100 millis) + } + + test("successfully relay HTLCs to on-the-fly funded channel") { f => + import f._ + + connect(peer) + + val preimage1 = randomBytes32() + val paymentHash1 = Crypto.sha256(preimage1) + val preimage2 = randomBytes32() + val paymentHash2 = Crypto.sha256(preimage2) + + val upstream1 = upstreamChannel(11_000_000 msat, expiryIn, paymentHash1) + proposeFunding(10_000_000 msat, expiryOut, paymentHash1, upstream1) + val upstream2 = upstreamChannel(16_000_000 msat, expiryIn, paymentHash2) + proposeFunding(15_000_000 msat, expiryOut, paymentHash2, upstream2) + + val htlcMinimum = 1_500_000 msat + val fees = LiquidityAds.Fees(10_000 sat, 5_000 sat) + val purchase = signLiquidityPurchase(200_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(List(paymentHash1, paymentHash2)), fees = fees, htlcMinimum = htlcMinimum) + + // Once the channel is ready to relay payments, we forward HTLCs matching the proposed will_add_htlc. + // We have two distinct payment hashes that are relayed independently. + val channelData = makeChannelData(htlcMinimum) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, purchase.channelId, fundingTxIndex = 0) + val channelInfo = Seq( + channel.expectMsgType[CMD_GET_CHANNEL_INFO], + channel.expectMsgType[CMD_GET_CHANNEL_INFO], + ) + + // We relay the first payment. + channelInfo.head.replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, purchase.channelId, channel.ref, NORMAL, channelData) + val cmd1 = channel.expectMsgType[CMD_ADD_HTLC] + channel.expectNoMessage(100 millis) + + // We relay the second payment. + channelInfo.last.replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, purchase.channelId, channel.ref, NORMAL, channelData) + val cmd2 = channel.expectMsgType[CMD_ADD_HTLC] + channel.expectNoMessage(100 millis) + + // The fee is split across outgoing payments. + assert(Set(cmd1.paymentHash, cmd2.paymentHash) == Set(paymentHash1, paymentHash2)) + val fundingFees = Seq(cmd1, cmd2).map(cmd => { + assert(cmd.amount >= htlcMinimum) + assert(cmd.cltvExpiry == expiryOut) + assert(cmd.commit) + assert(cmd.fundingFee_opt.nonEmpty) + assert(cmd.fundingFee_opt.get.fundingTxId == purchase.txId) + assert(cmd.fundingFee_opt.get.amount > 0.msat) + cmd.fundingFee_opt.get + }) + val feesPaid = fundingFees.map(_.amount).sum + assert(feesPaid == fees.total.toMilliSatoshi) + assert(cmd1.amount + cmd2.amount + feesPaid == 25_000_000.msat) + + // The payments are fulfilled. + val (add1, add2) = if (cmd1.paymentHash == paymentHash1) (cmd1, cmd2) else (cmd2, cmd1) + val outgoing = Seq(add1, add2).map(add => UpdateAddHtlc(purchase.channelId, randomHtlcId(), add.amount, add.paymentHash, add.cltvExpiry, add.onion, add.nextBlindingKey_opt, add.fundingFee_opt)) + add1.replyTo ! RES_ADD_SETTLED(add1.origin, outgoing.head, HtlcResult.RemoteFulfill(UpdateFulfillHtlc(purchase.channelId, outgoing.head.id, preimage1))) + val fwd1 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd1.channelId == upstream1.add.channelId) + assert(fwd1.message.id == upstream1.add.id) + assert(fwd1.message.r == preimage1) + add2.replyTo ! RES_ADD_SETTLED(add2.origin, outgoing.last, HtlcResult.OnChainFulfill(preimage2)) + val fwd2 = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd2.channelId == upstream2.add.channelId) + assert(fwd2.message.id == upstream2.add.id) + assert(fwd2.message.r == preimage2) + awaitCond(nodeParams.db.onTheFlyFunding.listPending(remoteNodeId).isEmpty, interval = 100 millis) + register.expectNoMessage(100 millis) + } + + test("successfully relay HTLCs to on-the-fly spliced channel") { f => + import f._ + + // We create a channel, that can later be spliced. + connect(peer) + val channelId = openChannel(250_000 sat) + + val htlcMinimum = 1_000_000 msat + val fees = LiquidityAds.Fees(1000 sat, 4000 sat) + val upstream = Seq( + upstreamChannel(50_000_000 msat, expiryIn, paymentHash), + upstreamChannel(60_000_000 msat, expiryIn, paymentHash), + Upstream.Hot.Trampoline(upstreamChannel(50_000_000 msat, expiryIn, paymentHash) :: Nil) + ) + proposeFunding(50_000_000 msat, expiryOut, paymentHash, upstream(0)) + proposeFunding(60_000_000 msat, expiryOut, paymentHash, upstream(1)) + val purchase = signLiquidityPurchase(200_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlcWithPreimage(preimage :: Nil), channelId, fees, fundingTxIndex = 5, htlcMinimum) + // We receive the last payment *after* signing the funding transaction. + proposeExtraFunding(50_000_000 msat, expiryOut, paymentHash, upstream(2)) + + // Once the splice with the right funding index is locked, we forward HTLCs matching the proposed will_add_htlc. + val channelData = makeChannelData(htlcMinimum) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 4) + channel.expectNoMessage(100 millis) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 5) + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, channelId, channel.ref, NORMAL, channelData) + val adds1 = Seq( + channel.expectMsgType[CMD_ADD_HTLC], + channel.expectMsgType[CMD_ADD_HTLC], + channel.expectMsgType[CMD_ADD_HTLC], + ) + adds1.foreach(add => { + assert(add.paymentHash == paymentHash) + assert(add.fundingFee_opt.nonEmpty) + assert(add.fundingFee_opt.get.fundingTxId == purchase.txId) + }) + adds1.take(2).foreach(add => assert(!add.commit)) + assert(adds1.last.commit) + assert(adds1.map(_.fundingFee_opt.get.amount).sum == fees.total.toMilliSatoshi) + assert(adds1.map(add => add.amount + add.fundingFee_opt.get.amount).sum == 160_000_000.msat) + channel.expectNoMessage(100 millis) + + // The recipient fails the payments: we don't relay the failure upstream and will retry. + adds1.take(2).foreach(add => { + val htlc = UpdateAddHtlc(channelId, randomHtlcId(), add.amount, paymentHash, add.cltvExpiry, add.onion, add.nextBlindingKey_opt, add.fundingFee_opt) + val fail = UpdateFailHtlc(channelId, htlc.id, randomBytes(50)) + add.replyTo ! RES_ADD_SETTLED(add.origin, htlc, HtlcResult.RemoteFail(fail)) + }) + adds1.last.replyTo ! RES_ADD_FAILED(adds1.last, TooManyAcceptedHtlcs(channelId, 5), None) + register.expectNoMessage(100 millis) + + // When the next splice completes, we retry the payment. + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 6) + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, channelId, channel.ref, NORMAL, channelData) + val adds2 = Seq( + channel.expectMsgType[CMD_ADD_HTLC], + channel.expectMsgType[CMD_ADD_HTLC], + channel.expectMsgType[CMD_ADD_HTLC], + ) + channel.expectNoMessage(100 millis) + + // The payment succeeds. + adds2.foreach(add => { + val htlc = UpdateAddHtlc(channelId, randomHtlcId(), add.amount, paymentHash, add.cltvExpiry, add.onion, add.nextBlindingKey_opt, add.fundingFee_opt) + add.replyTo ! RES_ADD_SETTLED(add.origin, htlc, HtlcResult.OnChainFulfill(preimage)) + }) + val fwds = Seq( + register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]], + register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]], + register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]], + ) + val (channelsIn, htlcsIn) = upstream.flatMap { + case u: Hot.Channel => Seq(u) + case u: Hot.Trampoline => u.received + case _: Upstream.Local => Nil + }.map(c => (c.add.channelId, c.add.id)).toSet.unzip + assert(fwds.map(_.channelId).toSet == channelsIn) + assert(fwds.map(_.message.id).toSet == htlcsIn) + fwds.foreach(fwd => assert(fwd.message.r == preimage)) + awaitCond(nodeParams.db.onTheFlyFunding.listPending(remoteNodeId).isEmpty, interval = 100 millis) + + // We don't retry anymore. + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 7) + channel.expectNoMessage(100 millis) + } + + test("successfully relay HTLCs after restart") { f => + import f._ + + // We create a channel, that can later be spliced. + connect(peer) + val channelId = openChannel(250_000 sat) + + // We relay an on-the-fly payment. + val upstream = upstreamChannel(50_000_000 msat, expiryIn, paymentHash) + proposeFunding(50_000_000 msat, expiryOut, paymentHash, upstream) + val fees = LiquidityAds.Fees(1000 sat, 1000 sat) + val purchase = signLiquidityPurchase(200_000 sat, LiquidityAds.PaymentDetails.FromChannelBalanceForFutureHtlc(paymentHash :: Nil), channelId, fees, fundingTxIndex = 1) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 1) + val channelData1 = makeChannelData() + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, channelId, channel.ref, NORMAL, channelData1) + // We don't collect additional fees if they were paid from our peer's channel balance already. + val cmd1 = channel.expectMsgType[CMD_ADD_HTLC] + val htlc = UpdateAddHtlc(channelId, 0, cmd1.amount, paymentHash, cmd1.cltvExpiry, cmd1.onion, cmd1.nextBlindingKey_opt, cmd1.fundingFee_opt) + assert(cmd1.fundingFee_opt.contains(LiquidityAds.FundingFee(0 msat, purchase.txId))) + channel.expectNoMessage(100 millis) + + // We disconnect: on reconnection, we don't attempt the payment again since it's already pending. + disconnect(channelCount = 1) + connect(peer, channelCount = 1) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 1) + channel.expectNoMessage(100 millis) + register.expectNoMessage(100 millis) + + // On restart, we don't attempt the payment again: it's already pending. + val peerAfterRestart = TestFSMRef(new Peer(nodeParams, remoteNodeId, new DummyOnChainWallet(), FakeChannelFactory(remoteNodeId, channel), TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref)) + peerAfterRestart ! Peer.Init(Set.empty) + connect(peerAfterRestart) + peerAfterRestart ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 1) + val channelData2 = makeChannelData(localChanges = LocalChanges(Nil, htlc :: Nil, Nil)) + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, channelId, channel.ref, NORMAL, channelData2) + channel.expectNoMessage(100 millis) + + // The payment is failed by our peer but we don't see it (it's a cold origin): we attempt it again. + val channelData3 = makeChannelData() + peerAfterRestart ! ChannelReadyForPayments(channel.ref, remoteNodeId, channelId, fundingTxIndex = 1) + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, channelId, channel.ref, NORMAL, channelData3) + val cmd2 = channel.expectMsgType[CMD_ADD_HTLC] + assert(cmd2.paymentHash == paymentHash) + assert(cmd2.amount == cmd1.amount) + channel.expectNoMessage(100 millis) + + // The payment is fulfilled by our peer. + cmd2.replyTo ! RES_ADD_SETTLED(cmd2.origin, htlc, HtlcResult.OnChainFulfill(preimage)) + assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].channelId == upstream.add.channelId) + nodeParams.db.onTheFlyFunding.addPreimage(preimage) + register.expectNoMessage(100 millis) + awaitCond(nodeParams.db.onTheFlyFunding.listPending(remoteNodeId).isEmpty, interval = 100 millis) + } + + test("don't relay payments too close to expiry") { f => + import f._ + + connect(peer) + + val upstream = upstreamChannel(100_000_000 msat, expiryIn, paymentHash) + proposeFunding(100_000_000 msat, CltvExpiry(TestConstants.defaultBlockHeight), paymentHash, upstream) + val purchase = signLiquidityPurchase(200_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(List(paymentHash))) + + // We're too close the HTLC expiry to relay it. + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, purchase.channelId, fundingTxIndex = 0) + channel.expectNoMessage(100 millis) + peer ! CurrentBlockHeight(BlockHeight(TestConstants.defaultBlockHeight)) + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == upstream.add.channelId) + assert(fwd.message.id == upstream.add.id) + awaitCond(nodeParams.db.onTheFlyFunding.listPending(remoteNodeId).isEmpty, interval = 100 millis) + } + + test("don't relay payments for known preimage") { f => + import f._ + + connect(peer) + + val upstream = upstreamChannel(100_000_000 msat, expiryIn, paymentHash) + proposeFunding(100_000_000 msat, expiryOut, paymentHash, upstream) + val purchase = signLiquidityPurchase(200_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(List(paymentHash))) + + // We've already relayed that payment and have the matching preimage in our DB. + // We don't relay it again to avoid paying our peer twice. + nodeParams.db.onTheFlyFunding.addPreimage(preimage) + peer ! ChannelReadyForPayments(channel.ref, remoteNodeId, purchase.channelId, fundingTxIndex = 0) + channel.expectMsgType[CMD_GET_CHANNEL_INFO].replyTo ! RES_GET_CHANNEL_INFO(remoteNodeId, purchase.channelId, channel.ref, NORMAL, makeChannelData()) + channel.expectNoMessage(100 millis) + + val fwd = register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == upstream.add.channelId) + assert(fwd.message.id == upstream.add.id) + assert(fwd.message.r == preimage) + register.expectNoMessage(100 millis) + } + + test("stop when disconnecting without pending proposals") { f => + import f._ + + connect(peer) + probe.watch(peer.ref) + disconnect() + probe.expectTerminated(peer.ref) + } + + test("stop when disconnecting with non-funded pending proposals") { f => + import f._ + + connect(peer) + probe.watch(peer.ref) + + // We have two distinct pending funding proposals. + val paymentHash1 = randomBytes32() + val upstream1 = upstreamChannel(300_000_000 msat, CltvExpiry(1200), paymentHash1) + proposeFunding(250_000_000 msat, CltvExpiry(1105), paymentHash1, upstream1) + val paymentHash2 = randomBytes32() + val upstream2 = Upstream.Hot.Trampoline(List( + upstreamChannel(100_000_000 msat, CltvExpiry(1250), paymentHash2), + upstreamChannel(150_000_000 msat, CltvExpiry(1240), paymentHash2), + )) + proposeFunding(225_000_000 msat, CltvExpiry(1105), paymentHash2, upstream2) + + // All incoming HTLCs are failed on disconnection. + disconnect() + (upstream1.add +: upstream2.received.map(_.add)).foreach(add => { + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == add.channelId) + assert(fwd.message.id == add.id) + assert(fwd.message.reason == Right(UnknownNextPeer())) + assert(fwd.message.commit) + }) + register.expectNoMessage(100 millis) + probe.expectTerminated(peer.ref) + } + + test("don't stop when disconnecting with funded pending proposals") { f => + import f._ + + connect(peer) + probe.watch(peer.ref) + + // We have one funded proposal and one that was not funded yet. + val upstream1 = upstreamChannel(300_000_000 msat, CltvExpiry(1200)) + proposeFunding(250_000_000 msat, CltvExpiry(1105), upstream1.add.paymentHash, upstream1) + val upstream2 = upstreamChannel(250_000_000 msat, CltvExpiry(1000)) + proposeFunding(220_000_000 msat, CltvExpiry(1105), upstream2.add.paymentHash, upstream2) + signLiquidityPurchase(500_000 sat, LiquidityAds.PaymentDetails.FromFutureHtlc(upstream2.add.paymentHash :: Nil)) + + // Only non-funded proposals are failed on disconnection, and we don't stop before the funded proposal completes. + disconnect() + val fwd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == upstream1.add.channelId) + assert(fwd.message.id == upstream1.add.id) + assert(fwd.message.reason == Right(UnknownNextPeer())) + assert(fwd.message.commit) + register.expectNoMessage(100 millis) + probe.expectNoMessage(100 millis) + } + +} + +object OnTheFlyFundingSpec { + + val expiryIn = CltvExpiry(TestConstants.defaultBlockHeight + 750) + val expiryOut = CltvExpiry(TestConstants.defaultBlockHeight + 500) + + val preimage = randomBytes32() + val paymentHash = Crypto.sha256(preimage) + + def randomOnion(): OnionRoutingPacket = OnionRoutingPacket(0, randomKey().publicKey.value, randomBytes(1300), randomBytes32()) + + def randomHtlcId(): Long = Math.abs(randomLong()) % 50_000 + + def upstreamChannel(amountIn: MilliSatoshi, expiryIn: CltvExpiry, paymentHash: ByteVector32 = randomBytes32(), blinded: Boolean = false): Upstream.Hot.Channel = { + val blindingKey = if (blinded) Some(randomKey().publicKey) else None + val add = UpdateAddHtlc(randomBytes32(), randomHtlcId(), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, blindingKey, None) + Upstream.Hot.Channel(add, TimestampMilli.now(), randomKey().publicKey) + } + + def createWillAdd(amount: MilliSatoshi, paymentHash: ByteVector32, expiry: CltvExpiry, blinding_opt: Option[PublicKey] = None): WillAddHtlc = { + WillAddHtlc(Block.RegtestGenesisBlock.hash, randomBytes32(), amount, paymentHash, expiry, randomOnion(), blinding_opt) + } + + def createStatus(): OnTheFlyFunding.Status = OnTheFlyFunding.Status.Funded(randomBytes32(), TxId(randomBytes32()), 0, 2500 msat) + + def createOpenChannelMessage(requestFunding: LiquidityAds.RequestFunding, fundingAmount: Satoshi = 250_000 sat, htlcMinimum: MilliSatoshi = 1 msat): OpenDualFundedChannel = { + val channelFlags = ChannelFlags(nonInitiatorPaysCommitFees = true, announceChannel = false) + val tlvs = TlvStream[OpenDualFundedChannelTlv](ChannelTlv.ChannelTypeTlv(ChannelTypes.AnchorOutputsZeroFeeHtlcTx()), ChannelTlv.RequestFundingTlv(requestFunding)) + OpenDualFundedChannel(Block.RegtestGenesisBlock.hash, randomBytes32(), TestConstants.feeratePerKw, TestConstants.anchorOutputsFeeratePerKw, fundingAmount, 483 sat, UInt64(100), htlcMinimum, CltvExpiryDelta(144), 10, 0, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, channelFlags, tlvs) + } + + def createSpliceMessage(channelId: ByteVector32, requestFunding: LiquidityAds.RequestFunding): SpliceInit = { + SpliceInit(channelId, 0 sat, 0, TestConstants.feeratePerKw, randomKey().publicKey, 0 msat, requireConfirmedInputs = false, Some(requestFunding)) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index 52ee222bbf..babe08980b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -20,8 +20,9 @@ import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps} +import akka.testkit.TestKit.awaitCond import com.typesafe.config.ConfigFactory -import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, TxId} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair._ @@ -212,10 +213,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat import f._ val replyTo = TestProbe[Any]() - val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 42, amountMsat = 11000000 msat, paymentHash = ByteVector32.Zeroes, CltvExpiry(4200), TestConstants.emptyOnionPacket, None, None) + val add_ab = UpdateAddHtlc(channelId_ab, 42, 11000000 msat, ByteVector32.Zeroes, CltvExpiry(4200), TestConstants.emptyOnionPacket, None, None) val add_bc = UpdateAddHtlc(channelId_bc, 72, 1000 msat, paymentHash, CltvExpiry(1), TestConstants.emptyOnionPacket, None, None) val channelOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Channel(add_ab, TimestampMilli.now(), priv_a.publicKey)) - val trampolineOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_ab, TimestampMilli.now(), priv_a.publicKey)))) + val trampolineOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Trampoline(List(Upstream.Hot.Channel(add_ab, TimestampMilli.now(), priv_a.publicKey)))) val addSettled = Seq( RES_ADD_SETTLED(channelOrigin, add_bc, HtlcResult.OnChainFulfill(randomBytes32())), @@ -234,4 +235,29 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } } + test("store on-the-fly funding preimage") { f => + import f._ + + val replyTo = TestProbe[Any]() + val add_ab = UpdateAddHtlc(channelId_ab, 17, 50_000 msat, paymentHash, CltvExpiry(800_000), TestConstants.emptyOnionPacket, None, None) + val add_bc = UpdateAddHtlc(channelId_bc, 21, 45_000 msat, paymentHash, CltvExpiry(799_000), TestConstants.emptyOnionPacket, None, Some(LiquidityAds.FundingFee(1000 msat, TxId(randomBytes32())))) + val originHot = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Channel(add_ab, TimestampMilli.now(), randomKey().publicKey)) + val originCold = Origin.Cold(originHot) + + val addFulfilled = Seq( + RES_ADD_SETTLED(originHot, add_bc, HtlcResult.OnChainFulfill(randomBytes32())), + RES_ADD_SETTLED(originHot, add_bc, HtlcResult.RemoteFulfill(UpdateFulfillHtlc(add_bc.channelId, add_bc.id, randomBytes32()))), + RES_ADD_SETTLED(originCold, add_bc, HtlcResult.OnChainFulfill(randomBytes32())), + RES_ADD_SETTLED(originCold, add_bc, HtlcResult.RemoteFulfill(UpdateFulfillHtlc(add_bc.channelId, add_bc.id, randomBytes32()))), + ) + + for (res <- addFulfilled) { + val preimage = res.result.paymentPreimage + val paymentHash = Crypto.sha256(preimage) + assert(nodeParams.db.onTheFlyFunding.getPreimage(paymentHash).isEmpty) + relayer ! res + awaitCond(nodeParams.db.onTheFlyFunding.getPreimage(paymentHash).contains(preimage), 10 seconds) + } + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala index 5db31e0600..ee7dd8847f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala @@ -136,7 +136,7 @@ class ChannelCodecs1Spec extends AnyFunSuite { UpdateAddHtlc(randomBytes32(), 1L, 2000 msat, randomBytes32(), CltvExpiry(400000), TestConstants.emptyOnionPacket, None, None), UpdateAddHtlc(randomBytes32(), 2L, 3000 msat, randomBytes32(), CltvExpiry(400000), TestConstants.emptyOnionPacket, None, None), ) - val trampolineRelayedHot = Origin.Hot(replyTo, Upstream.Hot.Trampoline(adds.map(add => Upstream.Hot.Channel(add, TimestampMilli(0), randomKey().publicKey)))) + val trampolineRelayedHot = Origin.Hot(replyTo, Upstream.Hot.Trampoline(adds.map(add => Upstream.Hot.Channel(add, TimestampMilli(0), randomKey().publicKey)).toList)) // We didn't encode the incoming HTLC amount. val trampolineRelayed = Origin.Cold(Upstream.Cold.Trampoline(adds.map(add => Upstream.Cold.Channel(add.channelId, add.id, 0 msat)).toList)) assert(originCodec.decodeValue(originCodec.encode(trampolineRelayedHot).require).require == trampolineRelayed)