From 956ae5509dc4711d775c44f7779e0e77ce8ee46f Mon Sep 17 00:00:00 2001 From: t-bast Date: Wed, 7 Aug 2024 12:29:10 +0200 Subject: [PATCH] Simplify outgoing payment state machine We previously supported having multiple channels with our peer, because we didn't yet support splicing. Now that we support splicing, we always have at most one active channel with our peer. This lets us simplify greatly the outgoing payment state machine: payments are always made with a single outgoing HTLC instead of potentially multiple HTLCs (MPP). We don't need any kind of path-finding: we simply need to check the balance of our active channel, if any. We may introduce support for connecting to multiple peers in the future. When that happens, we will still have a single active channel per peer, but we may allow splitting outgoing payments across our peers. We will need to re-work the outgoing payment state machine when this happens, but it is too early to support this now anyway. This refactoring makes it easier to create payment onion, by creating the trampoline onion *and* the outer onion in the same function call. This will make it simpler to migrate to the version of trampoline that is currently specified in https://github.com/lightning/bolts/pull/836 where some fields will be included in the payment onion instead of the trampoline onion. --- .../lightning/channel/ChannelException.kt | 2 + .../kotlin/fr/acinq/lightning/io/Peer.kt | 8 +- .../payment/OutgoingPaymentHandler.kt | 539 +++++------- .../payment/OutgoingPaymentPacket.kt | 192 ++--- .../lightning/payment/RouteCalculation.kt | 61 -- .../fr/acinq/lightning/wire/PaymentOnion.kt | 15 +- .../fr/acinq/lightning/channel/TestsHelper.kt | 14 +- .../fr/acinq/lightning/io/peer/PeerTest.kt | 101 +-- .../IncomingPaymentHandlerTestsCommon.kt | 70 +- .../OutgoingPaymentHandlerTestsCommon.kt | 779 +++++------------ .../payment/PaymentPacketTestsCommon.kt | 806 ++++++++---------- .../payment/RouteCalculationTestsCommon.kt | 154 ---- 12 files changed, 942 insertions(+), 1799 deletions(-) delete mode 100644 src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt delete mode 100644 src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt index 0c28e038c..a005ac003 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt @@ -83,9 +83,11 @@ data class CannotAffordFirstCommitFees (override val channelId: Byte data class CannotAffordFees (override val channelId: ByteVector32, val missing: Satoshi, val reserve: Satoshi, val fees: Satoshi) : ChannelException(channelId, "can't pay the fee: missing=$missing reserve=$reserve fees=$fees") data class CannotSignWithoutChanges (override val channelId: ByteVector32) : ChannelException(channelId, "cannot sign when there are no change") data class CannotSignBeforeRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "cannot sign until next revocation hash is received") +data class CannotSignDisconnected (override val channelId: ByteVector32) : ChannelException(channelId, "disconnected before signing outgoing payments") data class UnexpectedRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "received unexpected RevokeAndAck message") data class InvalidRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "invalid revocation") data class InvalidFailureCode (override val channelId: ByteVector32) : ChannelException(channelId, "UpdateFailMalformedHtlc message doesn't have BADONION bit set") +data class CannotDecryptFailure (override val channelId: ByteVector32, val details: String) : ChannelException(channelId, "cannot decrypt failure message: $details") data class PleasePublishYourCommitment (override val channelId: ByteVector32) : ChannelException(channelId, "please publish your local commitment") data class CommandUnavailableInThisState (override val channelId: ByteVector32, val state: String) : ChannelException(channelId, "cannot execute command in state=$state") data class ForbiddenDuringSplice (override val channelId: ByteVector32, val command: String?) : ChannelException(channelId, "cannot process $command while splicing") diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index a24a9d5cb..a6d2d2fd9 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -812,12 +812,7 @@ class Peer( is ChannelAction.ProcessIncomingHtlc -> processIncomingPayment(Either.Right(action.add)) is ChannelAction.ProcessCmdRes.NotExecuted -> logger.warning(action.t) { "command not executed" } is ChannelAction.ProcessCmdRes.AddFailed -> { - when (val result = outgoingPaymentHandler.processAddFailed(actualChannelId, action, _channels)) { - is OutgoingPaymentHandler.Progress -> { - _eventsFlow.emit(PaymentProgress(result.request, result.fees)) - result.actions.forEach { input.send(it) } - } - + when (val result = outgoingPaymentHandler.processAddFailed(actualChannelId, action)) { is OutgoingPaymentHandler.Failure -> _eventsFlow.emit(PaymentNotSent(result.request, result.failure)) null -> logger.debug { "non-final error, more partial payments are still pending: ${action.error.message}" } } @@ -838,7 +833,6 @@ class Peer( is ChannelAction.ProcessCmdRes.AddSettledFulfill -> { when (val result = outgoingPaymentHandler.processAddSettled(action)) { is OutgoingPaymentHandler.Success -> _eventsFlow.emit(PaymentSent(result.request, result.payment)) - is OutgoingPaymentHandler.PreimageReceived -> logger.debug(mapOf("paymentId" to result.request.paymentId)) { "payment preimage received: ${result.preimage}" } null -> logger.debug { "unknown payment" } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index 8c0c3d042..f6d53fc75 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -4,11 +4,10 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.Try import fr.acinq.lightning.* -import fr.acinq.lightning.channel.ChannelAction -import fr.acinq.lightning.channel.ChannelException -import fr.acinq.lightning.channel.states.Channel -import fr.acinq.lightning.channel.states.ChannelState +import fr.acinq.lightning.channel.* +import fr.acinq.lightning.channel.states.* import fr.acinq.lightning.crypto.sphinx.FailurePacket +import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.SharedSecrets import fr.acinq.lightning.db.HopDesc import fr.acinq.lightning.db.LightningOutgoingPayment @@ -18,12 +17,13 @@ import fr.acinq.lightning.io.WrappedChannelCommand import fr.acinq.lightning.logging.MDCLogger import fr.acinq.lightning.logging.error import fr.acinq.lightning.logging.mdc -import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.utils.UUID import fr.acinq.lightning.utils.msat -import fr.acinq.lightning.utils.sum -import fr.acinq.lightning.wire.* +import fr.acinq.lightning.wire.FailureMessage +import fr.acinq.lightning.wire.TrampolineExpiryTooSoon +import fr.acinq.lightning.wire.TrampolineFeeInsufficient +import fr.acinq.lightning.wire.UnknownNextPeer class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: WalletParams, val db: OutgoingPaymentsDb) { @@ -37,16 +37,32 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle /** The payment could not be sent. */ data class Failure(val request: PayInvoice, val failure: OutgoingPaymentFailure) : SendPaymentResult, ProcessFailureResult - /** The recipient released the preimage, but we are still waiting for some partial payments to settle. */ - data class PreimageReceived(val request: PayInvoice, val preimage: ByteVector32) : ProcessFulfillResult - /** The payment was successfully made. */ data class Success(val request: PayInvoice, val payment: LightningOutgoingPayment, val preimage: ByteVector32) : ProcessFailureResult, ProcessFulfillResult private val logger = nodeParams.loggerFactory.newLogger(this::class) private val childToParentId = mutableMapOf() private val pending = mutableMapOf() - private val routeCalculation = RouteCalculation(nodeParams.loggerFactory) + + /** + * While a payment is in progress, we wait for the outgoing HTLC to settle. + * When we receive a failure, we retry with a different channel or fees. + * + * @param request payment request containing the total amount to send. + * @param attemptNumber number of failed previous payment attempts. + * @param pending pending outgoing payment. + * @param sharedSecrets payment onion shared secrets, used to decrypt failures. + * @param failures previous payment failures. + */ + data class PaymentAttempt( + val request: PayInvoice, + val attemptNumber: Int, + val pending: LightningOutgoingPayment.Part, + val sharedSecrets: SharedSecrets, + val failures: List> + ) { + val fees: MilliSatoshi = pending.amount - request.amount + } // NB: this function should only be used in tests. fun getPendingPayment(parentId: UUID): PaymentAttempt? = pending[parentId] @@ -72,404 +88,273 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle logger.error { "invoice has already been paid" } return Failure(request, FinalFailure.AlreadyPaid.toPaymentFailure()) } - val trampolineFees = request.trampolineFeesOverride ?: walletParams.trampolineFees - val (trampolineAmount, trampolineExpiry, trampolinePacket) = createTrampolinePayload(request, trampolineFees.first(), currentBlockHeight) - return when (val result = routeCalculation.findRoutes(request.paymentId, trampolineAmount, channels)) { + val trampolineFees = (request.trampolineFeesOverride ?: walletParams.trampolineFees).first() + val trampolineAmount = request.amount + trampolineFees.calculateFees(request.amount) + return when (val result = selectChannel(trampolineAmount, channels)) { is Either.Left -> { logger.warning { "payment failed: ${result.value}" } db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails)) - val finalFailure = result.value - db.completeOutgoingPaymentOffchain(request.paymentId, finalFailure) - Failure(request, finalFailure.toPaymentFailure()) + db.completeOutgoingPaymentOffchain(request.paymentId, result.value) + Failure(request, result.value.toPaymentFailure()) } is Either.Right -> { - // We generate a random secret for this payment to avoid leaking the invoice secret to the trampoline node. - val trampolinePaymentSecret = Lightning.randomBytes32() - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolinePacket) - val childPayments = createChildPayments(request, result.value, trampolinePayload) - db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, childPayments.map { it.first }, LightningOutgoingPayment.Status.Pending)) - val payment = PaymentAttempt.PaymentInProgress(request, 0, trampolinePayload, childPayments.associate { it.first.id to Pair(it.first, it.second) }, setOf(), listOf()) + val hop = NodeHop(walletParams.trampolineNode.id, request.recipient, trampolineFees.cltvExpiryDelta, trampolineFees.calculateFees(request.amount)) + val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(request, result.value, hop, currentBlockHeight) + val payment = PaymentAttempt(request, 0, childPayment, sharedSecrets, listOf()) + db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, listOf(childPayment), LightningOutgoingPayment.Status.Pending)) pending[request.paymentId] = payment - Progress(request, payment.fees, childPayments.map { it.third }) + Progress(request, payment.fees, listOf(cmd)) } } } - private fun createChildPayments(request: PayInvoice, routes: List, trampolinePayload: PaymentAttempt.TrampolinePayload): List> { - val logger = MDCLogger(logger, staticMdc = request.mdc()) - val childPayments = routes.map { createOutgoingPart(request, it, trampolinePayload) } - childToParentId.putAll(childPayments.map { it.first.id to request.paymentId }) - childPayments.forEach { logger.info(mapOf("childPaymentId" to it.first.id)) { "sending ${it.first.amount} to channel ${it.third.channelId}" } } - return childPayments - } - - suspend fun processAddFailed(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddFailed, channels: Map): ProcessFailureResult? { - val add = event.cmd - val payment = getPaymentAttempt(add.paymentId) ?: return processPostRestartFailure(add.paymentId, Either.Left(event.error)) - val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to add.paymentId) + payment.request.mdc()) - - logger.debug { "could not send HTLC: ${event.error.message}" } - db.completeOutgoingLightningPart(add.paymentId, OutgoingPaymentFailure.convertFailure(Either.Left(event.error))) - - val (updated, result) = when (payment) { - is PaymentAttempt.PaymentInProgress -> { - val ignore = payment.ignore + channelId // we ignore the failing channel in retries - when (val routes = routeCalculation.findRoutes(payment.request.paymentId, add.amount, channels - ignore)) { - is Either.Left -> PaymentAttempt.PaymentAborted(payment.request, routes.value, payment.pending, payment.failures).failChild(add.paymentId, Either.Left(event.error), db, logger) - is Either.Right -> { - val newPayments = createChildPayments(payment.request, routes.value, payment.trampolinePayload) - db.addOutgoingLightningParts(payment.request.paymentId, newPayments.map { it.first }) - val updatedPayments = payment.pending - add.paymentId + newPayments.map { it.first.id to Pair(it.first, it.second) } - val updated = payment.copy(ignore = ignore, failures = payment.failures + Either.Left(event.error), pending = updatedPayments) - val result = Progress(payment.request, updated.fees, newPayments.map { it.third }) - Pair(updated, result) - } - } - } - is PaymentAttempt.PaymentAborted -> payment.failChild(add.paymentId, Either.Left(event.error), db, logger) - is PaymentAttempt.PaymentSucceeded -> payment.failChild(add.paymentId, db, logger) + /** + * This may happen if we hit channel limits (e.g. max-accepted-htlcs). + * This is a temporary failure that we cannot automatically resolve: we must wait for the channel to be usable again. + */ + suspend fun processAddFailed(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddFailed): Failure? { + val payment = getPaymentAttempt(event.cmd.paymentId) ?: return processPostRestartFailure(event.cmd.paymentId, Either.Left(event.error)) + val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to event.cmd.paymentId) + payment.request.mdc()) + + if (payment.pending.id != event.cmd.paymentId) { + logger.warning { "ignoring HTLC that does not match pending payment part (${event.cmd.paymentId} != ${payment.pending.id})" } + return null } - updateGlobalState(add.paymentId, updated) - - return result + logger.info { "could not send HTLC: ${event.error.message}" } + db.completeOutgoingLightningPart(event.cmd.paymentId, OutgoingPaymentFailure.convertFailure(Either.Left(event.error))) + db.completeOutgoingPaymentOffchain(payment.request.paymentId, FinalFailure.NoAvailableChannels) + removeFromState(payment.request.paymentId) + return Failure(payment.request, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, payment.failures + Either.Left(event.error))) } suspend fun processAddSettled(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddSettledFail, channels: Map, currentBlockHeight: Int): ProcessFailureResult? { - val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFailure(event.paymentId, Either.Right(UnknownFailureMessage(0))) + val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFailure(event.paymentId, Either.Left(CannotDecryptFailure(channelId, "restarted"))) val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to event.paymentId) + payment.request.mdc()) - val failure: FailureMessage = when (event.result) { - is ChannelAction.HtlcResult.Fail.RemoteFail -> when (val part = payment.pending[event.paymentId]) { - null -> UnknownFailureMessage(0) - else -> when (val decrypted = FailurePacket.decrypt(event.result.fail.reason.toByteArray(), part.second)) { - is Try.Failure -> UnknownFailureMessage(1) - is Try.Success -> decrypted.result.failureMessage + if (payment.pending.id != event.paymentId) { + logger.warning { "ignoring HTLC that does not match latest payment part (${event.paymentId} != ${payment.pending.id})" } + return null + } + + val failure = when (event.result) { + is ChannelAction.HtlcResult.Fail.RemoteFail -> when (val decrypted = FailurePacket.decrypt(event.result.fail.reason.toByteArray(), payment.sharedSecrets)) { + is Try.Failure -> { + logger.warning { "could not decrypt failure packet: ${decrypted.error.message}" } + Either.Left(CannotDecryptFailure(channelId, decrypted.error.message ?: "unknown")) + } + is Try.Success -> { + logger.debug { "HTLC failed: ${decrypted.result.failureMessage.message}" } + Either.Right(decrypted.result.failureMessage) } } - else -> UnknownFailureMessage(FailureMessage.BADONION) + is ChannelAction.HtlcResult.Fail.RemoteFailMalformed -> { + logger.warning { "our peer couldn't decrypt our payment onion (failureCode=${event.result.fail.failureCode})" } + Either.Left(CannotDecryptFailure(channelId, "malformed onion")) + } + is ChannelAction.HtlcResult.Fail.OnChainFail -> { + logger.warning { "channel closed while our HTLC was pending: ${event.result.cause.message}" } + Either.Left(event.result.cause) + } + is ChannelAction.HtlcResult.Fail.Disconnected -> { + logger.warning { "we got disconnected before signing outgoing HTLC" } + Either.Left(CannotSignDisconnected(channelId)) + } } - logger.debug { "HTLC failed: ${failure.message}" } - db.completeOutgoingLightningPart(event.paymentId, OutgoingPaymentFailure.convertFailure(Either.Right(failure))) + // We update the status in our DB. + db.completeOutgoingLightningPart(event.paymentId, OutgoingPaymentFailure.convertFailure(failure)) - val (updated, result) = when (payment) { - is PaymentAttempt.PaymentInProgress -> { - val trampolineFees = payment.request.trampolineFeesOverride ?: walletParams.trampolineFees - val finalError = when { - trampolineFees.size <= payment.attemptNumber + 1 -> FinalFailure.RetryExhausted - failure == UnknownNextPeer -> FinalFailure.RecipientUnreachable - failure != TrampolineExpiryTooSoon && failure != TrampolineFeeInsufficient -> FinalFailure.UnknownError // non-retriable error - else -> null + val trampolineFees = payment.request.trampolineFeesOverride ?: walletParams.trampolineFees + val finalError = when { + trampolineFees.size <= payment.attemptNumber + 1 -> FinalFailure.RetryExhausted + failure == Either.Right(UnknownNextPeer) -> FinalFailure.RecipientUnreachable + failure != Either.Right(TrampolineExpiryTooSoon) && failure != Either.Right(TrampolineFeeInsufficient) -> FinalFailure.UnknownError // non-retriable error + else -> null + } + return if (finalError != null) { + db.completeOutgoingPaymentOffchain(payment.request.paymentId, finalError) + removeFromState(payment.request.paymentId) + Failure(payment.request, OutgoingPaymentFailure(finalError, payment.failures + failure)) + } else { + // The trampoline node is asking us to retry the payment with more fees or a larger expiry delta. + val nextFees = trampolineFees[payment.attemptNumber + 1] + logger.info { "retrying payment with higher fees (base=${nextFees.feeBase}, proportional=${nextFees.feeProportional})..." } + val trampolineAmount = payment.request.amount + nextFees.calculateFees(payment.request.amount) + when (val result = selectChannel(trampolineAmount, channels)) { + is Either.Left -> { + logger.warning { "payment failed: ${result.value}" } + db.completeOutgoingPaymentOffchain(payment.request.paymentId, result.value) + removeFromState(payment.request.paymentId) + Failure(payment.request, OutgoingPaymentFailure(result.value, payment.failures + failure)) } - if (finalError != null) { - PaymentAttempt.PaymentAborted(payment.request, finalError, payment.pending, listOf()).failChild(event.paymentId, Either.Right(failure), db, logger) - } else { - // The trampoline node is asking us to retry the payment with more fees. - logger.debug { "child payment failed because of fees" } - val updated = payment.copy(pending = payment.pending - event.paymentId) - if (updated.pending.isNotEmpty()) { - // We wait for all pending HTLCs to be settled before retrying. - // NB: we don't update failures here to avoid duplicate trampoline errors - Pair(updated, null) - } else { - val nextFees = trampolineFees[payment.attemptNumber + 1] - logger.info { "retrying payment with higher fees (base=${nextFees.feeBase}, proportional=${nextFees.feeProportional})..." } - val (trampolineAmount, trampolineExpiry, trampolinePacket) = createTrampolinePayload(payment.request, nextFees, currentBlockHeight) - when (val routes = routeCalculation.findRoutes(payment.request.paymentId, trampolineAmount, channels)) { - is Either.Left -> { - logger.warning { "payment failed: ${routes.value}" } - val aborted = PaymentAttempt.PaymentAborted(payment.request, routes.value, mapOf(), payment.failures + Either.Right(failure)) - val result = Failure(payment.request, OutgoingPaymentFailure(aborted.reason, aborted.failures)) - db.completeOutgoingPaymentOffchain(payment.request.paymentId, result.failure.reason) - Pair(aborted, result) - } - is Either.Right -> { - // We generate a random secret for this payment to avoid leaking the invoice secret to the trampoline node. - val trampolinePaymentSecret = Lightning.randomBytes32() - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolinePacket) - val childPayments = createChildPayments(payment.request, routes.value, trampolinePayload) - db.addOutgoingLightningParts(payment.request.paymentId, childPayments.map { it.first }) - val newAttempt = PaymentAttempt.PaymentInProgress( - payment.request, - payment.attemptNumber + 1, - trampolinePayload, - childPayments.associate { it.first.id to Pair(it.first, it.second) }, - setOf(), // we reset ignored channels - payment.failures + Either.Right(failure) - ) - val result = Progress(newAttempt.request, newAttempt.fees, childPayments.map { it.third }) - Pair(newAttempt, result) - } - } - } + is Either.Right -> { + val hop = NodeHop(walletParams.trampolineNode.id, payment.request.recipient, nextFees.cltvExpiryDelta, nextFees.calculateFees(payment.request.amount)) + val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(payment.request, result.value, hop, currentBlockHeight) + db.addOutgoingLightningParts(payment.request.paymentId, listOf(childPayment)) + val payment1 = PaymentAttempt( + request = payment.request, + attemptNumber = payment.attemptNumber + 1, + pending = childPayment, + sharedSecrets = sharedSecrets, + failures = payment.failures + failure + ) + pending[payment1.request.paymentId] = payment1 + Progress(payment1.request, payment1.fees, listOf(cmd)) } } - is PaymentAttempt.PaymentAborted -> payment.failChild(event.paymentId, Either.Right(failure), db, logger) - is PaymentAttempt.PaymentSucceeded -> payment.failChild(event.paymentId, db, logger) } - - updateGlobalState(event.paymentId, updated) - - return result } - private suspend fun processPostRestartFailure(partId: UUID, failure: Either): ProcessFailureResult? { - when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { + private suspend fun processPostRestartFailure(partId: UUID, failure: Either): Failure? { + return when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { null -> { logger.error { "paymentId=$partId doesn't match any known payment attempt" } - return null + null } else -> { val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to partId) + payment.mdc()) logger.debug { "could not send HTLC (wallet restart): ${failure.fold({ it.message }, { it.message })}" } val status = LightningOutgoingPayment.Part.Status.Failed(OutgoingPaymentFailure.convertFailure(failure)) db.completeOutgoingLightningPart(partId, status.failure) - val hasMorePendingParts = payment.parts.any { it.status == LightningOutgoingPayment.Part.Status.Pending && it.id != partId } - return if (!hasMorePendingParts) { - logger.warning { "payment failed: ${FinalFailure.WalletRestarted}" } - db.completeOutgoingPaymentOffchain(payment.id, FinalFailure.WalletRestarted) - val request = when (payment.details) { - is LightningOutgoingPayment.Details.Normal -> PayInvoice(payment.id, payment.recipientAmount, payment.details) - else -> { - logger.debug { "cannot recreate send-payment-request failure from db data with details=${payment.details}" } - return null + when (payment.details) { + is LightningOutgoingPayment.Details.Normal -> { + val request = PayInvoice(payment.id, payment.recipientAmount, payment.details) + val remainingParts = payment.parts.filter { it.id != partId && it.status !is LightningOutgoingPayment.Part.Status.Failed } + if (remainingParts.isEmpty()) { + logger.warning { "payment failed: ${FinalFailure.WalletRestarted}" } + db.completeOutgoingPaymentOffchain(payment.id, FinalFailure.WalletRestarted) + removeFromState(payment.id) + val failures = payment.parts.map { it.status }.filterIsInstance() + status + Failure(request, OutgoingPaymentFailure(FinalFailure.WalletRestarted, failures)) + } else { + logger.warning { "some payment parts haven't been failed after restart: ${remainingParts.map { it.id }.joinToString(", ")}" } + null } } - Failure( - request = request, - failure = OutgoingPaymentFailure( - reason = FinalFailure.WalletRestarted, - failures = payment.parts.map { it.status }.filterIsInstance() + status - ) - ) - } else { - null + else -> { + logger.debug { "cannot recreate payment request from db data with details=${payment.details}" } + null + } } } } } - suspend fun processAddSettled(event: ChannelAction.ProcessCmdRes.AddSettledFulfill): ProcessFulfillResult? { + suspend fun processAddSettled(event: ChannelAction.ProcessCmdRes.AddSettledFulfill): Success? { val preimage = event.result.paymentPreimage val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFulfill(event.paymentId, preimage) val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to event.paymentId) + payment.request.mdc()) - logger.debug { "HTLC fulfilled" } - val part = payment.pending[event.paymentId]?.first?.copy(status = LightningOutgoingPayment.Part.Status.Succeeded(preimage)) - db.completeOutgoingLightningPart(event.paymentId, preimage) - - val updated = when (payment) { - is PaymentAttempt.PaymentInProgress -> PaymentAttempt.PaymentSucceeded(payment.request, preimage, part?.let { listOf(it) } ?: listOf(), payment.pending - event.paymentId) - is PaymentAttempt.PaymentSucceeded -> payment.copy(pending = payment.pending - event.paymentId, parts = part?.let { payment.parts + it } ?: payment.parts) - is PaymentAttempt.PaymentAborted -> { - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - logger.warning { "payment succeeded after partial failure: we may have paid less than the full amount" } - PaymentAttempt.PaymentSucceeded(payment.request, preimage, part?.let { listOf(it) } ?: listOf(), payment.pending - event.paymentId) - } + if (payment.pending.id != event.paymentId) { + logger.warning { "ignoring HTLC that does not match latest payment part (${event.paymentId} != ${payment.pending.id})" } + return null } - updateGlobalState(event.paymentId, updated) - - return if (updated.isComplete()) { - logger.info { "payment successfully sent (fees=${updated.fees})" } - db.completeOutgoingPaymentOffchain(payment.request.paymentId, preimage) - val r = payment.request - Success(r, LightningOutgoingPayment(r.paymentId, r.amount, r.recipient, r.paymentDetails, updated.parts, LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage)), preimage) - } else { - PreimageReceived(payment.request, preimage) - } + logger.info { "payment successfully sent (fees=${payment.fees})" } + db.completeOutgoingLightningPart(event.paymentId, preimage) + db.completeOutgoingPaymentOffchain(payment.request.paymentId, preimage) + removeFromState(payment.request.paymentId) + val status = LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage) + val part = payment.pending.copy(status = LightningOutgoingPayment.Part.Status.Succeeded(preimage)) + val result = LightningOutgoingPayment(payment.request.paymentId, payment.request.amount, payment.request.recipient, payment.request.paymentDetails, listOf(part), status) + return Success(payment.request, result, preimage) } - private suspend fun processPostRestartFulfill(partId: UUID, preimage: ByteVector32): ProcessFulfillResult? { - when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { + private suspend fun processPostRestartFulfill(partId: UUID, preimage: ByteVector32): Success? { + return when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { null -> { logger.error { "paymentId=$partId doesn't match any known payment attempt" } - return null + null } else -> { val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to partId) + payment.mdc()) - logger.debug { "HTLC succeeded (wallet restart): $preimage" } db.completeOutgoingLightningPart(partId, preimage) - // We try to re-create the request from what we have in the DB. - val request = when (payment.details) { - is LightningOutgoingPayment.Details.Normal -> PayInvoice(payment.id, payment.recipientAmount, payment.details) + when (payment.details) { + is LightningOutgoingPayment.Details.Normal -> { + logger.info { "payment successfully sent (wallet restart)" } + val request = PayInvoice(payment.id, payment.recipientAmount, payment.details) + db.completeOutgoingPaymentOffchain(payment.id, preimage) + removeFromState(payment.id) + // NB: we reload the payment to ensure all parts status are updated + // this payment cannot be null + val succeeded = db.getLightningOutgoingPayment(payment.id)!! + Success(request, succeeded, preimage) + } else -> { logger.warning { "cannot recreate send-payment-request fulfill from db data with details=${payment.details}" } - return null + null } } - val hasMorePendingParts = payment.parts.any { it.status == LightningOutgoingPayment.Part.Status.Pending && it.id != partId } - return if (!hasMorePendingParts) { - logger.info { "payment successfully sent (wallet restart)" } - db.completeOutgoingPaymentOffchain(payment.id, preimage) - // NB: we reload the payment to ensure all parts status are updated - // this payment cannot be null - val succeeded = db.getLightningOutgoingPayment(payment.id)!! - Success(request, succeeded, preimage) - } else { - PreimageReceived(request, preimage) - } } } } - private fun updateGlobalState(processedChildId: UUID, updatedPayment: PaymentAttempt) { - childToParentId.remove(processedChildId) - if (updatedPayment.isComplete()) { - pending.remove(updatedPayment.request.paymentId) - } else { - pending[updatedPayment.request.paymentId] = updatedPayment + private fun removeFromState(paymentId: UUID) { + val children = childToParentId.filterValues { it == paymentId }.keys + children.forEach { childToParentId.remove(it) } + pending.remove(paymentId) + } + + /** + * We assume that we have at most one channel with our trampoline node. + * We return it if it's ready and has enough balance for the payment, otherwise we return a failure. + */ + private fun selectChannel(toSend: MilliSatoshi, channels: Map): Either { + return when (val available = channels.values.firstOrNull { it is Normal }) { + is Normal -> when { + toSend < available.channelUpdate.htlcMinimumMsat -> Either.Left(FinalFailure.InvalidPaymentAmount) + available.commitments.availableBalanceForSend() < toSend -> Either.Left(FinalFailure.InsufficientBalance) + else -> Either.Right(available) + } + else -> { + val failure = when { + channels.values.any { it is Syncing || it is Offline } -> FinalFailure.ChannelNotConnected + channels.values.any { it is WaitForOpenChannel || it is WaitForAcceptChannel || it is WaitForFundingCreated || it is WaitForFundingSigned || it is WaitForFundingConfirmed || it is WaitForChannelReady } -> FinalFailure.ChannelOpening + channels.values.any { it is ShuttingDown || it is Negotiating || it is Closing || it is Closed || it is WaitForRemotePublishFutureCommitment } -> FinalFailure.ChannelClosing + else -> FinalFailure.NoAvailableChannels + } + Either.Left(failure) + } } } - private fun createOutgoingPart(request: PayInvoice, route: RouteCalculation.Route, trampolinePayload: PaymentAttempt.TrampolinePayload): Triple { + private fun createOutgoingPayment(request: PayInvoice, channel: Normal, hop: NodeHop, currentBlockHeight: Int): Triple { + val logger = MDCLogger(logger, staticMdc = request.mdc()) val childId = UUID.randomUUID() + childToParentId[childId] = request.paymentId + val (amount, expiry, onion) = createPaymentOnion(request, hop, currentBlockHeight) val outgoingPayment = LightningOutgoingPayment.Part( id = childId, - amount = route.amount, - route = listOf(HopDesc(nodeParams.nodeId, route.channel.commitments.remoteNodeId, route.channel.shortChannelId), HopDesc(route.channel.commitments.remoteNodeId, request.recipient)), + amount = amount, + route = listOf(HopDesc(nodeParams.nodeId, hop.nodeId, channel.shortChannelId), HopDesc(hop.nodeId, hop.nextNodeId)), status = LightningOutgoingPayment.Part.Status.Pending ) - val channelHops: List = listOf(ChannelHop(nodeParams.nodeId, route.channel.commitments.remoteNodeId, route.channel.channelUpdate)) - val (cmdAdd, secrets) = OutgoingPaymentPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) - return Triple(outgoingPayment, secrets, WrappedChannelCommand(route.channel.channelId, cmdAdd)) + logger.info { "sending $amount to channel ${channel.shortChannelId}" } + val add = ChannelCommand.Htlc.Add(amount, request.paymentHash, expiry, onion.packet, paymentId = childId, commit = true) + return Triple(outgoingPayment, onion.sharedSecrets, WrappedChannelCommand(channel.channelId, add)) } - private fun createTrampolinePayload(request: PayInvoice, fees: TrampolineFees, currentBlockHeight: Int): Triple { - // We are either directly paying our peer (the trampoline node) or a remote node via our peer (using trampoline). - val trampolineRoute = when (request.recipient) { - walletParams.trampolineNode.id -> listOf( - NodeHop(nodeParams.nodeId, request.recipient, /* ignored */ CltvExpiryDelta(0), /* ignored */ 0.msat) - ) - else -> listOf( - NodeHop(nodeParams.nodeId, walletParams.trampolineNode.id, /* ignored */ CltvExpiryDelta(0), /* ignored */ 0.msat), - NodeHop(walletParams.trampolineNode.id, request.recipient, fees.cltvExpiryDelta, fees.calculateFees(request.amount)) - ) - } - when (val paymentRequest = request.paymentDetails.paymentRequest) { + private fun createPaymentOnion(request: PayInvoice, hop: NodeHop, currentBlockHeight: Int): Triple { + return when (val paymentRequest = request.paymentDetails.paymentRequest) { is Bolt11Invoice -> { val minFinalExpiryDelta = paymentRequest.minFinalExpiryDelta ?: Channel.MIN_CLTV_EXPIRY_DELTA - val finalExpiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) - val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(request.amount, finalExpiry, paymentRequest.paymentSecret, paymentRequest.paymentMetadata) + val expiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) val invoiceFeatures = paymentRequest.features - val (trampolineAmount, trampolineExpiry, trampolineOnion) = if (invoiceFeatures.hasFeature(Feature.TrampolinePayment) || invoiceFeatures.hasFeature(Feature.ExperimentalTrampolinePayment)) { - // We may be paying an older version of lightning-kmp that only supports trampoline packets of size 400. - OutgoingPaymentPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, 400) + if (request.recipient == walletParams.trampolineNode.id) { + // We are directly paying our trampoline node. + OutgoingPaymentPacket.buildPacketToTrampolinePeer(paymentRequest, request.amount, expiry) + } else if (invoiceFeatures.hasFeature(Feature.TrampolinePayment) || invoiceFeatures.hasFeature(Feature.ExperimentalTrampolinePayment)) { + OutgoingPaymentPacket.buildPacketToTrampolineRecipient(paymentRequest, request.amount, expiry, hop) } else { - OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(paymentRequest, trampolineRoute, finalPayload) + OutgoingPaymentPacket.buildPacketToLegacyRecipient(paymentRequest, request.amount, expiry, hop) } - return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) } is Bolt12Invoice -> { - val finalExpiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, CltvExpiryDelta(0)) - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(paymentRequest, trampolineRoute.last(), request.amount, finalExpiry) - return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) - } - } - } - - sealed class PaymentAttempt { - abstract val request: PayInvoice - abstract val pending: Map> - abstract val fees: MilliSatoshi - - fun isComplete(): Boolean = pending.isEmpty() - - /** - * @param totalAmount total amount that the trampoline node should receive. - * @param expiry expiry at the trampoline node. - * @param paymentSecret trampoline payment secret (should be different from the invoice payment secret). - * @param packet trampoline onion packet. - */ - data class TrampolinePayload(val totalAmount: MilliSatoshi, val expiry: CltvExpiry, val paymentSecret: ByteVector32, val packet: OnionRoutingPacket) { - fun createFinalPayload(partialAmount: MilliSatoshi): PaymentOnion.FinalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) - } - - /** - * While a payment is in progress, we listen to child payments failures. - * When we receive failures, we retry the failed amount with different routes/fees. - * - * @param request payment request containing the total amount to send. - * @param attemptNumber number of failed previous payment attempts. - * @param trampolinePayload trampoline payload for the current payment attempt. - * @param pending pending child payments (HTLCs were sent, we are waiting for a fulfill or a failure). - * @param ignore channels that should be ignored (previously returned an error). - * @param failures previous child payment failures. - */ - data class PaymentInProgress( - override val request: PayInvoice, - val attemptNumber: Int, - val trampolinePayload: TrampolinePayload, - override val pending: Map>, - val ignore: Set, - val failures: List> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = pending.values.map { it.first.amount }.sum() - request.amount - } - - /** - * When we exhaust our retry attempts without success or encounter a non-recoverable error, we abort the payment. - * Once we're in that state, we wait for all the pending child payments to settle. - * - * @param request payment request containing the total amount to send. - * @param reason failure reason. - * @param pending pending child payments (we are waiting for them to be failed downstream). - * @param failures child payment failures. - */ - data class PaymentAborted( - override val request: PayInvoice, - val reason: FinalFailure, - override val pending: Map>, - val failures: List> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = 0.msat - - suspend fun failChild(childId: UUID, failure: Either, db: OutgoingPaymentsDb, logger: MDCLogger): Pair { - val updated = copy(pending = pending - childId, failures = failures + failure) - val result = if (updated.isComplete()) { - logger.warning { "payment failed: ${updated.reason}" } - db.completeOutgoingPaymentOffchain(request.paymentId, updated.reason) - Failure(request, OutgoingPaymentFailure(updated.reason, updated.failures)) - } else { - null - } - return Pair(updated, result) - } - } - - /** - * Once we receive a first fulfill for a child payment, we can consider that the whole payment succeeded (because we - * received the payment preimage that we can use as a proof of payment). - * Once we're in that state, we wait for all the pending child payments to fulfill. - * - * @param request payment request containing the total amount to send. - * @param preimage payment preimage. - * @param parts fulfilled child payments. - * @param pending pending child payments (we are waiting for them to be fulfilled downstream). - */ - data class PaymentSucceeded( - override val request: PayInvoice, - val preimage: ByteVector32, - val parts: List, - override val pending: Map> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = parts.map { it.amount }.sum() + pending.values.map { it.first.amount }.sum() - request.amount - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - suspend fun failChild(childId: UUID, db: OutgoingPaymentsDb, logger: MDCLogger): Pair { - logger.warning { "partial payment failure after fulfill: we may have paid less than the full amount" } - val updated = copy(pending = pending - childId) - val result = if (updated.isComplete()) { - logger.info { "payment successfully sent (fees=${updated.fees})" } - db.completeOutgoingPaymentOffchain(request.paymentId, preimage) - Success(request, LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, parts, LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage)), preimage) - } else { - null - } - return Pair(updated, result) + // The recipient already included a final cltv-expiry-delta in their invoice blinded paths. + val minFinalExpiryDelta = CltvExpiryDelta(0) + val expiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) + OutgoingPaymentPacket.buildPacketToBlindedRecipient(paymentRequest, request.amount, expiry, hop) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt index 920ac64d1..6c09586c1 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt @@ -6,17 +6,14 @@ import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.PublicKey import fr.acinq.bitcoin.utils.Either import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.Feature import fr.acinq.lightning.Lightning import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.channel.ChannelCommand import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets -import fr.acinq.lightning.crypto.sphinx.SharedSecrets import fr.acinq.lightning.crypto.sphinx.Sphinx -import fr.acinq.lightning.router.ChannelHop -import fr.acinq.lightning.router.Hop import fr.acinq.lightning.router.NodeHop -import fr.acinq.lightning.utils.UUID import fr.acinq.lightning.wire.* object OutgoingPaymentPacket { @@ -24,128 +21,127 @@ object OutgoingPaymentPacket { /** * Build an encrypted onion packet from onion payloads and node public keys. */ - private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int?): PacketAndSecrets { - require(nodes.size == payloads.size) + fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int? = null): PacketAndSecrets { val sessionKey = Lightning.randomKey() + return buildOnion(sessionKey, nodes, payloads, associatedData, payloadLength) + } + + private fun buildOnion(sessionKey: PrivateKey, nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int? = null): PacketAndSecrets { + require(nodes.size == payloads.size) val payloadsBin = payloads.map { it.write() } val totalPayloadLength = payloadLength ?: payloadsBin.sumOf { it.size + Sphinx.MacLength } return Sphinx.create(sessionKey, nodes, payloadsBin, associatedData, totalPayloadLength) } /** - * Build the onion payloads for each hop. + * Build an encrypted payment onion packet when the final recipient supports trampoline. + * The trampoline node will receive instructions on how much to relay to the final recipient. * - * @param hops the hops as computed by the router + extra routes from payment request - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, payloads) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - a sequence of payloads that will be used to build the onion + * @param invoice a Bolt 11 invoice that contains the trampoline feature bit. + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. + * @param hop the trampoline hop from the trampoline node to the recipient. */ - private fun buildPayloads(hops: List, finalPayload: PaymentOnion.FinalPayload): Triple> { - return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> - val (amount, expiry, payloads) = triple - val payload = when (hop) { - // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. - is ChannelHop -> PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) - is NodeHop -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) - } - Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) + fun buildPacketToTrampolineRecipient(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + require(invoice.features.hasFeature(Feature.ExperimentalTrampolinePayment) || invoice.features.hasFeature(Feature.TrampolinePayment)) { "invoice must support trampoline" } + val trampolineOnion = run { + val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, invoice.paymentMetadata) + val trampolinePayload = PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + // We may be paying an older version of lightning-kmp that only supports trampoline packets of size 400. + buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, finalPayload), invoice.paymentHash, payloadLength = 400) } + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + // We generate a random secret to avoid leaking the invoice secret to the trampoline node. + val trampolinePaymentSecret = Lightning.randomBytes32() + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } /** - * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. - * The next-to-last trampoline node payload will contain instructions to convert to a legacy payment. + * Build an encrypted payment onion packet when the final recipient is our trampoline node. * - * @param invoice a Bolt11 invoice (features and routing hints will be provided to the next-to-last node). - * @param hops the trampoline hops (including ourselves in the first hop, and the non-trampoline final recipient in the last hop). - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) triple where: - * - firstAmount is the amount for the trampoline node in the route - * - firstExpiry is the cltv expiry for the first trampoline node in the route - * - the trampoline onion to include in final payload of a normal onion + * @param invoice a Bolt 11 invoice that contains the trampoline feature bit. + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. */ - fun buildTrampolineToNonTrampolinePacket(invoice: Bolt11Invoice, hops: List, finalPayload: PaymentOnion.FinalPayload.Standard): Triple { - // NB: the final payload will never reach the recipient, since the next-to-last trampoline hop will convert that to a legacy payment - // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. - val dummyFinalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, finalPayload.paymentSecret, null) - val (firstAmount, firstExpiry, initialPayloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(dummyFinalPayload))) { triple, hop -> - val (amount, expiry, payloads) = triple - val payload = when (payloads.size) { - // The next-to-last trampoline hop must include invoice data to indicate the conversion to a legacy payment. - 1 -> PaymentOnion.RelayToNonTrampolinePayload.create(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) - else -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) - } - Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) - } - var payloads = initialPayloads - val nodes = hops.map { it.nextNodeId } - var onion = buildOnion(nodes, payloads, invoice.paymentHash, payloadLength = null) - // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. - while (onion.packet.payload.size() > 1000) { - payloads = payloads.map { payload -> when (payload) { - is PaymentOnion.RelayToNonTrampolinePayload -> payload.copy(records = payload.records.copy(records = payload.records.records.map { when (it) { - is OnionPaymentPayloadTlv.InvoiceRoutingInfo -> OnionPaymentPayloadTlv.InvoiceRoutingInfo(it.extraHops.dropLast(1)) - else -> it - } }.toSet())) - else -> payload - } } - onion = buildOnion(nodes, payloads, invoice.paymentHash, payloadLength = null) + fun buildPacketToTrampolinePeer(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry): Triple { + require(invoice.features.hasFeature(Feature.ExperimentalTrampolinePayment) || invoice.features.hasFeature(Feature.TrampolinePayment)) { "invoice must support trampoline" } + val trampolineOnion = run { + val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, invoice.paymentMetadata) + buildOnion(listOf(invoice.nodeId), listOf(finalPayload), invoice.paymentHash) } - return Triple(firstAmount, firstExpiry, onion) + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, invoice.paymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(invoice.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(amount, expiry, paymentOnion) } /** - * Build an encrypted trampoline onion packet when the final recipient is using a blinded path. - * The trampoline payload will contain data from the invoice to allow the trampoline node to pay the blinded path. - * We only need a single trampoline node, who will find a route to the blinded path's introduction node without learning the recipient's identity. + * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. + * The trampoline node will receive instructions to convert to a legacy payment. + * This reveals to the trampoline node who the recipient is and details from the invoice. + * This must be deprecated once recipients support either trampoline or blinded paths. * - * @param invoice a Bolt12 invoice (blinded path data will be provided to the trampoline node). + * @param invoice a Bolt11 invoice (features and routing hints will be provided to the trampoline node). + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. * @param hop the trampoline hop from the trampoline node to the recipient. - * @param finalAmount amount that should be received by the final recipient. - * @param finalExpiry cltv expiry that should be received by the final recipient. */ - fun buildTrampolineToNonTrampolinePacket(invoice: Bolt12Invoice, hop: NodeHop, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry): Triple { - var payload = PaymentOnion.RelayToBlindedPayload.create(finalAmount, finalExpiry, invoice) - var onion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, payloadLength = null) - // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. - while (onion.packet.payload.size() > 1000) { - payload = payload.copy(records = payload.records.copy(records = payload.records.records.map { when (it) { - is OnionPaymentPayloadTlv.OutgoingBlindedPaths -> OnionPaymentPayloadTlv.OutgoingBlindedPaths(it.paths.dropLast(1)) - else -> it - } }.toSet())) - onion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, payloadLength = null) + fun buildPacketToLegacyRecipient(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + val trampolineOnion = run { + // NB: the final payload will never reach the recipient, since the trampoline node will convert that to a legacy payment. + // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. + val dummyFinalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, null) + var routingInfo = invoice.routingInfo + var trampolinePayload = PaymentOnion.RelayToNonTrampolinePayload.create(amount, amount, expiry, hop.nextNodeId, invoice, routingInfo) + var trampolineOnion = buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, dummyFinalPayload), invoice.paymentHash) + // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. + while (trampolineOnion.packet.payload.size() > 1000) { + routingInfo = routingInfo.dropLast(1) + trampolinePayload = PaymentOnion.RelayToNonTrampolinePayload.create(amount, amount, expiry, hop.nextNodeId, invoice, routingInfo) + trampolineOnion = buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, dummyFinalPayload), invoice.paymentHash) + } + trampolineOnion } - return Triple(finalAmount + hop.fee(finalAmount), finalExpiry + hop.cltvExpiryDelta, onion) - } - - /** - * Build an encrypted onion packet with the given final payload. - * - * @param hops the hops as computed by the router + extra routes from payment request, including ourselves in the first hop - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - the onion to include in the HTLC - */ - fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload, payloadLength: Int?): Triple { - val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) - val nodes = hops.map { it.nextNodeId } - // BOLT 2 requires that associatedData == paymentHash - val onion = buildOnion(nodes, payloads, paymentHash, payloadLength) - return Triple(firstAmount, firstExpiry, onion) + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, invoice.paymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } /** - * Build the command to add an HTLC with the given final payload and using the provided hops. + * Build an encrypted trampoline onion packet when the final recipient is using a blinded path. + * The trampoline node will receive data from the invoice to allow them to pay the blinded path. + * The data revealed to the trampoline node doesn't leak anything about the recipient's identity. + * We only need a single trampoline node, who will find routes to the blinded paths. * - * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) + * @param invoice a Bolt12 invoice (blinded path data will be provided to the trampoline node). + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. + * @param hop the trampoline hop from the trampoline node to the recipient. */ - fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Pair { - val (firstAmount, firstExpiry, onion) = buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) - return Pair(ChannelCommand.Htlc.Add(firstAmount, paymentHash, firstExpiry, onion.packet, paymentId, commit = true), onion.sharedSecrets) + fun buildPacketToBlindedRecipient(invoice: Bolt12Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + val trampolineOnion = run { + var blindedPaths = invoice.blindedPaths + var trampolinePayload = PaymentOnion.RelayToBlindedPayload.create(amount, expiry, invoice.features, blindedPaths) + var trampolineOnion = buildOnion(listOf(hop.nodeId), listOf(trampolinePayload), invoice.paymentHash) + // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. + while (trampolineOnion.packet.payload.size() > 1000) { + blindedPaths = blindedPaths.dropLast(1) + trampolinePayload = PaymentOnion.RelayToBlindedPayload.create(amount, expiry, invoice.features, blindedPaths) + trampolineOnion = buildOnion(listOf(hop.nodeId), listOf(trampolinePayload), invoice.paymentHash) + } + trampolineOnion + } + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + // We generate a random secret to avoid leaking the invoice secret to the trampoline node. + val trampolinePaymentSecret = Lightning.randomBytes32() + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } fun buildHtlcFailure(nodeSecret: PrivateKey, paymentHash: ByteVector32, onion: OnionRoutingPacket, reason: ChannelCommand.Htlc.Settlement.Fail.Reason): Either { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt deleted file mode 100644 index ac4e94654..000000000 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt +++ /dev/null @@ -1,61 +0,0 @@ -package fr.acinq.lightning.payment - -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.Satoshi -import fr.acinq.bitcoin.utils.Either -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.channel.states.* -import fr.acinq.lightning.logging.LoggerFactory -import fr.acinq.lightning.logging.MDCLogger -import fr.acinq.lightning.utils.UUID -import fr.acinq.lightning.utils.msat - -class RouteCalculation(loggerFactory: LoggerFactory) { - - private val logger = loggerFactory.newLogger(this::class) - - data class Route(val amount: MilliSatoshi, val channel: Normal) - - data class ChannelBalance(val c: Normal) { - val balance: MilliSatoshi = c.commitments.availableBalanceForSend() - val capacity: Satoshi = c.commitments.latest.fundingAmount - } - - fun findRoutes(paymentId: UUID, amount: MilliSatoshi, channels: Map): Either> { - val logger = MDCLogger(logger, staticMdc = mapOf("paymentId" to paymentId, "amount" to amount)) - - val sortedChannels = channels.values.filterIsInstance().map { ChannelBalance(it) }.sortedBy { it.balance }.reversed() - if (sortedChannels.isEmpty()) { - val failure = when { - channels.values.any { it is Syncing || it is Offline } -> FinalFailure.ChannelNotConnected - channels.values.any { it is WaitForOpenChannel || it is WaitForAcceptChannel || it is WaitForFundingCreated || it is WaitForFundingSigned || it is WaitForFundingConfirmed || it is WaitForChannelReady } -> FinalFailure.ChannelOpening - channels.values.any { it is ShuttingDown || it is Negotiating || it is Closing || it is WaitForRemotePublishFutureCommitment } -> FinalFailure.ChannelClosing - // This may happen if adding an HTLC failed because we hit channel limits (e.g. max-accepted-htlcs) and we're retrying with this channel filtered out. - else -> FinalFailure.NoAvailableChannels - } - logger.warning { "no available channels: $failure" } - return Either.Left(failure) - } - - val filteredChannels = sortedChannels.filter { it.balance >= it.c.channelUpdate.htlcMinimumMsat } - var remaining = amount - val routes = mutableListOf() - for (channel in filteredChannels) { - val toSend = channel.balance.min(remaining) - routes.add(Route(toSend, channel.c)) - remaining -= toSend - if (remaining == 0.msat) { - break - } - } - - return if (remaining > 0.msat) { - logger.info { "insufficient balance: ${sortedChannels.joinToString { "${it.c.shortChannelId}->${it.balance}/${it.capacity}" }}" } - Either.Left(FinalFailure.InsufficientBalance) - } else { - logger.info { "routes found: ${routes.map { "${it.channel.shortChannelId}->${it.amount}" }}" } - Either.Right(routes) - } - } - -} \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt index bde0d2bc6..73c651239 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt @@ -9,10 +9,7 @@ import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.io.Output import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.flatMap -import fr.acinq.lightning.CltvExpiry -import fr.acinq.lightning.CltvExpiryDelta -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId +import fr.acinq.lightning.* import fr.acinq.lightning.payment.Bolt11Invoice import fr.acinq.lightning.payment.Bolt12Invoice import fr.acinq.lightning.utils.msat @@ -498,7 +495,7 @@ object PaymentOnion { } } - fun create(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice): RelayToNonTrampolinePayload = + fun create(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice, routingInfo: List): RelayToNonTrampolinePayload = RelayToNonTrampolinePayload( TlvStream( buildSet { @@ -508,7 +505,7 @@ object PaymentOnion { add(OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, totalAmount)) invoice.paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } add(OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features.toByteArray().toByteVector())) - add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(invoice.routingInfo.map { it.hints })) + add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(routingInfo.map { it.hints })) } ) ) @@ -538,14 +535,14 @@ object PaymentOnion { } } - fun create(amount: MilliSatoshi, expiry: CltvExpiry, invoice: Bolt12Invoice): RelayToBlindedPayload = + fun create(amount: MilliSatoshi, expiry: CltvExpiry, features: Features, blindedPaths: List): RelayToBlindedPayload = RelayToBlindedPayload( TlvStream( setOf( OnionPaymentPayloadTlv.AmountToForward(amount), OnionPaymentPayloadTlv.OutgoingCltv(expiry), - OnionPaymentPayloadTlv.OutgoingBlindedPaths(invoice.blindedPaths), - OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features.toByteArray().toByteVector()) + OnionPaymentPayloadTlv.OutgoingBlindedPaths(blindedPaths), + OnionPaymentPayloadTlv.InvoiceFeatures(features.toByteArray().toByteVector()) ) ) ) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt index c3b63ad94..7606a1eff 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt @@ -14,7 +14,6 @@ import fr.acinq.lightning.json.JsonSerializers import fr.acinq.lightning.logging.MDCLogger import fr.acinq.lightning.logging.mdc import fr.acinq.lightning.payment.OutgoingPaymentPacket -import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.serialization.Serialization import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.testLoggerFactory @@ -429,16 +428,11 @@ object TestsHelper { } fun makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long, paymentPreimage: ByteVector32 = randomBytes32(), paymentId: UUID = UUID.randomUUID()): Pair { - val paymentHash: ByteVector32 = Crypto.sha256(paymentPreimage).toByteVector32() + val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() val expiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() - val dummyUpdate = ChannelUpdate(ByteVector64.Zeroes, BlockHash(ByteVector32.Zeroes), ShortChannelId(144, 0, 0), 0, 0, 0, CltvExpiryDelta(1), 0.msat, 0.msat, 0, null) - val cmd = OutgoingPaymentPacket.buildCommand( - paymentId, - paymentHash, - listOf(ChannelHop(dummyKey, destination, dummyUpdate)), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), null) - ).first.copy(commit = false) + val payload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), null) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destination), listOf(payload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + val cmd = ChannelCommand.Htlc.Add(amount, paymentHash, expiry, onion, paymentId, commit = false) return Pair(paymentPreimage, cmd) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt b/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt index ffcfed201..cc2852efa 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt @@ -485,101 +485,52 @@ class PeerTest : LightningTestSuite() { @Test fun `payment between two nodes -- with disconnection`() = runSuspendTest { - // We create two channels between Alice and Bob to ensure that the payment is split in two parts. - val (aliceChan1, bobChan1) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, bobFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobPushAmount = 0.msat) - val (aliceChan2, bobChan2) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, bobFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobPushAmount = 0.msat) - val nodeParams = Pair(aliceChan1.staticParams.nodeParams, bobChan1.staticParams.nodeParams) + val (alice0, bob0) = TestsHelper.reachNormal() + val nodeParams = Pair(alice0.staticParams.nodeParams, bob0.staticParams.nodeParams) val walletParams = Pair( // Alice must declare Bob as her trampoline node to enable direct payments. - TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(nodeParams.second.nodeId, "bob.com", 9735)), + TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(bob0.staticParams.nodeParams.nodeId, "bob.com", 9735)), TestConstants.Bob.walletParams ) - // Bob sends a multipart payment to Alice. - val (alice, bob, alice2bob1, bob2alice1) = newPeers(this, nodeParams, walletParams, listOf(aliceChan1 to bobChan1, aliceChan2 to bobChan2), automateMessaging = false) + val (alice, bob, alice2bob1, bob2alice1) = newPeers(this, nodeParams, walletParams, listOf(alice0 to bob0), automateMessaging = false) val invoice = alice.createInvoice(randomBytes32(), 150_000_000.msat, Either.Left("test invoice"), null) bob.send(PayInvoice(UUID.randomUUID(), invoice.amount!!, LightningOutgoingPayment.Details.Normal(invoice))) - // Bob sends one HTLC on each channel. - val htlcs = listOf( - bob2alice1.expect(), - bob2alice1.expect(), - ) - assertEquals(2, htlcs.map { it.channelId }.toSet().size) - val commitSigsBob = listOf( - bob2alice1.expect(), - bob2alice1.expect(), - ) + // Bob sends an HTLC to Alice. + alice.forward(bob2alice1.expect()) + alice.forward(bob2alice1.expect()) - // We cross-sign the HTLC on the first channel. - run { - val htlc = htlcs.find { it.channelId == aliceChan1.channelId } - assertNotNull(htlc) - alice.forward(htlc) - val commitSigBob = commitSigsBob.find { it.channelId == aliceChan1.channelId } - assertNotNull(commitSigBob) - alice.forward(commitSigBob) - bob.forward(alice2bob1.expect()) - bob.forward(alice2bob1.expect()) - alice.forward(bob2alice1.expect()) - } - // We start cross-signing the HTLC on the second channel. - run { - val htlc = htlcs.find { it.channelId == aliceChan2.channelId } - assertNotNull(htlc) - alice.forward(htlc) - val commitSigBob = commitSigsBob.find { it.channelId == aliceChan2.channelId } - assertNotNull(commitSigBob) - alice.forward(commitSigBob) - bob.forward(alice2bob1.expect()) - bob.forward(alice2bob1.expect()) - bob2alice1.expect() // Alice doesn't receive Bob's revocation. - } + // We start cross-signing the HTLC. + bob.forward(alice2bob1.expect()) + bob.forward(alice2bob1.expect()) + bob2alice1.expect() // Alice doesn't receive Bob's revocation. - // We disconnect before Alice receives Bob's revocation on the second channel. + // We disconnect before Alice receives Bob's revocation. alice.disconnect() alice.send(Disconnected) bob.disconnect() bob.send(Disconnected) // On reconnection, Bob retransmits its revocation. - val (_, _, alice2bob2, bob2alice2) = connect(this, connectionId = 1, alice, bob, channelsCount = 2, expectChannelReady = false, automateMessaging = false) + val (_, _, alice2bob2, bob2alice2) = connect(this, connectionId = 1, alice, bob, channelsCount = 1, expectChannelReady = false, automateMessaging = false) alice.forward(bob2alice2.expect(), connectionId = 1) // Alice has now received the complete payment and fulfills it. - val fulfills = listOf( - alice2bob2.expect(), - alice2bob2.expect(), - ) - val commitSigsAlice = listOf( - alice2bob2.expect(), - alice2bob2.expect(), - ) + bob.forward(alice2bob2.expect(), connectionId = 1) + bob.forward(alice2bob2.expect(), connectionId = 1) + alice.forward(bob2alice2.expect(), connectionId = 1) + bob2alice2.expect() // Alice doesn't receive Bob's signature. - // We fulfill the first HTLC. - run { - val fulfill = fulfills.find { it.channelId == aliceChan1.channelId } - assertNotNull(fulfill) - bob.forward(fulfill, connectionId = 1) - val commitSigAlice = commitSigsAlice.find { it.channelId == aliceChan1.channelId } - assertNotNull(commitSigAlice) - bob.forward(commitSigAlice, connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - bob.forward(alice2bob2.expect(), connectionId = 1) - } + // We disconnect before Alice receives Bob's signature. + alice.disconnect() + alice.send(Disconnected) + bob.disconnect() + bob.send(Disconnected) - // We fulfill the second HTLC. - run { - val fulfill = fulfills.find { it.channelId == aliceChan2.channelId } - assertNotNull(fulfill) - bob.forward(fulfill, connectionId = 1) - val commitSigAlice = commitSigsAlice.find { it.channelId == aliceChan2.channelId } - assertNotNull(commitSigAlice) - bob.forward(commitSigAlice, connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - bob.forward(alice2bob2.expect(), connectionId = 1) - } + // On reconnection, Bob retransmits its signature. + val (_, _, alice2bob3, bob2alice3) = connect(this, connectionId = 2, alice, bob, channelsCount = 1, expectChannelReady = false, automateMessaging = false) + alice.forward(bob2alice3.expect(), connectionId = 2) + bob.forward(alice2bob3.expect(), connectionId = 2) assertEquals(invoice.amount, alice.db.payments.getIncomingPayment(invoice.paymentHash)?.received?.amount) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt index a3aa4855a..0eaf9bd00 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt @@ -16,8 +16,6 @@ import fr.acinq.lightning.db.IncomingPaymentsDb import fr.acinq.lightning.io.AddLiquidityForIncomingPayment import fr.acinq.lightning.io.SendOnTheFlyFundingMessage import fr.acinq.lightning.io.WrappedChannelCommand -import fr.acinq.lightning.router.ChannelHop -import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.tests.utils.runSuspendTest @@ -342,15 +340,16 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val (paymentHandler, incomingPayment, paymentSecret) = createFixture(defaultAmount) checkDbPayment(incomingPayment, paymentHandler.db) val willAddHtlc = run { - // We simulate a trampoline-relay with a dummy channel hop between the liquidity provider and the wallet. - val (amount, expiry, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - incomingPayment.paymentHash, - listOf(NodeHop(TestConstants.Alice.nodeParams.nodeId, TestConstants.Bob.nodeParams.nodeId, CltvExpiryDelta(144), 0.msat)), - makeMppPayload(defaultAmount, defaultAmount, paymentSecret), - null - ) - assertTrue(trampolineOnion.packet.payload.size() < 500) - makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, randomBytes32(), trampolineOnion.packet)) + // We simulate a trampoline-relay: the trampoline node will relay the a payment onion with a dummy channel hop. + val trampolinePayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret) + val trampolineOnion = OutgoingPaymentPacket.buildOnion( + nodes = listOf(TestConstants.Bob.nodeParams.nodeId), + payloads = listOf(trampolinePayload), + associatedData = incomingPayment.paymentHash, + ).packet + assertTrue(trampolineOnion.payload.size() < 500) + val finalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolinePayload.amount, trampolinePayload.totalAmount, trampolinePayload.expiry, randomBytes32(), trampolineOnion) + makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, finalPayload) } val result = paymentHandler.process(willAddHtlc, Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, TestConstants.fundingRates) assertIs(result) @@ -393,15 +392,16 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val (paymentHandler, incomingPayment, _) = createFixture(defaultAmount) checkDbPayment(incomingPayment, paymentHandler.db) val willAddHtlc = run { - // We simulate a trampoline-relay with a dummy channel hop between the liquidity provider and the wallet. - val (amount, expiry, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - incomingPayment.paymentHash, - listOf(NodeHop(TestConstants.Alice.nodeParams.nodeId, TestConstants.Bob.nodeParams.nodeId, CltvExpiryDelta(144), 0.msat)), - makeMppPayload(defaultAmount, defaultAmount, randomBytes32()), - null - ) - assertTrue(trampolineOnion.packet.payload.size() < 500) - makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, randomBytes32(), trampolineOnion.packet)) + // We simulate a trampoline-relay: the trampoline node will relay the a payment onion with a dummy channel hop. + val trampolinePayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()) + val trampolineOnion = OutgoingPaymentPacket.buildOnion( + nodes = listOf(TestConstants.Bob.nodeParams.nodeId), + payloads = listOf(trampolinePayload), + associatedData = incomingPayment.paymentHash, + ).packet + assertTrue(trampolineOnion.payload.size() < 500) + val finalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolinePayload.amount, trampolinePayload.totalAmount, trampolinePayload.expiry, randomBytes32(), trampolineOnion) + makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, finalPayload) } val result = paymentHandler.process(willAddHtlc, Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, TestConstants.fundingRates) assertIs(result) @@ -1762,27 +1762,9 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val defaultAmount = 150_000_000.msat val feeCreditFeatures = Features(Feature.ExperimentalSplice to FeatureSupport.Optional, Feature.OnTheFlyFunding to FeatureSupport.Optional, Feature.FundingFeeCredit to FeatureSupport.Optional) - private fun channelHops(destination: PublicKey): List { - val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() - val dummyUpdate = ChannelUpdate( - signature = ByteVector64.Zeroes, - chainHash = BlockHash(ByteVector32.Zeroes), - shortChannelId = ShortChannelId(144, 0, 0), - timestampSeconds = 0, - messageFlags = 0, - channelFlags = 0, - cltvExpiryDelta = CltvExpiryDelta(144), - htlcMinimumMsat = 1000.msat, - feeBaseMsat = 1.msat, - feeProportionalMillionths = 10, - htlcMaximumMsat = null - ) - val channelHop = ChannelHop(dummyKey, destination, dummyUpdate) - return listOf(channelHop) - } - private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): ChannelCommand.Htlc.Add { - return OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destination), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + return ChannelCommand.Htlc.Add(finalPayload.amount, paymentHash, finalPayload.expiry, onion, UUID.randomUUID(), commit = true) } private fun makeUpdateAddHtlc( @@ -1798,9 +1780,9 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { null -> destination.nodeParams.nodeId else -> RouteBlinding.derivePrivateKey(destination.nodeParams.nodePrivateKey, blinding).publicKey() } - val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destinationNodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destinationNodeId), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet val amount = finalPayload.amount - (fundingFee?.amount ?: 0.msat) - return UpdateAddHtlc(channelId, id, amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet, blinding, fundingFee) + return UpdateAddHtlc(channelId, id, amount, paymentHash, finalPayload.expiry, onion, blinding, fundingFee) } private fun makeWillAddHtlc(destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload, blinding: PublicKey? = null): WillAddHtlc { @@ -1808,8 +1790,8 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { null -> destination.nodeParams.nodeId else -> RouteBlinding.derivePrivateKey(destination.nodeParams.nodePrivateKey, blinding).publicKey() } - val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destinationNodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) - return WillAddHtlc(destination.nodeParams.chainHash, randomBytes32(), finalPayload.amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet, blinding) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destinationNodeId), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + return WillAddHtlc(destination.nodeParams.chainHash, randomBytes32(), finalPayload.amount, paymentHash, finalPayload.expiry, onion, blinding) } private fun makeMppPayload( diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt index 1d3bbba2f..b8502d240 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt @@ -1,13 +1,14 @@ package fr.acinq.lightning.payment -import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Chain +import fr.acinq.bitcoin.Crypto +import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.utils.Either import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.Lightning.randomKey -import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.channel.* -import fr.acinq.lightning.channel.states.Normal import fr.acinq.lightning.channel.states.Offline import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.Sphinx @@ -15,11 +16,9 @@ import fr.acinq.lightning.db.InMemoryPaymentsDb import fr.acinq.lightning.db.LightningOutgoingPayment import fr.acinq.lightning.db.OutgoingPaymentsDb import fr.acinq.lightning.io.PayInvoice -import fr.acinq.lightning.io.WrappedChannelCommand import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.tests.utils.runSuspendTest -import fr.acinq.lightning.transactions.CommitmentSpec import fr.acinq.lightning.utils.* import fr.acinq.lightning.wire.* import kotlin.test.* @@ -57,10 +56,25 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { Feature.RouteBlinding to FeatureSupport.Optional, ) // The following invoice requires payment_metadata. - val invoice1 = - Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.PaymentMetadata to FeatureSupport.Mandatory)) + val invoice1 = run { + val unsupportedFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Mandatory + ) + Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = unsupportedFeatures) + } // The following invoice requires unknown feature bit 188. - val invoice2 = Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = Features(mapOf(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory), setOf(UnknownFeature(188)))) + val invoice2 = run { + val unsupportedFeatures = Features( + mapOf( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory + ), + setOf(UnknownFeature(188)) + ) + Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = unsupportedFeatures) + } for (invoice in listOf(invoice1, invoice2)) { val outgoingPaymentHandler = OutgoingPaymentHandler(alice.staticParams.nodeParams.copy(features = features), defaultWalletParams, InMemoryPaymentsDb()) val payment = PayInvoice(UUID.randomUUID(), 15_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) @@ -111,16 +125,16 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { @Test fun `invoice already paid`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = 100_000.msat, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val (channelId, add) = filterAddHtlcCommands(result).first() + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId, add, randomBytes32())) as OutgoingPaymentHandler.Success val duplicatePayment = payment.copy(paymentId = UUID.randomUUID()) - val error = outgoingPaymentHandler.sendPayment(duplicatePayment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val error = outgoingPaymentHandler.sendPayment(duplicatePayment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure assertEquals(error.failure.reason, FinalFailure.AlreadyPaid) } @@ -134,14 +148,13 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 1 of 2: this should work because we're still under the maxAcceptedHtlcs. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result is OutgoingPaymentHandler.Progress } - - val progress = result as OutgoingPaymentHandler.Progress - assertEquals(1, result.actions.size) - val processResult = alice.processSameState(progress.actions.first().channelCommand) - assertTrue { processResult.second.filterIsInstance().isEmpty() } - alice = processResult.first + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) + + val (alice1, actions1) = alice.processSameState(progress.actions.first().channelCommand) + assertTrue(actions1.filterIsInstance().isEmpty()) + alice = alice1 assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) @@ -154,21 +167,18 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 2 of 2: this should exceed the configured maxAcceptedHtlcs. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 50_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result1 is OutgoingPaymentHandler.Progress } + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) - val progress = result1 as OutgoingPaymentHandler.Progress - assertEquals(1, result1.actions.size) val cmdAdd = progress.actions.first().channelCommand - val processResult = alice.processSameState(cmdAdd) - alice = processResult.first - - val addFailure = processResult.second.filterIsInstance().firstOrNull() + val (_, actions1) = alice.processSameState(cmdAdd) + val addFailure = actions1.filterIsInstance().firstOrNull() assertNotNull(addFailure) // Now the channel error gets sent back to the OutgoingPaymentHandler. - val result2 = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure, mapOf(alice.channelId to alice.state)) + val failure = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, listOf(Either.Left(TooManyAcceptedHtlcs(alice.channelId, 1))))) - assertFailureEquals(result2 as OutgoingPaymentHandler.Failure, expected) + assertFailureEquals(failure as OutgoingPaymentHandler.Failure, expected) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) @@ -179,22 +189,22 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { fun `channel restrictions -- maxHtlcValueInFlight`() = runSuspendTest { var (alice, _) = TestsHelper.reachNormal() val maxHtlcValueInFlightMsat = 150_000L - alice = - alice.copy(state = alice.state.copy(commitments = alice.commitments.copy(params = alice.commitments.params.copy(remoteParams = alice.commitments.params.remoteParams.copy(maxHtlcValueInFlightMsat = maxHtlcValueInFlightMsat))))) + alice = alice.copy( + state = alice.state.copy(commitments = alice.commitments.copy(params = alice.commitments.params.copy(remoteParams = alice.commitments.params.remoteParams.copy(maxHtlcValueInFlightMsat = maxHtlcValueInFlightMsat)))) + ) val outgoingPaymentHandler = OutgoingPaymentHandler(alice.staticParams.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) run { // Send payment 1 of 2: this should work because we're still under the maxHtlcValueInFlightMsat. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result is OutgoingPaymentHandler.Progress } - - val progress = result as OutgoingPaymentHandler.Progress - assertEquals(1, result.actions.size) - val processResult = alice.processSameState(progress.actions.first().channelCommand) - assertTrue { processResult.second.filterIsInstance().isEmpty() } - alice = processResult.first + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) + + val (alice1, actions1) = alice.processSameState(progress.actions.first().channelCommand) + assertTrue(actions1.filterIsInstance().isEmpty()) + alice = alice1 assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) @@ -207,21 +217,18 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 2 of 2: this should exceed the configured maxHtlcValueInFlightMsat. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result1 is OutgoingPaymentHandler.Progress } + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) - val progress = result1 as OutgoingPaymentHandler.Progress - assertEquals(1, result1.actions.size) val cmdAdd = progress.actions.first().channelCommand - val processResult = alice.processSameState(cmdAdd) - alice = processResult.first - - val addFailure = processResult.second.filterIsInstance().firstOrNull() + val (_, actions1) = alice.processSameState(cmdAdd) + val addFailure = actions1.filterIsInstance().firstOrNull() assertNotNull(addFailure) // Now the channel error gets sent back to the OutgoingPaymentHandler. - val result2 = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure, mapOf(alice.channelId to alice.state)) + val failure = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, listOf(Either.Left(HtlcValueTooHighInFlight(alice.channelId, maxHtlcValueInFlightMsat.toULong(), 200_000.msat))))) - assertFailureEquals(result2 as OutgoingPaymentHandler.Failure, expected) + assertFailureEquals(failure as OutgoingPaymentHandler.Failure, expected) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) @@ -229,7 +236,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- single part`() = runSuspendTest { + fun `successful first attempt`() = runSuspendTest { val recipientKey = randomKey() val invoice = makeInvoice(amount = 195_000.msat, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 200_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) // we slightly overpay the invoice amount @@ -237,7 +244,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- single part + backwards-compatibility trampoline bit`() = runSuspendTest { + fun `successful first attempt -- backwards-compatibility trampoline bit`() = runSuspendTest { val recipientKey = randomKey() val invoice = run { // Invoices generated by older versions of wallets based on lightning-kmp will generate invoices with the following feature bits. @@ -262,19 +269,17 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } private suspend fun testSinglePartTrampolinePayment(payment: PayInvoice, invoice: Bolt11Invoice, recipientKey: PrivateKey) { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(3.sat, 10_000, CltvExpiryDelta(144)))) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(1, adds.size) - val (channelId, add) = adds.first() + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) assertEquals(205_000.msat, add.amount) assertEquals(payment.paymentHash, add.paymentHash) // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) assertEquals(205_000.msat, outerB.amount) assertEquals(205_000.msat, outerB.totalAmount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) @@ -310,65 +315,8 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- multiple parts`() = runSuspendTest { - val channels = makeChannels() - val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(10.sat, 0, CltvExpiryDelta(144)))) - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) - val recipientKey = randomKey() - val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(310_000.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) - assertEquals(payloadC.amount, payloadC.totalAmount) - assertEquals(invoice.paymentSecret, payloadC.paymentSecret) - } - - val preimage = randomBytes32() - val (channelId1, add1) = adds[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val fulfill2 = ChannelAction.ProcessCmdRes.AddSettledFulfill(add2.paymentId, makeUpdateAddHtlc(channelId2, add2), ChannelAction.HtlcResult.Fulfill.OnChainFulfill(preimage)) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(10_000.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(310_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 2) - } - - @Test - fun `successful first attempt -- multiple parts + legacy recipient`() = runSuspendTest { - val channels = makeChannels() + fun `successful first attempt -- legacy recipient`() = runSuspendTest { + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(10.sat, 0, CltvExpiryDelta(144)))) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) val recipientKey = randomKey() @@ -376,51 +324,46 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = false, privKey = recipientKey, extraHops = extraHops) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, _) = PaymentPacketTestsCommon.decryptRelayToNonTrampolinePayload(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(310_000.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - assertEquals(invoice.paymentSecret, innerB.paymentSecret) - assertEquals(invoice.features.toByteArray().toByteVector(), innerB.invoiceFeatures) - assertFalse(innerB.invoiceRoutingInfo.isEmpty()) - assertEquals(invoice.routingInfo.map { it.hints }, innerB.invoiceRoutingInfo) - } + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) + assertEquals(310_000.msat, add.amount) + assertEquals(payment.paymentHash, add.paymentHash) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB) = PaymentPacketTestsCommon.decryptRelayToLegacy(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add.amount, outerB.amount) + assertEquals(310_000.msat, outerB.totalAmount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) + assertEquals(300_000.msat, innerB.amountToForward) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) + assertEquals(payment.recipient, innerB.outgoingNodeId) + assertEquals(invoice.paymentSecret, innerB.paymentSecret) + assertEquals(invoice.features.toByteArray().toByteVector(), innerB.invoiceFeatures) + assertFalse(innerB.invoiceRoutingInfo.isEmpty()) + assertEquals(invoice.routingInfo.map { it.hints }, innerB.invoiceRoutingInfo) val preimage = randomBytes32() - val (channelId1, add1) = adds[0] - val success1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId1, add1, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val success2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId2, add2, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(10_000.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(310_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId, add, preimage)) + assertNotNull(success) + assertEquals(preimage, success.preimage) + assertEquals(10_000.msat, success.payment.fees) + assertEquals(300_000.msat, success.payment.recipientAmount) + assertEquals(invoice.nodeId, success.payment.recipient) + assertEquals(invoice.paymentHash, success.payment.paymentHash) + assertEquals(invoice, (success.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(310_000.msat, success.payment.parts.first().amount) + val status = success.payment.parts.first().status + assertIs(status) + assertEquals(preimage, status.preimage) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 2) + assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 1) } @Test fun `successful first attempt -- random final expiry`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(25.sat, 0, CltvExpiryDelta(48)))) val recipientExpiryParams = RecipientCltvExpiryParams(CltvExpiryDelta(144), CltvExpiryDelta(288)) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams.copy(paymentRecipientExpiryParams = recipientExpiryParams), walletParams, InMemoryPaymentsDb()) @@ -428,230 +371,126 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - val minFinalExpiry = CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA + recipientExpiryParams.min - assertTrue(minFinalExpiry + CltvExpiryDelta(48) <= outerB.expiry) - assertTrue(minFinalExpiry <= innerB.outgoingCltv) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertTrue(minFinalExpiry <= payloadC.expiry) - assertEquals(innerB.outgoingCltv, payloadC.expiry) - } - } - - @Test - fun `successful first attempt -- multiple parts + recipient is our peer`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - // The invoice comes from Bob, our direct peer (and trampoline node). - val preimage = randomBytes32() - val incomingPaymentHandler = IncomingPaymentHandler(TestConstants.Bob.nodeParams, InMemoryPaymentsDb()) - val invoice = incomingPaymentHandler.createInvoice(preimage, amount = null, Either.Left("phoenix to phoenix"), listOf()) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(300_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - adds.forEach { (channelId, add) -> - // Bob should receive the right final information. - val payloadB = IncomingPaymentPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! - assertIs(payloadB) - assertEquals(add.amount, payloadB.amount) - assertEquals(300_000.msat, payloadB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadB.expiry) - assertEquals(invoice.paymentSecret, payloadB.paymentSecret) - } - - // Bob receives these 2 HTLCs. - val process1 = incomingPaymentHandler.process(makeUpdateAddHtlc(adds[0].first, adds[0].second, 3), Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, remoteFundingRates = null) - assertTrue(process1 is IncomingPaymentHandler.ProcessAddResult.Pending) - val process2 = incomingPaymentHandler.process(makeUpdateAddHtlc(adds[1].first, adds[1].second, 5), Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, remoteFundingRates = null) - assertTrue(process2 is IncomingPaymentHandler.ProcessAddResult.Accepted) - val fulfills = process2.actions.filterIsInstance().mapNotNull { it.channelCommand as? ChannelCommand.Htlc.Settlement.Fulfill } - assertEquals(2, fulfills.size) - - // Alice receives the fulfill for these 2 HTLCs. - val (channelId1, add1) = adds[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val fulfill2 = ChannelAction.ProcessCmdRes.AddSettledFulfill(add2.paymentId, makeUpdateAddHtlc(channelId2, add2), ChannelAction.HtlcResult.Fulfill.OnChainFulfill(preimage)) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(0.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(TestConstants.Bob.nodeParams.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(300_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add.amount, outerB.amount) + val minFinalExpiry = CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA + recipientExpiryParams.min + assertTrue(minFinalExpiry + CltvExpiryDelta(48) <= outerB.expiry) + assertTrue(minFinalExpiry <= innerB.outgoingCltv) - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 0.msat, partsCount = 2) + // The recipient should receive the right amount and expiry. + val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! + val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! + assertEquals(300_000.msat, payloadC.amount) + assertTrue(minFinalExpiry <= payloadC.expiry) + assertEquals(innerB.outgoingCltv, payloadC.expiry) } @Test fun `successful second attempt`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val recipientKey = randomKey() val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(2, adds1.size) - assertEquals(300_000.msat, adds1.map { it.second.amount }.sum()) + val progress1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress1) + val (channelId1, add1) = findAddHtlcCommand(progress1) + assertEquals(alice.channelId, channelId1) + assertEquals(300_000.msat, add1.amount) // This first attempt fails because fees are too low. val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail1 = outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) - assertNull(fail1) - val progress2 = outgoingPaymentHandler.processAddSettled(adds1[1].first, createRemoteFailure(adds1[1].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(progress2) - assertEquals(2, adds2.size) - assertEquals(301_030.msat, adds2.map { it.second.amount }.sum()) - adds2.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - adds2.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(301_030.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(576) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) - assertEquals(payloadC.amount, payloadC.totalAmount) - assertEquals(invoice.paymentSecret, payloadC.paymentSecret) - } + val progress2 = outgoingPaymentHandler.processAddSettled(channelId1, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress2) + val (channelId2, add2) = findAddHtlcCommand(progress2) + assertEquals(channelId1, channelId2) + assertEquals(301_030.msat, add2.amount) + assertEquals(payment.paymentHash, add2.paymentHash) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId2, add2), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add2.amount, outerB.amount) + assertEquals(301_030.msat, outerB.totalAmount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(576) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) + assertEquals(300_000.msat, innerB.amountToForward) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) + assertEquals(payment.recipient, innerB.outgoingNodeId) + + // The recipient should receive the right amount and expiry. + val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! + val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! + assertEquals(300_000.msat, payloadC.amount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) + assertEquals(payloadC.amount, payloadC.totalAmount) + assertEquals(invoice.paymentSecret, payloadC.paymentSecret) val dbPayment1 = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) assertNotNull(dbPayment1) - assertTrue(dbPayment1.status is LightningOutgoingPayment.Status.Pending) - assertEquals(2, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Failed }.size) - assertEquals(2, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Pending }.size) + assertIs(dbPayment1.status) + assertEquals(1, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Failed }.size) + assertEquals(1, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Pending }.size) // The second attempt succeeds. val preimage = randomBytes32() - val success1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds2[0].first, adds2[0].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - val success2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds2[1].first, adds2[1].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(1_030.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(301_030.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId2, add2, preimage)) + assertIs(success) + assertEquals(preimage, success.preimage) + assertEquals(1_030.msat, success.payment.fees) + assertEquals(300_000.msat, success.payment.recipientAmount) + assertEquals(invoice.nodeId, success.payment.recipient) + assertEquals(invoice.paymentHash, success.payment.paymentHash) + assertEquals(invoice, (success.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(301_030.msat, success.payment.parts.first().amount) + val status = success.payment.parts.first().status + assertIs(status) + assertEquals(preimage, status.preimage) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment2 = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) assertNotNull(dbPayment2) - assertTrue(dbPayment2.status is LightningOutgoingPayment.Status.Completed.Succeeded.OffChain) - assertEquals(2, dbPayment2.parts.size) + assertIs(dbPayment2.status) + assertEquals(1, dbPayment2.parts.size) assertTrue(dbPayment2.parts.all { it.status is LightningOutgoingPayment.Part.Status.Succeeded }) } - @Test - fun `successful second attempt -- recipient is our peer`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - // The invoice comes from Bob, our direct peer (and trampoline node). - val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = TestConstants.Bob.nodeParams.nodePrivateKey) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(result1) - assertEquals(2, adds1.size) - - // The first attempt fails because of a local channel error. - val result2 = outgoingPaymentHandler.processAddFailed(adds1[1].first, ChannelAction.ProcessCmdRes.AddFailed(adds1[1].second, TooManyAcceptedHtlcs(adds1[1].first, 10), null), channels) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(result2) - assertEquals(1, adds2.size) - - // The other HTLCs succeed. - val preimage = randomBytes32() - val (channelId1, add1) = adds1[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - - val (channelId2, add2) = adds2[0] - val fulfill2 = createRemoteFulfill(channelId2, add2, preimage) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(0.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(TestConstants.Bob.nodeParams.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(300_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 0.msat, partsCount = 2) - } - @Test fun `insufficient funds when retrying with higher fees`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobFundingAmount = 0.sat, bobPushAmount = 0.msat) + assertTrue(83_500_000.msat < alice.commitments.availableBalanceForSend()) + assertTrue(alice.commitments.availableBalanceForSend() < 84_000_000.msat) val walletParams = defaultWalletParams.copy( trampolineFees = listOf( - TrampolineFees(10.sat, 0, CltvExpiryDelta(144)), TrampolineFees(100.sat, 0, CltvExpiryDelta(144)), + TrampolineFees(1000.sat, 0, CltvExpiryDelta(144)), ) ) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 650_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) + val payment = PayInvoice(UUID.randomUUID(), 83_000_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(4, adds1.size) - assertEquals(660_000.msat, adds1.map { it.second.amount }.sum()) + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + val (_, add1) = findAddHtlcCommand(progress) + assertEquals(83_100_000.msat, add1.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - assertNull(outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - assertNull(outgoingPaymentHandler.processAddSettled(adds1[1].first, createRemoteFailure(adds1[1].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - assertNull(outgoingPaymentHandler.processAddSettled(adds1[2].first, createRemoteFailure(adds1[2].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - val fail = outgoingPaymentHandler.processAddSettled(adds1[3].first, createRemoteFailure(adds1[3].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.InsufficientBalance, listOf(Either.Right(TrampolineFeeInsufficient)))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 4) + assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) } @Test fun `retries exhausted`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy( trampolineFees = listOf( TrampolineFees(10.sat, 0, CltvExpiryDelta(144)), @@ -662,20 +501,21 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 220_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(1, adds1.size) - assertEquals(230_000.msat, adds1.map { it.second.amount }.sum()) + val progress1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress1) + val (_, add1) = findAddHtlcCommand(progress1) + assertEquals(230_000.msat, add1.amount) val attempt1 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val progress2 = outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt1, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(progress2) - assertEquals(1, adds2.size) - assertEquals(240_000.msat, adds2.map { it.second.amount }.sum()) + val progress2 = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add1, attempt1, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress2) + val (_, add2) = findAddHtlcCommand(progress2) + assertEquals(240_000.msat, add2.amount) val attempt2 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(adds2[0].first, createRemoteFailure(adds2[0].second, attempt2, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure - val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeInsufficient)))) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add2, attempt2, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) + val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeInsufficient), Either.Right(TrampolineFeeInsufficient)))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) @@ -689,18 +529,19 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { Pair(IncorrectOrUnknownPaymentDetails(50_000.msat, TestConstants.defaultBlockHeight.toLong()), FinalFailure.UnknownError) ) fatalFailures.forEach { (remoteFailure, userFailure) -> - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 50_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(1, adds.size) - assertEquals(50_000.msat, adds.map { it.second.amount }.sum()) + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + val (_, add) = findAddHtlcCommand(progress) + assertEquals(50_000.msat, add.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(adds[0].first, createRemoteFailure(adds[0].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add, attempt, remoteFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(userFailure, listOf(Either.Right(remoteFailure)))) assertFailureEquals(expected, fail) @@ -709,212 +550,74 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } } - private suspend fun testLocalChannelFailures(invoice: Bolt11Invoice) { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val payment = PayInvoice(UUID.randomUUID(), 5_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - var progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - assertEquals(1, progress.actions.size) - assertEquals(5_000.msat, filterAddHtlcCommands(progress).map { it.second.amount }.sum()) - - // Channels fail, so we retry with different channels, without raising the fees. - val localFailures = listOf( - { channelId: ByteVector32 -> TooManyAcceptedHtlcs(channelId, 15) }, - { channelId: ByteVector32 -> InsufficientFunds(channelId, 5_000.msat, 1.sat, 20.sat, 1.sat) }, - { channelId: ByteVector32 -> HtlcValueTooHighInFlight(channelId, 150_000U, 155_000.msat) }, - { channelId: ByteVector32 -> ForbiddenDuringSplice(channelId, "update-add-htlc") } - ) - localFailures.forEach { localFailure -> - val (channelId, add) = filterAddHtlcCommands(progress).first() - progress = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, localFailure(channelId), null), channels) as OutgoingPaymentHandler.Progress - assertEquals(5_000.msat, add.amount) - } - - // The last channel fails: we don't have any channels available to retry. - val (channelId, add) = filterAddHtlcCommands(progress).first() - val fail = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, TooManyAcceptedHtlcs(channelId, 15), null), channels) as OutgoingPaymentHandler.Failure - assertEquals(FinalFailure.InsufficientBalance, fail.failure.reason) - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 5) - } - - @Test - fun `local channel failures`() = runSuspendTest { - testLocalChannelFailures(makeInvoice(amount = null, supportsTrampoline = true)) - } - - @Test - fun `local channel failures -- recipient is our peer`() = runSuspendTest { - // The invoice comes from Bob, our direct peer (and trampoline node). - testLocalChannelFailures(makeInvoice(amount = null, supportsTrampoline = true, privKey = TestConstants.Bob.nodeParams.nodePrivateKey)) - } - - @Test - fun `local channel failure followed by success`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 5_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - assertEquals(1, progress1.actions.size) - assertEquals(5_000.msat, filterAddHtlcCommands(progress1).map { it.second.amount }.sum()) - - // This first payment fails: - val (channelId, add) = filterAddHtlcCommands(progress1).first() - val progress2 = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, TooManyAcceptedHtlcs(channelId, 1), null), channels) as OutgoingPaymentHandler.Progress - assertEquals(1, progress2.actions.size) - val adds = filterAddHtlcCommands(progress2) - assertEquals(5_000.msat, adds.map { it.second.amount }.sum()) - - // This second attempt succeeds: - val preimage = randomBytes32() - val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(0.msat, success.payment.fees) - assertEquals(5_000.msat, success.payment.recipientAmount) - assertEquals(preimage, success.preimage) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 5_000.msat, fees = 0.msat, partsCount = 1) - } - - @Test - fun `partial failure then fulfill -- spec violation`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 310_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - - val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val remoteFailure = IncorrectOrUnknownPaymentDetails(310_000.msat, TestConstants.defaultBlockHeight.toLong()) - assertNull(outgoingPaymentHandler.processAddSettled(adds[0].first, createRemoteFailure(adds[0].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight)) - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - val preimage = randomBytes32() - val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[1].first, adds[1].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success.preimage) - assertEquals((-250_000).msat, success.payment.fees) // since we paid much less than the expected amount, it results in negative fees - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 310_000.msat, fees = (-250_000).msat, partsCount = 1) - } - - @Test - fun `partial fulfill then failure -- spec violation`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 310_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - - val preimage = randomBytes32() - val expected = OutgoingPaymentHandler.PreimageReceived(payment, preimage) - val result = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) - assertEquals(expected, result) - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val remoteFailure = IncorrectOrUnknownPaymentDetails(310_000.msat, TestConstants.defaultBlockHeight.toLong()) - val success = outgoingPaymentHandler.processAddSettled(adds[1].first, createRemoteFailure(adds[1].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Success - assertEquals(preimage, success.preimage) - assertEquals((-60_000).msat, success.payment.fees) // since we paid much less than the expected amount, it results in negative fees - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 310_000.msat, fees = (-60_000).msat, partsCount = 1) - } - @Test fun `failure after a wallet restart`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val db = InMemoryPaymentsDb() // Step 1: a payment attempt is made. - val (adds, attempt) = run { + val (add, attempt) = run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 550_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val adds = filterAddHtlcCommands(progress) - assertEquals(3, adds.size) - Pair(adds, attempt) + val (_, add) = findAddHtlcCommand(progress) + Pair(add, attempt) } // Step 2: the wallet restarts and payment fails. run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - assertNull(outgoingPaymentHandler.processAddFailed(adds[0].first, ChannelAction.ProcessCmdRes.AddFailed(adds[0].second, ChannelUnavailable(adds[0].first), null), channels)) - assertNull(outgoingPaymentHandler.processAddSettled(adds[1].first, createRemoteFailure(adds[1].second, attempt, TemporaryNodeFailure), channels, TestConstants.defaultBlockHeight)) - val result = outgoingPaymentHandler.processAddSettled(adds[2].first, createRemoteFailure(adds[2].second, attempt, PermanentNodeFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure - val failures: List> = listOf( - Either.Left(ChannelUnavailable(adds[0].first)), - // Since we've lost the shared secrets, we can't decrypt remote failures. - Either.Right(UnknownFailureMessage(0)), - Either.Right(UnknownFailureMessage(0)) - ) - assertFailureEquals(OutgoingPaymentHandler.Failure(attempt.request, OutgoingPaymentFailure(FinalFailure.WalletRestarted, failures)), result) + val fail = outgoingPaymentHandler.processAddSettled(alice.channelId, createRemoteFailure(add, attempt, TemporaryNodeFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) + assertEquals(attempt.request, fail.request) + assertEquals(FinalFailure.WalletRestarted, fail.failure.reason) + assertEquals(1, fail.failure.failures.size) + // Since we haven't stored the shared secrets, we can't decrypt remote failure. + assertIs(fail.failure.failures.first().failure) } } @Test fun `success after a wallet restart`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val preimage = randomBytes32() val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 550_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) val db = InMemoryPaymentsDb() // Step 1: a payment attempt is made. - val adds = run { + val add = run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(3, adds.size) - // A first part is fulfilled before the wallet restarts. - val result = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), result) - adds + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + findAddHtlcCommand(progress).second } // Step 2: the wallet restarts and payment succeeds. run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - val result1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[1].first, adds[1].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), result1) - val result2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[2].first, adds[2].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, result2.preimage) - assertEquals(3, result2.payment.parts.size) - assertEquals(payment, PayInvoice(result2.payment.id, result2.payment.recipientAmount, result2.payment.details as LightningOutgoingPayment.Details.Normal)) - assertEquals(preimage, (result2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(alice.channelId, add, preimage)) + assertIs(success) + assertEquals(preimage, success.preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(payment, success.request) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) } } private fun makeInvoice(amount: MilliSatoshi?, supportsTrampoline: Boolean, privKey: PrivateKey = randomKey(), extraHops: List> = listOf()): Bolt11Invoice { val paymentPreimage: ByteVector32 = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() - - val invoiceFeatures = mutableMapOf( - Feature.VariableLengthOnion to FeatureSupport.Optional, - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.BasicMultiPartPayment to FeatureSupport.Optional - ) - if (supportsTrampoline) { - invoiceFeatures[Feature.ExperimentalTrampolinePayment] = FeatureSupport.Optional + val invoiceFeatures: Map = buildMap { + put(Feature.VariableLengthOnion, FeatureSupport.Optional) + put(Feature.PaymentSecret, FeatureSupport.Mandatory) + put(Feature.BasicMultiPartPayment, FeatureSupport.Optional) + if (supportsTrampoline) put(Feature.ExperimentalTrampolinePayment, FeatureSupport.Optional) } - return Bolt11Invoice.create( chain = Chain.Mainnet, amount = amount, @@ -922,50 +625,21 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { privateKey = privKey, description = Either.Left("unit test"), minFinalCltvExpiryDelta = Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, - features = Features(invoiceFeatures.toMap()), + features = Features(invoiceFeatures), extraHops = extraHops ) } - private fun makeChannels(): Map { - val (alice, _) = TestsHelper.reachNormal() - val reserve = alice.commitments.latest.localChannelReserve - val channelDetails = listOf( - Pair(ShortChannelId(1), 250_000.msat), - Pair(ShortChannelId(2), 150_000.msat), - Pair(ShortChannelId(3), 0.msat), - Pair(ShortChannelId(4), 10_000.msat), - Pair(ShortChannelId(5), 200_000.msat), - Pair(ShortChannelId(6), 100_000.msat), - ) - return channelDetails.associate { - val channelId = randomBytes32() - val channel = alice.state.copy( - shortChannelId = it.first, - commitments = alice.commitments.copy( - params = alice.commitments.params.copy(channelId = channelId), - active = alice.commitments.active.map { commitment -> - commitment.copy(remoteCommit = commitment.remoteCommit.copy(spec = CommitmentSpec(setOf(), FeeratePerKw(0.sat), 50_000.msat, (it.second + ((Commitments.ANCHOR_AMOUNT * 2) + reserve).toMilliSatoshi())))) - } - ) - ) - channelId to channel - } - } - - private fun filterAddHtlcCommands(progress: OutgoingPaymentHandler.Progress): List> { - val addCommands = mutableListOf>() - for (action in progress.actions) { - val addCommand = action.channelCommand as? ChannelCommand.Htlc.Add - if (addCommand != null) { - addCommands.add(Pair(action.channelId, addCommand)) + private fun findAddHtlcCommand(progress: OutgoingPaymentHandler.Progress): Pair { + return progress.actions.firstNotNullOf { + when (val cmd = it.channelCommand) { + is ChannelCommand.Htlc.Add -> Pair(it.channelId, cmd) + else -> null } } - return addCommands.toList() } - private fun makeUpdateAddHtlc(channelId: ByteVector32, cmd: ChannelCommand.Htlc.Add, htlcId: Long = 0): UpdateAddHtlc = - UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) + private fun makeUpdateAddHtlc(channelId: ByteVector32, cmd: ChannelCommand.Htlc.Add, htlcId: Long = 0): UpdateAddHtlc = UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) private fun createRemoteFulfill(channelId: ByteVector32, add: ChannelCommand.Htlc.Add, preimage: ByteVector32): ChannelAction.ProcessCmdRes.AddSettledFulfill { val updateAddHtlc = makeUpdateAddHtlc(channelId, add) @@ -973,8 +647,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } private fun createRemoteFailure(add: ChannelCommand.Htlc.Add, attempt: OutgoingPaymentHandler.PaymentAttempt, failureMessage: FailureMessage): ChannelAction.ProcessCmdRes.AddSettledFail { - val sharedSecrets = attempt.pending.getValue(add.paymentId).second - val reason = FailurePacket.create(sharedSecrets.perHopSecrets.last().first, failureMessage) + val reason = FailurePacket.create(attempt.sharedSecrets.perHopSecrets.last().first, failureMessage) val updateAddHtlc = makeUpdateAddHtlc(randomBytes32(), add) return ChannelAction.ProcessCmdRes.AddSettledFail( add.paymentId, diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt index 9b699dee3..e57d8390b 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt @@ -10,122 +10,70 @@ import fr.acinq.lightning.Lightning.randomBytes64 import fr.acinq.lightning.Lightning.randomKey import fr.acinq.lightning.channel.states.Channel import fr.acinq.lightning.crypto.RouteBlinding +import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.Sphinx import fr.acinq.lightning.crypto.sphinx.Sphinx.hash -import fr.acinq.lightning.payment.OutgoingPaymentHandler.PaymentAttempt import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.tests.utils.LightningTestSuite -import fr.acinq.lightning.utils.* +import fr.acinq.lightning.utils.currentTimestampMillis +import fr.acinq.lightning.utils.msat +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.utils.toByteVector32 import fr.acinq.lightning.wire.* import kotlin.test.* class PaymentPacketTestsCommon : LightningTestSuite() { companion object { - private val privA = randomKey() - private val a = privA.publicKey() private val privB = randomKey() private val b = privB.publicKey() private val privC = randomKey() private val c = privC.publicKey() private val privD = randomKey() private val d = privD.publicKey() - private val privE = randomKey() - private val e = privE.publicKey() private val defaultChannelUpdate = ChannelUpdate(randomBytes64(), Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0, 1, 0, CltvExpiryDelta(0), 42000.msat, 0.msat, 0, 500000000.msat) - private val channelUpdateAB = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1), cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 642000.msat, feeProportionalMillionths = 7) - private val channelUpdateBC = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(2), cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 153000.msat, feeProportionalMillionths = 4) - private val channelUpdateCD = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(3), cltvExpiryDelta = CltvExpiryDelta(10), feeBaseMsat = 60000.msat, feeProportionalMillionths = 1) - private val channelUpdateDE = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(4), cltvExpiryDelta = CltvExpiryDelta(7), feeBaseMsat = 766000.msat, feeProportionalMillionths = 10) - - // simple route a -> b -> c -> d -> e - private val hops = listOf( - ChannelHop(a, b, channelUpdateAB), - ChannelHop(b, c, channelUpdateBC), - ChannelHop(c, d, channelUpdateCD), - ChannelHop(d, e, channelUpdateDE) - ) + private val channelUpdateBC = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1105), cltvExpiryDelta = CltvExpiryDelta(36), feeBaseMsat = 7_500.msat, feeProportionalMillionths = 250) + private val channelUpdateCD = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1729), cltvExpiryDelta = CltvExpiryDelta(24), feeBaseMsat = 5_000.msat, feeProportionalMillionths = 100) - private val finalAmount = 42000000.msat - private const val currentBlockCount = 400000L + private val finalAmount = 50_000_000.msat + private const val currentBlockCount = 400_000L private val finalExpiry = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA private val paymentPreimage = randomBytes32() private val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() private val paymentSecret = randomBytes32() private val paymentMetadata = randomBytes(64).toByteVector() - private val expiryDE = finalExpiry - private val amountDE = finalAmount - private val feeD = nodeFee(channelUpdateDE.feeBaseMsat, channelUpdateDE.feeProportionalMillionths, amountDE) + private val nonTrampolineFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.BasicMultiPartPayment to FeatureSupport.Optional, + ) + private val trampolineFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.BasicMultiPartPayment to FeatureSupport.Optional, + Feature.ExperimentalTrampolinePayment to FeatureSupport.Optional, + ) - private val expiryCD = expiryDE + channelUpdateDE.cltvExpiryDelta - private val amountCD = amountDE + feeD + // Amount and expiry sent by C to D. + private val expiryCD = finalExpiry + private val amountCD = finalAmount private val feeC = nodeFee(channelUpdateCD.feeBaseMsat, channelUpdateCD.feeProportionalMillionths, amountCD) + // Amount and expiry sent by B to C. private val expiryBC = expiryCD + channelUpdateCD.cltvExpiryDelta private val amountBC = amountCD + feeC private val feeB = nodeFee(channelUpdateBC.feeBaseMsat, channelUpdateBC.feeProportionalMillionths, amountBC) + // Amount and expiry sent by A to B. private val expiryAB = expiryBC + channelUpdateBC.cltvExpiryDelta private val amountAB = amountBC + feeB - // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e - - private val trampolineHops = listOf( - NodeHop(a, c, channelUpdateAB.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta, feeB), - NodeHop(c, d, channelUpdateCD.cltvExpiryDelta, feeC), - NodeHop(d, e, channelUpdateDE.cltvExpiryDelta, feeD) - ) - - private val trampolineChannelHops = listOf( - ChannelHop(a, b, channelUpdateAB), - ChannelHop(b, c, channelUpdateBC) - ) - - private fun testBuildOnion() { - val finalPayload = PaymentOnion.FinalPayload.Standard(TlvStream(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount))) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) - assertEquals(amountAB, firstAmount) - assertEquals(expiryAB, firstExpiry) - assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) - // let's peel the onion - testPeelOnion(onion.packet) - } - - private fun testPeelOnion(packet_b: OnionRoutingPacket) { - val addB = UpdateAddHtlc(randomBytes32(), 0, amountAB, paymentHash, expiryAB, packet_b) - val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetC.payload.size()) - assertEquals(amountBC, payloadB.amountToForward) - assertEquals(expiryBC, payloadB.outgoingCltv) - assertEquals(channelUpdateBC.shortChannelId, payloadB.outgoingChannelId) - - val addC = UpdateAddHtlc(randomBytes32(), 1, amountBC, paymentHash, expiryBC, packetC) - val (payloadC, packetD) = decryptChannelRelay(addC, privC) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetD.payload.size()) - assertEquals(amountCD, payloadC.amountToForward) - assertEquals(expiryCD, payloadC.outgoingCltv) - assertEquals(channelUpdateCD.shortChannelId, payloadC.outgoingChannelId) - - val addD = UpdateAddHtlc(randomBytes32(), 2, amountCD, paymentHash, expiryCD, packetD) - val (payloadD, packetE) = decryptChannelRelay(addD, privD) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetE.payload.size()) - assertEquals(amountDE, payloadD.amountToForward) - assertEquals(expiryDE, payloadD.outgoingCltv) - assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) - - val addE = UpdateAddHtlc(randomBytes32(), 2, amountDE, paymentHash, expiryDE, packetE) - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! - assertIs(payloadE) - assertEquals(finalAmount, payloadE.amount) - assertEquals(finalAmount, payloadE.totalAmount) - assertEquals(finalExpiry, payloadE.expiry) - assertEquals(paymentSecret, payloadE.paymentSecret) - } + // C is directly connected to the recipient D. + val nodeHop_cd = NodeHop(c, d, channelUpdateCD.cltvExpiryDelta, feeC) + // B is not directly connected to the recipient D. + val nodeHop_bd = NodeHop(b, d, channelUpdateBC.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta, feeB + feeC) // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { @@ -135,14 +83,25 @@ class PaymentPacketTestsCommon : LightningTestSuite() { return Pair(decoded, decrypted.nextPacket) } + // Wallets don't need to create channel routes, but it's useful to test the end-to-end flow. + fun encryptChannelRelay(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Triple { + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { (amount, expiry, payloads), hop -> + val payload = PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) + Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) + } + val onion = OutgoingPaymentPacket.buildOnion(hops.map { it.nextNodeId }, payloads, paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(firstAmount, firstExpiry, onion) + } + // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToTrampoline(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertFalse(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.NodeRelayPayload.read(decryptedInner.payload).right!! assertNull(innerPayload.records.get()) assertNull(innerPayload.records.get()) @@ -152,34 +111,35 @@ class PaymentPacketTestsCommon : LightningTestSuite() { } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptRelayToNonTrampolinePayload(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToLegacy(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertFalse(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.RelayToNonTrampolinePayload.read(decryptedInner.payload).right!! - return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) + return Pair(outerPayload, innerPayload) } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptRelayToBlinded(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToBlinded(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertTrue(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.RelayToBlindedPayload.read(decryptedInner.payload).right!! - return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) + return Pair(outerPayload, innerPayload) } - // Create an HTLC paying an empty blinded path. - fun createBlindedHtlc(): Pair { - val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), null, 1, currentTimestampMillis()) - val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))) - val blindedRoute = RouteBlinding.create(randomKey(), listOf(e), listOf(blindedPayload.write().byteVector())).route + fun createBlindedHtlcCD(): Pair { + val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedRoute = RouteBlinding.create(randomKey(), listOf(d), listOf(blindedPayload.write().byteVector())).route val finalPayload = PaymentOnion.FinalPayload.Blinded( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), @@ -189,345 +149,299 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ), blindedPayload ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(paymentMetadata.paymentHash, listOf(blindedHop), finalPayload, OnionRoutingPacket.PaymentPacketLength) - val add = UpdateAddHtlc(randomBytes32(), 2, amountE, paymentMetadata.paymentHash, expiryE, onionE.packet, blindedRoute.blindingKey, null) - return Pair(add, finalPayload) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength) + val addD = UpdateAddHtlc(randomBytes32(), 1, finalAmount, paymentHash, finalExpiry, onionD.packet, blindedRoute.blindingKey, null) + return Pair(addD, paymentMetadata) } - } - @Test - fun `build onion`() { - testBuildOnion() - } - - @Test - fun `build a command including the onion`() { - val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null)) - assertTrue(add.amount > finalAmount) - assertEquals(add.cltvExpiry, finalExpiry + channelUpdateDE.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta) - assertEquals(add.paymentHash, paymentHash) - assertEquals(add.onion.payload.size(), OnionRoutingPacket.PaymentPacketLength) - - // let's peel the onion - testPeelOnion(add.onion) - } - - @Test - fun `build a command with no hops`() { - val paymentSecret = randomBytes32() - val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata)) - assertEquals(add.amount, finalAmount) - assertEquals(add.cltvExpiry, finalExpiry) - assertEquals(add.paymentHash, paymentHash) - assertEquals(add.onion.payload.size(), OnionRoutingPacket.PaymentPacketLength) - - // let's peel the onion - val addB = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion) - val finalPayload = IncomingPaymentPacket.decrypt(addB, privB).right!! - assertIs(finalPayload) - assertEquals(finalPayload.amount, finalAmount) - assertEquals(finalPayload.totalAmount, finalAmount) - assertEquals(finalPayload.expiry, finalExpiry) - assertEquals(paymentSecret, finalPayload.paymentSecret) - assertEquals(paymentMetadata, finalPayload.paymentMetadata) + fun createBlindedPaymentInfo(u: ChannelUpdate): OfferTypes.PaymentInfo { + return OfferTypes.PaymentInfo(u.feeBaseMsat, u.feeProportionalMillionths, u.cltvExpiryDelta, u.htlcMinimumMsat, u.htlcMaximumMsat!!, Features.empty) + } } @Test - fun `build a trampoline payment`() { - // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e - - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, paymentMetadata), - null - ) - assertEquals(amountBC, amountAC) - assertEquals(expiryBC, expiryAC) - - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - assertEquals(amountAB, firstAmount) - assertEquals(expiryAB, firstExpiry) - - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(PaymentOnion.ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) - - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) + fun `send a trampoline payment -- recipient connected to trampoline node`() { + // .--. + // / \ + // b -> c d + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(amountBC, firstAmount) + assertEquals(expiryBC, firstExpiry) + + // B sends an HTLC to its trampoline node C. + val addC = UpdateAddHtlc(randomBytes32(), 1, amountBC, paymentHash, expiryBC, onion.packet) + val (outerC, innerC, packetD) = decryptRelayToTrampoline(addC, privC) assertEquals(amountBC, outerC.amount) assertEquals(amountBC, outerC.totalAmount) assertEquals(expiryBC, outerC.expiry) - assertEquals(amountCD, innerC.amountToForward) - assertEquals(expiryCD, innerC.outgoingCltv) + assertEquals(finalAmount, innerC.amountToForward) + assertEquals(finalExpiry, innerC.outgoingCltv) assertEquals(d, innerC.outgoingNodeId) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) + // C forwards the trampoline payment to D over its direct channel. + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } assertEquals(amountCD, amountD) assertEquals(expiryCD, expiryD) - val addD = UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet) - val (outerD, innerD, packetE) = decryptNodeRelay(addD, privD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertEquals(amountDE, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(e, innerD.outgoingNodeId) - - // d forwards the trampoline payment to e. - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - assertEquals(amountDE, amountE) - assertEquals(expiryDE, expiryE) - val addE = UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet) - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD, onionD.packet) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), - OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount * 3), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) ) ) - assertEquals(payloadE, expectedFinalPayload) + assertEquals(payloadD, expectedFinalPayload) } @Test - fun `build a trampoline payment with non-trampoline recipient`() { - // simple trampoline route to e where e doesn't support trampoline: - // .--. - // / \ - // a -> b -> c d -> e - - val routingHints = listOf(Bolt11Invoice.TaggedField.ExtraHop(randomKey().publicKey(), ShortChannelId(42), 10.msat, 100, CltvExpiryDelta(144))) - val invoiceFeatures = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) - val invoice = Bolt11Invoice( - "lnbcrt", finalAmount, currentTimestampSeconds(), e, listOf( - Bolt11Invoice.TaggedField.PaymentHash(paymentHash), - Bolt11Invoice.TaggedField.PaymentSecret(paymentSecret), - Bolt11Invoice.TaggedField.PaymentMetadata(paymentMetadata), - Bolt11Invoice.TaggedField.DescriptionHash(randomBytes32()), - Bolt11Invoice.TaggedField.Features(invoiceFeatures.toByteArray().toByteVector()), - Bolt11Invoice.TaggedField.RoutingInfo(routingHints) - ), ByteVector.empty - ) - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket( - invoice, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null) - ) - assertEquals(amountBC, amountAC) - assertEquals(expiryBC, expiryAC) - - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) + fun `send a trampoline payment -- recipient not connected to trampoline node`() { + // .-> c -. + // / \ + // a -> b d + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) assertEquals(amountAB, firstAmount) assertEquals(expiryAB, firstExpiry) - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (_, packetC) = decryptChannelRelay(addB, privB) + // A sends an HTLC to its trampoline node B. + val addB = UpdateAddHtlc(randomBytes32(), 1, amountAB, paymentHash, expiryAB, onion.packet) + val (outerB, innerB, packetC) = decryptRelayToTrampoline(addB, privB) + assertEquals(amountAB, outerB.amount) + assertEquals(amountAB, outerB.totalAmount) + assertEquals(expiryAB, outerB.expiry) + assertEquals(finalAmount, innerB.amountToForward) + assertEquals(finalExpiry, innerB.outgoingCltv) + assertEquals(d, innerB.outgoingNodeId) + + // B forwards the trampoline payment to D over an indirect channel route. + val (amountC, expiryC, onionC) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerB.amountToForward, innerB.amountToForward, innerB.outgoingCltv, randomBytes32(), packetC) + encryptChannelRelay(paymentHash, listOf(ChannelHop(b, c, channelUpdateBC), ChannelHop(c, d, channelUpdateCD)), payloadD) + } + assertEquals(amountBC, amountC) + assertEquals(expiryBC, expiryC) + + // C relays the payment to D. + val addC = UpdateAddHtlc(randomBytes32(), 2, amountC, paymentHash, expiryC, onionC.packet) + val (payloadC, packetD) = decryptChannelRelay(addC, privC) + val addD = UpdateAddHtlc(randomBytes32(), 3, payloadC.amountToForward, paymentHash, payloadC.outgoingCltv, packetD) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( + TlvStream( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) + ) + ) + assertEquals(payloadD, expectedFinalPayload) + } - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) - assertIs(outerC) - assertEquals(amountBC, outerC.amount) - assertEquals(amountBC, outerC.totalAmount) - assertEquals(expiryBC, outerC.expiry) - assertNotEquals(invoice.paymentSecret, outerC.paymentSecret) - assertEquals(amountCD, innerC.amountToForward) - assertEquals(expiryCD, innerC.outgoingCltv) - assertEquals(d, innerC.outgoingNodeId) + @Test + fun `send a trampoline payment to legacy non-trampoline recipient`() { + // .-> c -. + // / \ + // a -> b d + val routingHint = listOf(Bolt11Invoice.TaggedField.ExtraHop(c, channelUpdateCD.shortChannelId, channelUpdateCD.feeBaseMsat, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.cltvExpiryDelta)) + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), nonTrampolineFeatures, paymentSecret, paymentMetadata, extraHops = listOf(routingHint)) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToLegacyRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) + assertEquals(amountAB, firstAmount) + assertEquals(expiryAB, firstExpiry) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength + // A sends an HTLC to its trampoline node B. + val addB = UpdateAddHtlc(randomBytes32(), 1, amountAB, paymentHash, expiryAB, onion.packet) + val (outerB, innerB) = decryptRelayToLegacy(addB, privB) + assertEquals(amountAB, outerB.amount) + assertEquals(amountAB, outerB.totalAmount) + assertEquals(expiryAB, outerB.expiry) + assertEquals(invoice.paymentSecret, outerB.paymentSecret) + assertEquals(finalAmount, innerB.amountToForward) + assertEquals(finalAmount, innerB.totalAmount) + assertEquals(finalExpiry, innerB.outgoingCltv) + assertEquals(d, innerB.outgoingNodeId) + assertEquals(invoice.paymentSecret, innerB.paymentSecret) + assertEquals(invoice.paymentMetadata, innerB.paymentMetadata) + assertEquals(ByteVector("024100"), innerB.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp + assertEquals(listOf(routingHint), innerB.invoiceRoutingInfo) + + // B forwards the trampoline payment to D over an indirect channel route. + val (amountC, expiryC, onionC) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(innerB.amountToForward, innerB.outgoingCltv, innerB.paymentSecret, innerB.paymentMetadata) + encryptChannelRelay(paymentHash, listOf(ChannelHop(b, c, channelUpdateBC), ChannelHop(c, d, channelUpdateCD)), payloadD) + } + assertEquals(amountBC, amountC) + assertEquals(expiryBC, expiryC) + + // C relays the payment to D. + val addC = UpdateAddHtlc(randomBytes32(), 2, amountC, paymentHash, expiryC, onionC.packet) + val (payloadC, packetD) = decryptChannelRelay(addC, privC) + val addD = UpdateAddHtlc(randomBytes32(), 3, payloadC.amountToForward, paymentHash, payloadC.outgoingCltv, packetD) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( + TlvStream( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata), + ) ) - assertEquals(amountCD, amountD) - assertEquals(expiryCD, expiryD) - val addD = UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet) - val (outerD, innerD, _) = decryptRelayToNonTrampolinePayload(addD, privD) - assertIs(outerD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertNotEquals(invoice.paymentSecret, outerD.paymentSecret) - assertEquals(finalAmount, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(e, innerD.outgoingNodeId) - assertEquals(finalAmount, innerD.totalAmount) - assertEquals(invoice.paymentSecret, innerD.paymentSecret) - assertEquals(invoice.paymentMetadata, innerD.paymentMetadata) - assertEquals(ByteVector("024100"), innerD.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp - assertEquals(listOf(routingHints), innerD.invoiceRoutingInfo) + assertEquals(payloadD, expectedFinalPayload) } @Test - fun `build a trampoline payment to blinded paths`() { + fun `send a trampoline payment to blinded paths`() { val features = Features(Feature.BasicMultiPartPayment to FeatureSupport.Optional) - val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", e, features, Block.LivenetGenesisBlock.hash) - // E uses a 1-hop blinded path from its LSP. + val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", d, features, Block.RegtestGenesisBlock.hash) + // D uses a 1-hop blinded path from its trampoline node C. val (invoice, blindedRoute) = run { val payerKey = randomKey() - val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, null, Block.LivenetGenesisBlock.hash) - val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), null, 1, currentTimestampMillis()) - val blindedPayloads = listOf( - RouteBlindingEncryptedData( - TlvStream( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(channelUpdateDE.shortChannelId), - RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateDE.cltvExpiryDelta, channelUpdateDE.feeProportionalMillionths, channelUpdateDE.feeBaseMsat), - RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), - ) - ), - RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))), - ).map { it.write().byteVector() } - val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(d, e), blindedPayloads) - val paymentInfo = OfferTypes.PaymentInfo(channelUpdateDE.feeBaseMsat, channelUpdateDE.feeProportionalMillionths, channelUpdateDE.cltvExpiryDelta, channelUpdateDE.htlcMinimumMsat, channelUpdateDE.htlcMaximumMsat!!, Features.empty) + val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, "hello", Block.RegtestGenesisBlock.hash) + val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayloadC = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(channelUpdateCD.shortChannelId), + RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateCD.cltvExpiryDelta, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.feeBaseMsat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), + ) + ) + val blindedPayloadD = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)) + ) + ) + val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(c, d), listOf(blindedPayloadC, blindedPayloadD).map { it.write().byteVector() }) + val paymentInfo = createBlindedPaymentInfo(channelUpdateCD) val path = Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRouteDetails.route), paymentInfo) - val invoice = Bolt12Invoice(request, paymentPreimage, blindedRouteDetails.blindedPrivateKey(privE), 600, features, listOf(path)) + val invoice = Bolt12Invoice(request, paymentPreimage, blindedRouteDetails.blindedPrivateKey(privD), 600, features, listOf(path)) assertEquals(invoice.nodeId, blindedRouteDetails.route.blindedNodeIds.last()) + assertNotEquals(invoice.nodeId, d) Pair(invoice, blindedRouteDetails.route) } - // C pays that invoice using a trampoline node to relay to the invoice's blinded path. - val (firstAmount, firstExpiry, onion) = run { - val trampolineHop = NodeHop(d, invoice.nodeId, channelUpdateDE.cltvExpiryDelta, feeD) - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(invoice, trampolineHop, finalAmount, finalExpiry) - val trampolinePayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet) - OutgoingPaymentPacket.buildPacket(invoice.paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), trampolinePayload, OnionRoutingPacket.PaymentPacketLength) - } - assertEquals(amountCD, firstAmount) - assertEquals(expiryCD, firstExpiry) + // B pays that invoice using its trampoline node C to relay to the invoice's blinded path. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToBlindedRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(amountBC, firstAmount) + assertEquals(expiryBC, firstExpiry) - // D decrypts the onion that contains a blinded path in the trampoline onion. - val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (outerD, innerD, _) = decryptRelayToBlinded(addD, privD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertEquals(finalAmount, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(listOf(blindedRoute), innerD.outgoingBlindedPaths.map { it.route.route }) - assertEquals(invoice.features.toByteArray().toByteVector(), innerD.invoiceFeatures) - - // D is the introduction node of the blinded path: it can decrypt the first blinded payload and relay to E. - val addE = run { - val (dataD, blindingE) = RouteBlinding.decryptPayload(privD, blindedRoute.blindingKey, blindedRoute.encryptedPayloads.first()).right!! - val payloadD = RouteBlindingEncryptedData.read(dataD.toByteArray()).right!! - assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) - // D would normally create this payload based on the blinded path's payment_info field. - val payloadE = PaymentOnion.FinalPayload.Blinded( + // C decrypts the onion that contains a blinded path in the trampoline onion. + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (outerC, innerC) = decryptRelayToBlinded(addC, privC) + assertEquals(amountBC, outerC.amount) + assertEquals(amountBC, outerC.totalAmount) + assertEquals(expiryBC, outerC.expiry) + assertEquals(finalAmount, innerC.amountToForward) + assertEquals(finalExpiry, innerC.outgoingCltv) + assertEquals(listOf(blindedRoute), innerC.outgoingBlindedPaths.map { it.route.route }) + assertEquals(invoice.features.toByteArray().toByteVector(), innerC.invoiceFeatures) + + // C is the introduction node of the blinded path: it can decrypt the first blinded payload and relay to D. + val addD = run { + val (dataC, blindingD) = RouteBlinding.decryptPayload(privC, blindedRoute.blindingKey, blindedRoute.encryptedPayloads.first()).right!! + val payloadC = RouteBlindingEncryptedData.read(dataC.toByteArray()).right!! + assertEquals(channelUpdateCD.shortChannelId, payloadC.outgoingChannelId) + // C would normally create this payload based on the payment_relay field it received. + val payloadD = PaymentOnion.FinalPayload.Blinded( TlvStream( - OnionPaymentPayloadTlv.AmountToForward(finalAmount), - OnionPaymentPayloadTlv.TotalAmount(finalAmount), - OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.AmountToForward(innerC.amountToForward), + OnionPaymentPayloadTlv.TotalAmount(innerC.amountToForward), + OnionPaymentPayloadTlv.OutgoingCltv(innerC.outgoingCltv), OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads.last()), ), - // This dummy value is ignored when creating the htlc (D is not the recipient). + // This dummy value is ignored when creating the htlc (C is not the recipient). RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(ByteVector("deadbeef")))) ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(addD.paymentHash, listOf(blindedHop), payloadE, OnionRoutingPacket.PaymentPacketLength) - UpdateAddHtlc(randomBytes32(), 2, amountE, addD.paymentHash, expiryE, onionE.packet, blindingE, null) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(payloadD), addC.paymentHash, OnionRoutingPacket.PaymentPacketLength) + UpdateAddHtlc(randomBytes32(), 2, innerC.amountToForward, addC.paymentHash, innerC.outgoingCltv, onionD.packet, blindingD, null) } - // E can correctly decrypt the blinded payment. - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! - assertIs(payloadE) - val paymentMetadata = OfferPaymentMetadata.fromPathId(e, payloadE.pathId) + // D can correctly decrypt the blinded payment. + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertIs(payloadD) + assertEquals(finalAmount, payloadD.amount) + assertEquals(finalExpiry, payloadD.expiry) + val paymentMetadata = OfferPaymentMetadata.fromPathId(d, payloadD.pathId) assertNotNull(paymentMetadata) assertEquals(offer.offerId, paymentMetadata.offerId) assertEquals(paymentMetadata.paymentHash, invoice.paymentHash) } @Test - fun `build a payment to a blinded path`() { - val (addE, payloadE) = createBlindedHtlc() - // E can correctly decrypt the blinded payment. - assertEquals(payloadE, IncomingPaymentPacket.decrypt(addE, privE).right) + fun `receive a channel payment`() { + // c -> d + val finalPayload = PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 1.5, finalExpiry, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), finalPayload) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertEquals(finalPayload, payloadD) } @Test - fun `fail to decrypt when the onion is invalid`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + fun `receive a blinded channel payment`() { + // c -> d + val (addD, paymentMetadata) = createBlindedHtlcCD() + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertIs(payloadD) + assertEquals(finalAmount, payloadD.amount) + assertEquals(finalAmount, payloadD.totalAmount) + assertEquals(finalExpiry, payloadD.expiry) + assertEquals(paymentMetadata, OfferPaymentMetadata.fromPathId(d, payloadD.pathId)) + } + + @Test + fun `fail to decrypt when the payment onion is invalid`() { + val (firstAmount, firstExpiry, onion) = encryptChannelRelay( paymentHash, - hops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null), - OnionRoutingPacket.PaymentPacketLength + listOf(ChannelHop(c, d, channelUpdateCD)), + PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null) ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) - val failure = IncomingPaymentPacket.decrypt(add, privB) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when the trampoline onion is invalid`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), - OnionRoutingPacket.PaymentPacketLength - ) - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (_, packetC) = decryptChannelRelay(addB, privB) - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val failure = IncomingPaymentPacket.decrypt(addC, privC) + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + // C modifies the trampoline onion before forwarding the trampoline payment to D. + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD.copy(payload = packetD.payload.reversed())) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD, onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when payment hash doesn't match associated data`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = encryptChannelRelay( paymentHash.reversed(), - hops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength + listOf(ChannelHop(c, d, channelUpdateCD)), + PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata) ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when blinded route data is invalid`() { - val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), null, 1, currentTimestampMillis()) - val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))) - val blindedRoute = RouteBlinding.create(randomKey(), listOf(e), listOf(blindedPayload.write().byteVector())).route - val payloadE = PaymentOnion.FinalPayload.Blinded( + val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedRoute = RouteBlinding.create(randomKey(), listOf(d), listOf(blindedPayload.write().byteVector())).route + val payloadD = PaymentOnion.FinalPayload.Blinded( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.TotalAmount(finalAmount), @@ -537,135 +451,105 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ), blindedPayload ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(paymentMetadata.paymentHash, listOf(blindedHop), payloadE, OnionRoutingPacket.PaymentPacketLength) - val addE = UpdateAddHtlc(randomBytes32(), 2, amountE, paymentMetadata.paymentHash, expiryE, onionE.packet, blindedRoute.blindingKey, null) - val failure = IncomingPaymentPacket.decrypt(addE, privE) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(payloadD), paymentHash, OnionRoutingPacket.PaymentPacketLength) + val addD = UpdateAddHtlc(randomBytes32(), 1, finalAmount, paymentHash, finalExpiry, onionD.packet, blindedRoute.blindingKey, null) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) - assertEquals(failure.left, InvalidOnionBlinding(hash(addE.onionRoutingPacket))) + assertEquals(failure.left, InvalidOnionBlinding(hash(addD.onionRoutingPacket))) } @Test - fun `fail to decrypt at the final node when amount has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - hops.take(1), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength - ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) - assertEquals(Either.Left(FinalIncorrectHtlcAmount(firstAmount - 100.msat)), failure) + fun `fail to decrypt when amount has been modified by trampoline node`() { + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD - 100.msat, paymentHash, expiryD, onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) + assertEquals(Either.Left(FinalIncorrectHtlcAmount(finalAmount - 100.msat)), failure) } @Test - fun `fail to decrypt at the final node when expiry has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - hops.take(1), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength - ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) - assertEquals(Either.Left(FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))), failure) + fun `fail to decrypt when expiry has been modified by trampoline node`() { + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD - CltvExpiryDelta(12), onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) + assertEquals(Either.Left(FinalIncorrectCltvExpiry(finalExpiry - CltvExpiryDelta(12))), failure) } @Test - fun `fail to decrypt blinded payment at the final node when amount is too low`() { - val (addE, _) = createBlindedHtlc() - // E receives a smaller amount than expected and rejects the payment. - val failure = IncomingPaymentPacket.decrypt(addE.copy(amountMsat = addE.amountMsat - 1.msat), privE).left - assertEquals(InvalidOnionBlinding(hash(addE.onionRoutingPacket)), failure) + fun `fail to decrypt blinded payment when amount is too low`() { + val (addD, _) = createBlindedHtlcCD() + // D receives a smaller amount than expected and rejects the payment. + val failure = IncomingPaymentPacket.decrypt(addD.copy(amountMsat = addD.amountMsat - 1.msat), privD).left + assertEquals(InvalidOnionBlinding(hash(addD.onionRoutingPacket)), failure) } @Test - fun `fail to decrypt blinded payment at the final node when expiry is too low`() { - val (addE, _) = createBlindedHtlc() + fun `fail to decrypt blinded payment when expiry is too low`() { + val (addD, _) = createBlindedHtlcCD() // E receives a smaller expiry than expected and rejects the payment. - val failure = IncomingPaymentPacket.decrypt(addE.copy(cltvExpiry = addE.cltvExpiry - CltvExpiryDelta(1)), privE).left - assertEquals(InvalidOnionBlinding(hash(addE.onionRoutingPacket)), failure) - } - - @Test - fun `fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) - val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) - // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). - val invalidTotalAmount = amountDE + 100.msat - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) - assertEquals(Either.Left(FinalIncorrectHtlcAmount(invalidTotalAmount)), failure) + val failure = IncomingPaymentPacket.decrypt(addD.copy(cltvExpiry = addD.cltvExpiry - CltvExpiryDelta(1)), privD).left + assertEquals(InvalidOnionBlinding(hash(addD.onionRoutingPacket)), failure) } @Test - fun `fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) - val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) - // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). - val invalidExpiry = expiryDE - CltvExpiryDelta(12) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) - assertEquals(Either.Left(FinalIncorrectCltvExpiry(invalidExpiry)), failure) + fun `prune outgoing blinded paths`() { + // We create an invoice with a large number of blinded paths. + val features = Features(Feature.BasicMultiPartPayment to FeatureSupport.Optional) + val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", d, features, Block.RegtestGenesisBlock.hash) + val invoice = run { + val payerKey = randomKey() + val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, "hello", Block.RegtestGenesisBlock.hash) + val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayloadD = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedPaths = (0 until 20L).map { i -> + val blindedPayloadC = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(i)), + RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateCD.cltvExpiryDelta, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.feeBaseMsat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), + ) + ) + val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(c, d), listOf(blindedPayloadC, blindedPayloadD).map { it.write().byteVector() }) + val paymentInfo = createBlindedPaymentInfo(channelUpdateCD) + Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRouteDetails.route), paymentInfo) + } + Bolt12Invoice(request, paymentPreimage, randomKey(), 600, features, blindedPaths) + } + // B sends an HTLC to its trampoline node C: we prune the blinded paths to fit inside the onion. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToBlindedRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC) = decryptRelayToBlinded(addC, privC) + assertTrue(innerC.outgoingBlindedPaths.size < invoice.blindedPaths.size) + innerC.outgoingBlindedPaths.forEach { assertTrue(invoice.blindedPaths.contains(it)) } } @Test - fun `relay to blinded with many large blinded routes`() { - val invoice = Bolt12Invoice.fromString("lni1qqs0sehhttf0swv6sxsxuefk9q23yj9h0cl4wyn324jlt2gll4fzk0syzquta4dy4m9jgp5s970w2lw8928ppnqz662mq8r6dyy7w9kgv0ann0zlk7aacwpykl7uu5adckf7t0sgpees93j6nwr4xle4ck4d2wup6j044fx6f5zkdacvjen3herm493en2lmqgprfesng92cpl68qvlxv96prguvtl7yv49nwq9a92zl5n2pv73u4wgqx08jupl9mcphw7ww5mv50j7xkcjxvjamrysy3wsfmgj4cu9aq4y6hszdk5rys4qwa4muxczk4nuj54gyw3jqygaqqk2edh9c90ergwvaxu93jynty9z0gzrpcwuhkzt05epatv45qqg93d3x5hx8jgwgkwpnz85y3narw93pqvmp27mzgwjaswd76lgxm88qw2ehqszd2mlgntvnwq0vzrk4ehq325pqdl3gcz4k7xeh9sdx5fr2uclhf7f3aqm9u9dq38rg6cvsqqqqqqq9yqjwyp2qxqsqqpvzzqnfhw87mvwxx0qrljmlxppxr7h4hmdgdzpsxnj0rw2jckjtc93g44v3gmrfva58gmnfdenjq6tnyp3k7mmvypq5dg8aqsqsxhj07sv0ez642nzanm4xvwtvyfaag2dry5wge0r3zqpt5g2mls3xqwsadj5tn7gm50l43xmsu2dszx8s672wrkqn48qp7ftw6kvhltlrkqsrwc9agnmlg8dzefxe4vlz9euhhwjzeflzrrkkurluwhzfjqf62trsq2dm4984vptwe3zpea0e3n27w75pyt2ym8g734awf99ep759a4xdf29epp5z3e303e08q2e6qdycnlzdkv4c80u2pr7ql2ndpa9zfafwyqkf9r9cqtaza35txq9rppm56ajy97nema0j8vtuld097u0hjhh5cmwhlp8du95va7wvq2emxkvafy2wg8ghl7vpkw2w0w0pp2tr6gqqmnvxy9u9jg9rkwn2flxn9d3vd9x0g94wdhgap22l780d0t9zywpj7yu949347hsag0s6pgpkuaf2pg8xqzjk336e8azw8672qtvvljye6gtevmnrygav24jvsxmypf25rqs8ersr9wfsqhdxjl6lmy3lqac4q96rjn02vtk4n6zcaqzr5qhlqlywrga7lmwt99gyj23ndzrl2y8380nxydqntsh7gpajau2wrvpshzpryq37n72ujf26djermgcf47wfsfdts3p97u0lnt6szck8zlqzqfz6ex8p9xtugj6jtuyf3flkqk9ahu8e6dxgvp49xl24x6hdmljcjqptk8erxq4uwnkv4g32p4d8un0f9az5ft59mwvgtf3xxm3xepx68x6vk46pf5xh4lyr35m32q5ukd5xj4kj7vka67uggjd4saahn3dswdwsqe8epg2pvele4cz0zvq288t94n69cj87ewal7yqvvpq6tj6mpjw2d36uhlf2khtjlpdfv672gvkszdpqht37shtf0ytyfnl7xkeudm2xscvpack3whk29hsfr36cv95ddzc2klk2yc93s09y2uh2xz5k5jap4550z2vd7gy9faxd2gcr4xjmke8yhn0wdsqxpe78qyqy0elprd0swe9yxc82as23vm5fyumc7nv470xfsjy3zjg4zfvpa6kv3ndl0w3ukc6627dua9g7pmzfupesqvpjfw3rjt39hlmk40gtralyk5a4lq42207acdqekpgmdjqpmw0zy3cz8d3cpvw4a28lqnf3etu8adl43zra0zk5m094h86g4yfm8cxnz7zqyqa5q4dz6ceqp0jeuakvyx4653h0kjllhhpfduquh4ch8rtu9lf20gqznhxe60n3m45k80vcgnvut8cfgpvy37ae93rpah9sa4ddyzyt9gu0uj5zny00vrkkxycrxm8wk3gzyrn3q8caxrqwa48fwqzlagph95y5tw555px0s97y4lgspgedkg4q9gxgv4qyc7x8433j0a5fnntdjhlsmx4wje53dj8y4c3757yt87kenvej8ugwkg05flugshd906tvtqcldrcgxfqqw8ecz926ft3ehdahpjhueag8znklv6qvz7w3htszz5k3we0dd8r2grxpz5edgtuh28u0c5ugrf8autgu3crkjt58ev9gcvnetn5kjwk4wehr8yq7lulsug7dhyg0reyqlhhsxs8ueghtm02g7jd2p9ytwstsacry3y8mv202y4qqqqqqqqqqqq9qq6qqqqqqqqqqqqqsqqqqqf2q5htpqqqqqqqraqqqqqqpqp5qqqqqqqqqqqqpqqqqqqqqh6uzcqqqqqqqqqqqqqqeqqrgqqqqqqqqqqqqzqqqqqqjnems5qqqpfqyv6g47m4gyqw3qa7penn7yv39ufk8j0mpj0ldfyu9qd7j2vcgt27rqe9l5ydpm2szfcs2uqczqqqtqggrxc2hkcjr5hvrn0kh6pkeecrjkdcyqn2kl6y6mymsrmqsa4wdcy2lqsxw9n37kdz8zq0ytckjvcy7jcwqrqj5rgd56q4e3g76k0lja96hzkdda7z5sxqwemrdjje2rqm7jcv6ll39dut0mqjvrl4skedkfz5mu").get() - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(invoice, NodeHop(randomKey().publicKey(), randomKey().publicKey(), CltvExpiryDelta(444), 0.msat), finalAmount, finalExpiry) - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet) - val channelHops: List = listOf(ChannelHop(randomKey().publicKey(), randomKey().publicKey(), defaultChannelUpdate)) - OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), randomBytes32(), channelHops, trampolinePayload.createFinalPayload(finalAmount)) + fun `prune outgoing routing info`() { + // We create an invoice with a large number of routing hints. + val routingHints = (0 until 50L).map { i -> listOf(Bolt11Invoice.TaggedField.ExtraHop(randomKey().publicKey(), ShortChannelId(i), 5.msat, 25, CltvExpiryDelta(36))) } + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), nonTrampolineFeatures, paymentSecret, paymentMetadata, extraHops = routingHints) + // A sends an HTLC to its trampoline node B: we prune the routing hints to fit inside the onion. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToLegacyRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) + assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) + val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerB) = decryptRelayToLegacy(addB, privB) + assertTrue(innerB.invoiceRoutingInfo.isNotEmpty()) + assertTrue(innerB.invoiceRoutingInfo.size < routingHints.size) + innerB.invoiceRoutingInfo.forEach { assertTrue(routingHints.contains(it)) } } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt deleted file mode 100644 index e04098ea0..000000000 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt +++ /dev/null @@ -1,154 +0,0 @@ -package fr.acinq.lightning.payment - -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.utils.Either -import fr.acinq.lightning.Lightning.randomBytes32 -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId -import fr.acinq.lightning.blockchain.fee.FeeratePerKw -import fr.acinq.lightning.channel.Commitments -import fr.acinq.lightning.channel.states.Normal -import fr.acinq.lightning.channel.states.Offline -import fr.acinq.lightning.channel.states.Syncing -import fr.acinq.lightning.channel.TestsHelper.reachNormal -import fr.acinq.lightning.tests.utils.LightningTestSuite -import fr.acinq.lightning.transactions.CommitmentSpec -import fr.acinq.lightning.utils.* -import kotlin.random.Random -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertTrue - -class RouteCalculationTestsCommon : LightningTestSuite() { - - private val defaultChannel = reachNormal().first - private val paymentId = UUID.randomUUID() - private val routeCalculation = RouteCalculation(loggerFactory) - - private fun makeChannel(channelId: ByteVector32, balance: MilliSatoshi, htlcMin: MilliSatoshi): Normal { - val shortChannelId = ShortChannelId(Random.nextLong()) - val reserve = defaultChannel.commitments.latest.localChannelReserve - val commitments = defaultChannel.commitments.copy( - params = defaultChannel.commitments.params.copy(channelId = channelId), - active = defaultChannel.commitments.active.map { - it.copy(remoteCommit = it.remoteCommit.copy(spec = CommitmentSpec(setOf(), FeeratePerKw(0.sat), 50_000.msat, balance + ((Commitments.ANCHOR_AMOUNT * 2) + reserve).toMilliSatoshi()))) - } - ) - val channelUpdate = defaultChannel.state.channelUpdate.copy(htlcMinimumMsat = htlcMin) - return defaultChannel.state.copy(shortChannelId = shortChannelId, commitments = commitments, channelUpdate = channelUpdate) - } - - @Test - fun `make channel fixture`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val offlineChannels = mapOf( - channelId1 to Offline(makeChannel(channelId1, 15_000.msat, 10.msat)), - channelId2 to Offline(makeChannel(channelId2, 20_000.msat, 5.msat)), - channelId3 to Offline(makeChannel(channelId3, 10_000.msat, 10.msat)), - ) - assertEquals(setOf(10_000.msat, 15_000.msat, 20_000.msat), offlineChannels.map { (it.value.state as Normal).commitments.availableBalanceForSend() }.toSet()) - - val normalChannels = mapOf( - channelId1 to makeChannel(channelId1, 15_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 15_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 10_000.msat, 10.msat), - ) - assertEquals(setOf(10_000.msat, 15_000.msat), normalChannels.map { it.value.commitments.availableBalanceForSend() }.toSet()) - } - - @Test - fun `no available channels`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to Offline(makeChannel(channelId1, 15_000.msat, 10.msat)), - channelId2 to Syncing(makeChannel(channelId2, 20_000.msat, 5.msat), channelReestablishSent = true), - channelId3 to Offline(makeChannel(channelId3, 10_000.msat, 10.msat)), - ) - assertEquals(Either.Left(FinalFailure.ChannelNotConnected), routeCalculation.findRoutes(paymentId, 5_000.msat, channels)) - } - - @Test - fun `insufficient balance`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 15_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 18_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 12_000.msat, 10.msat), - ) - assertEquals(Either.Left(FinalFailure.InsufficientBalance), routeCalculation.findRoutes(paymentId, 50_000.msat, channels)) - } - - @Test - fun `single payment`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - run { - val channels = mapOf( - channelId1 to makeChannel(channelId1, 35_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 30_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 38_000.msat, 10.msat), - ) - val routes = routeCalculation.findRoutes(paymentId, 38_000.msat, channels).right!! - assertEquals(listOf(RouteCalculation.Route(38_000.msat, channels.getValue(channelId3))), routes) - } - run { - val channels = mapOf(channelId3 to makeChannel(channelId3, 38_000.msat, 10.msat)) - val routes = routeCalculation.findRoutes(paymentId, 38_000.msat, channels).right!! - assertEquals(listOf(RouteCalculation.Route(38_000.msat, channels.getValue(channelId3))), routes) - } - } - - @Test - fun `ignore empty channels`() { - val (channelId1, channelId2, channelId3, channelId4) = listOf(randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 0.msat, 10.msat), - channelId2 to makeChannel(channelId2, 50.msat, 100.msat), - channelId3 to makeChannel(channelId3, 30_000.msat, 15.msat), - channelId4 to makeChannel(channelId4, 20_000.msat, 50.msat), - ) - val routes = routeCalculation.findRoutes(paymentId, 50_000.msat, channels).right!! - val expected = setOf( - RouteCalculation.Route(30_000.msat, channels.getValue(channelId3)), - RouteCalculation.Route(20_000.msat, channels.getValue(channelId4)), - ) - assertEquals(expected, routes.toSet()) - assertEquals(Either.Left(FinalFailure.InsufficientBalance), routeCalculation.findRoutes(paymentId, 50010.msat, channels)) - } - - @Test - fun `split payment across many channels`() { - val (channelId1, channelId2, channelId3, channelId4) = listOf(randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 50.msat, 10.msat), - channelId2 to makeChannel(channelId2, 150.msat, 100.msat), - channelId3 to makeChannel(channelId3, 25.msat, 15.msat), - channelId4 to makeChannel(channelId4, 75.msat, 50.msat), - ) - run { - val routes = routeCalculation.findRoutes(paymentId, 300.msat, channels).right!! - val expected = setOf( - RouteCalculation.Route(50.msat, channels.getValue(channelId1)), - RouteCalculation.Route(150.msat, channels.getValue(channelId2)), - RouteCalculation.Route(25.msat, channels.getValue(channelId3)), - RouteCalculation.Route(75.msat, channels.getValue(channelId4)), - ) - assertEquals(expected, routes.toSet()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 250.msat, channels).right!! - assertTrue(routes.size >= 3) - assertEquals(250.msat, routes.map { it.amount }.sum()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 200.msat, channels).right!! - assertTrue(routes.size >= 2) - assertEquals(200.msat, routes.map { it.amount }.sum()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 50.msat, channels).right!! - assertTrue(routes.size == 1) - assertEquals(50.msat, routes.map { it.amount }.sum()) - } - } - -} \ No newline at end of file