diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt index 0c28e038c..a005ac003 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelException.kt @@ -83,9 +83,11 @@ data class CannotAffordFirstCommitFees (override val channelId: Byte data class CannotAffordFees (override val channelId: ByteVector32, val missing: Satoshi, val reserve: Satoshi, val fees: Satoshi) : ChannelException(channelId, "can't pay the fee: missing=$missing reserve=$reserve fees=$fees") data class CannotSignWithoutChanges (override val channelId: ByteVector32) : ChannelException(channelId, "cannot sign when there are no change") data class CannotSignBeforeRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "cannot sign until next revocation hash is received") +data class CannotSignDisconnected (override val channelId: ByteVector32) : ChannelException(channelId, "disconnected before signing outgoing payments") data class UnexpectedRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "received unexpected RevokeAndAck message") data class InvalidRevocation (override val channelId: ByteVector32) : ChannelException(channelId, "invalid revocation") data class InvalidFailureCode (override val channelId: ByteVector32) : ChannelException(channelId, "UpdateFailMalformedHtlc message doesn't have BADONION bit set") +data class CannotDecryptFailure (override val channelId: ByteVector32, val details: String) : ChannelException(channelId, "cannot decrypt failure message: $details") data class PleasePublishYourCommitment (override val channelId: ByteVector32) : ChannelException(channelId, "please publish your local commitment") data class CommandUnavailableInThisState (override val channelId: ByteVector32, val state: String) : ChannelException(channelId, "cannot execute command in state=$state") data class ForbiddenDuringSplice (override val channelId: ByteVector32, val command: String?) : ChannelException(channelId, "cannot process $command while splicing") diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index 9d5aaca5f..e930ecd73 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -812,19 +812,14 @@ class Peer( is ChannelAction.ProcessIncomingHtlc -> processIncomingPayment(Either.Right(action.add)) is ChannelAction.ProcessCmdRes.NotExecuted -> logger.warning(action.t) { "command not executed" } is ChannelAction.ProcessCmdRes.AddFailed -> { - when (val result = outgoingPaymentHandler.processAddFailed(actualChannelId, action, _channels)) { - is OutgoingPaymentHandler.Progress -> { - _eventsFlow.emit(PaymentProgress(result.request, result.fees)) - result.actions.forEach { input.send(it) } - } - + when (val result = outgoingPaymentHandler.processAddFailed(actualChannelId, action)) { is OutgoingPaymentHandler.Failure -> _eventsFlow.emit(PaymentNotSent(result.request, result.failure)) null -> logger.debug { "non-final error, more partial payments are still pending: ${action.error.message}" } } } is ChannelAction.ProcessCmdRes.AddSettledFail -> { val currentTip = currentTipFlow.filterNotNull().first() - when (val result = outgoingPaymentHandler.processAddSettled(actualChannelId, action, _channels, currentTip)) { + when (val result = outgoingPaymentHandler.processAddSettledFailed(actualChannelId, action, _channels, currentTip)) { is OutgoingPaymentHandler.Progress -> { _eventsFlow.emit(PaymentProgress(result.request, result.fees)) result.actions.forEach { input.send(it) } @@ -832,14 +827,13 @@ class Peer( is OutgoingPaymentHandler.Success -> _eventsFlow.emit(PaymentSent(result.request, result.payment)) is OutgoingPaymentHandler.Failure -> _eventsFlow.emit(PaymentNotSent(result.request, result.failure)) - null -> logger.debug { "non-final error, more partial payments are still pending: ${action.result}" } + null -> logger.debug { "non-final error, another payment attempt (retry) is still pending: ${action.result}" } } } is ChannelAction.ProcessCmdRes.AddSettledFulfill -> { - when (val result = outgoingPaymentHandler.processAddSettled(action)) { + when (val result = outgoingPaymentHandler.processAddSettledFulfilled(action)) { is OutgoingPaymentHandler.Success -> _eventsFlow.emit(PaymentSent(result.request, result.payment)) - is OutgoingPaymentHandler.PreimageReceived -> logger.debug(mapOf("paymentId" to result.request.paymentId)) { "payment preimage received: ${result.preimage}" } - null -> logger.debug { "unknown payment" } + null -> logger.error { "unknown payment fulfilled: this should never happen" } } } is ChannelAction.Storage.StoreState -> { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index 8c0c3d042..f0213252f 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -4,11 +4,10 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.Try import fr.acinq.lightning.* -import fr.acinq.lightning.channel.ChannelAction -import fr.acinq.lightning.channel.ChannelException -import fr.acinq.lightning.channel.states.Channel -import fr.acinq.lightning.channel.states.ChannelState +import fr.acinq.lightning.channel.* +import fr.acinq.lightning.channel.states.* import fr.acinq.lightning.crypto.sphinx.FailurePacket +import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.SharedSecrets import fr.acinq.lightning.db.HopDesc import fr.acinq.lightning.db.LightningOutgoingPayment @@ -18,12 +17,13 @@ import fr.acinq.lightning.io.WrappedChannelCommand import fr.acinq.lightning.logging.MDCLogger import fr.acinq.lightning.logging.error import fr.acinq.lightning.logging.mdc -import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.utils.UUID import fr.acinq.lightning.utils.msat -import fr.acinq.lightning.utils.sum -import fr.acinq.lightning.wire.* +import fr.acinq.lightning.wire.FailureMessage +import fr.acinq.lightning.wire.TrampolineExpiryTooSoon +import fr.acinq.lightning.wire.TrampolineFeeInsufficient +import fr.acinq.lightning.wire.UnknownNextPeer class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: WalletParams, val db: OutgoingPaymentsDb) { @@ -37,21 +37,69 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle /** The payment could not be sent. */ data class Failure(val request: PayInvoice, val failure: OutgoingPaymentFailure) : SendPaymentResult, ProcessFailureResult - /** The recipient released the preimage, but we are still waiting for some partial payments to settle. */ - data class PreimageReceived(val request: PayInvoice, val preimage: ByteVector32) : ProcessFulfillResult - /** The payment was successfully made. */ data class Success(val request: PayInvoice, val payment: LightningOutgoingPayment, val preimage: ByteVector32) : ProcessFailureResult, ProcessFulfillResult private val logger = nodeParams.loggerFactory.newLogger(this::class) - private val childToParentId = mutableMapOf() + // Each outgoing HTLC will have its own ID, because its status will be recorded in the payments DB. + // Since we automatically retry on failure, we may have multiple child attempts for each payment. + private val childToPaymentId = mutableMapOf() private val pending = mutableMapOf() - private val routeCalculation = RouteCalculation(nodeParams.loggerFactory) + + /** + * While a payment is in progress, we wait for the outgoing HTLC to settle. + * When we receive a failure, we may retry with a different fee or expiry. + * + * @param request payment request containing the total amount to send. + * @param attemptNumber number of failed previous payment attempts. + * @param pending pending outgoing payment. + * @param sharedSecrets payment onion shared secrets, used to decrypt failures. + * @param failures previous payment failures. + */ + data class PaymentAttempt( + val request: PayInvoice, + val attemptNumber: Int, + val pending: LightningOutgoingPayment.Part, + val sharedSecrets: SharedSecrets, + val failures: List> + ) { + val fees: MilliSatoshi = pending.amount - request.amount + } // NB: this function should only be used in tests. - fun getPendingPayment(parentId: UUID): PaymentAttempt? = pending[parentId] + fun getPendingPayment(paymentId: UUID): PaymentAttempt? = pending[paymentId] - private fun getPaymentAttempt(childId: UUID): PaymentAttempt? = childToParentId[childId]?.let { pending[it] } + private fun getPaymentAttempt(childId: UUID): PaymentAttempt? = childToPaymentId[childId]?.let { pending[it] } + + private suspend fun sendPaymentInternal(request: PayInvoice, failures: List>, channels: Map, currentBlockHeight: Int, logger: MDCLogger): Either { + val attemptNumber = failures.size + val trampolineFees = (request.trampolineFeesOverride ?: walletParams.trampolineFees)[attemptNumber] + logger.info { "trying payment with fee_base=${trampolineFees.feeBase}, fee_proportional=${trampolineFees.feeProportional}" } + val trampolineAmount = request.amount + trampolineFees.calculateFees(request.amount) + return when (val result = selectChannel(trampolineAmount, channels)) { + is Either.Left -> { + logger.warning { "payment failed: ${result.value}" } + if (attemptNumber == 0) { + db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails)) + } + db.completeOutgoingPaymentOffchain(request.paymentId, result.value) + removeFromState(request.paymentId) + Either.Left(Failure(request, OutgoingPaymentFailure(result.value, failures))) + } + is Either.Right -> { + val hop = NodeHop(walletParams.trampolineNode.id, request.recipient, trampolineFees.cltvExpiryDelta, trampolineFees.calculateFees(request.amount)) + val (childPayment, sharedSecrets, cmd) = createOutgoingPayment(request, result.value, hop, currentBlockHeight) + if (attemptNumber == 0) { + db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, listOf(childPayment), LightningOutgoingPayment.Status.Pending)) + } else { + db.addOutgoingLightningParts(request.paymentId, listOf(childPayment)) + } + val payment = PaymentAttempt(request, attemptNumber, childPayment, sharedSecrets, failures) + pending[request.paymentId] = payment + Either.Right(Progress(request, payment.fees, listOf(cmd))) + } + } + } suspend fun sendPayment(request: PayInvoice, channels: Map, currentBlockHeight: Int): SendPaymentResult { val logger = MDCLogger(logger, staticMdc = request.mdc()) @@ -72,404 +120,212 @@ class OutgoingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle logger.error { "invoice has already been paid" } return Failure(request, FinalFailure.AlreadyPaid.toPaymentFailure()) } - val trampolineFees = request.trampolineFeesOverride ?: walletParams.trampolineFees - val (trampolineAmount, trampolineExpiry, trampolinePacket) = createTrampolinePayload(request, trampolineFees.first(), currentBlockHeight) - return when (val result = routeCalculation.findRoutes(request.paymentId, trampolineAmount, channels)) { - is Either.Left -> { - logger.warning { "payment failed: ${result.value}" } - db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails)) - val finalFailure = result.value - db.completeOutgoingPaymentOffchain(request.paymentId, finalFailure) - Failure(request, finalFailure.toPaymentFailure()) - } - is Either.Right -> { - // We generate a random secret for this payment to avoid leaking the invoice secret to the trampoline node. - val trampolinePaymentSecret = Lightning.randomBytes32() - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolinePacket) - val childPayments = createChildPayments(request, result.value, trampolinePayload) - db.addOutgoingPayment(LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, childPayments.map { it.first }, LightningOutgoingPayment.Status.Pending)) - val payment = PaymentAttempt.PaymentInProgress(request, 0, trampolinePayload, childPayments.associate { it.first.id to Pair(it.first, it.second) }, setOf(), listOf()) - pending[request.paymentId] = payment - Progress(request, payment.fees, childPayments.map { it.third }) - } - } - } - - private fun createChildPayments(request: PayInvoice, routes: List, trampolinePayload: PaymentAttempt.TrampolinePayload): List> { - val logger = MDCLogger(logger, staticMdc = request.mdc()) - val childPayments = routes.map { createOutgoingPart(request, it, trampolinePayload) } - childToParentId.putAll(childPayments.map { it.first.id to request.paymentId }) - childPayments.forEach { logger.info(mapOf("childPaymentId" to it.first.id)) { "sending ${it.first.amount} to channel ${it.third.channelId}" } } - return childPayments + return sendPaymentInternal(request, listOf(), channels, currentBlockHeight, logger).fold({ it }, { it }) } - suspend fun processAddFailed(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddFailed, channels: Map): ProcessFailureResult? { - val add = event.cmd - val payment = getPaymentAttempt(add.paymentId) ?: return processPostRestartFailure(add.paymentId, Either.Left(event.error)) - val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to add.paymentId) + payment.request.mdc()) - - logger.debug { "could not send HTLC: ${event.error.message}" } - db.completeOutgoingLightningPart(add.paymentId, OutgoingPaymentFailure.convertFailure(Either.Left(event.error))) - - val (updated, result) = when (payment) { - is PaymentAttempt.PaymentInProgress -> { - val ignore = payment.ignore + channelId // we ignore the failing channel in retries - when (val routes = routeCalculation.findRoutes(payment.request.paymentId, add.amount, channels - ignore)) { - is Either.Left -> PaymentAttempt.PaymentAborted(payment.request, routes.value, payment.pending, payment.failures).failChild(add.paymentId, Either.Left(event.error), db, logger) - is Either.Right -> { - val newPayments = createChildPayments(payment.request, routes.value, payment.trampolinePayload) - db.addOutgoingLightningParts(payment.request.paymentId, newPayments.map { it.first }) - val updatedPayments = payment.pending - add.paymentId + newPayments.map { it.first.id to Pair(it.first, it.second) } - val updated = payment.copy(ignore = ignore, failures = payment.failures + Either.Left(event.error), pending = updatedPayments) - val result = Progress(payment.request, updated.fees, newPayments.map { it.third }) - Pair(updated, result) - } - } - } - is PaymentAttempt.PaymentAborted -> payment.failChild(add.paymentId, Either.Left(event.error), db, logger) - is PaymentAttempt.PaymentSucceeded -> payment.failChild(add.paymentId, db, logger) + /** + * This may happen if we hit channel limits (e.g. max-accepted-htlcs). + * This is a temporary failure that we cannot automatically resolve: we must wait for the channel to be usable again. + */ + suspend fun processAddFailed(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddFailed): Failure? { + val payment = getPaymentAttempt(event.cmd.paymentId) ?: return processPostRestartFailure(event.cmd.paymentId, Either.Left(event.error)) + val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to event.cmd.paymentId) + payment.request.mdc()) + + if (payment.pending.id != event.cmd.paymentId) { + logger.warning { "ignoring HTLC that does not match pending payment part (${event.cmd.paymentId} != ${payment.pending.id})" } + return null } - updateGlobalState(add.paymentId, updated) - - return result + logger.info { "could not send HTLC: ${event.error.message}" } + db.completeOutgoingLightningPart(event.cmd.paymentId, OutgoingPaymentFailure.convertFailure(Either.Left(event.error))) + db.completeOutgoingPaymentOffchain(payment.request.paymentId, FinalFailure.NoAvailableChannels) + removeFromState(payment.request.paymentId) + return Failure(payment.request, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, payment.failures + Either.Left(event.error))) } - suspend fun processAddSettled(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddSettledFail, channels: Map, currentBlockHeight: Int): ProcessFailureResult? { - val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFailure(event.paymentId, Either.Right(UnknownFailureMessage(0))) + suspend fun processAddSettledFailed(channelId: ByteVector32, event: ChannelAction.ProcessCmdRes.AddSettledFail, channels: Map, currentBlockHeight: Int): ProcessFailureResult? { + val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFailure(event.paymentId, Either.Left(CannotDecryptFailure(channelId, "restarted"))) val logger = MDCLogger(logger, staticMdc = mapOf("channelId" to channelId, "childPaymentId" to event.paymentId) + payment.request.mdc()) - val failure: FailureMessage = when (event.result) { - is ChannelAction.HtlcResult.Fail.RemoteFail -> when (val part = payment.pending[event.paymentId]) { - null -> UnknownFailureMessage(0) - else -> when (val decrypted = FailurePacket.decrypt(event.result.fail.reason.toByteArray(), part.second)) { - is Try.Failure -> UnknownFailureMessage(1) - is Try.Success -> decrypted.result.failureMessage - } - } - else -> UnknownFailureMessage(FailureMessage.BADONION) + if (payment.pending.id != event.paymentId) { + logger.warning { "ignoring HTLC that does not match latest payment part (${event.paymentId} != ${payment.pending.id})" } + // This case may happen when we receive AddSettledFailed again for the previous attempt. + // This can happen if we disconnect and re-process the update_fail_htlc message on reconnection. + return null } - logger.debug { "HTLC failed: ${failure.message}" } - db.completeOutgoingLightningPart(event.paymentId, OutgoingPaymentFailure.convertFailure(Either.Right(failure))) - - val (updated, result) = when (payment) { - is PaymentAttempt.PaymentInProgress -> { - val trampolineFees = payment.request.trampolineFeesOverride ?: walletParams.trampolineFees - val finalError = when { - trampolineFees.size <= payment.attemptNumber + 1 -> FinalFailure.RetryExhausted - failure == UnknownNextPeer -> FinalFailure.RecipientUnreachable - failure != TrampolineExpiryTooSoon && failure != TrampolineFeeInsufficient -> FinalFailure.UnknownError // non-retriable error - else -> null + val failure = when (event.result) { + is ChannelAction.HtlcResult.Fail.RemoteFail -> when (val decrypted = FailurePacket.decrypt(event.result.fail.reason.toByteArray(), payment.sharedSecrets)) { + is Try.Failure -> { + logger.warning { "could not decrypt failure packet: ${decrypted.error.message}" } + Either.Left(CannotDecryptFailure(channelId, decrypted.error.message ?: "unknown")) } - if (finalError != null) { - PaymentAttempt.PaymentAborted(payment.request, finalError, payment.pending, listOf()).failChild(event.paymentId, Either.Right(failure), db, logger) - } else { - // The trampoline node is asking us to retry the payment with more fees. - logger.debug { "child payment failed because of fees" } - val updated = payment.copy(pending = payment.pending - event.paymentId) - if (updated.pending.isNotEmpty()) { - // We wait for all pending HTLCs to be settled before retrying. - // NB: we don't update failures here to avoid duplicate trampoline errors - Pair(updated, null) - } else { - val nextFees = trampolineFees[payment.attemptNumber + 1] - logger.info { "retrying payment with higher fees (base=${nextFees.feeBase}, proportional=${nextFees.feeProportional})..." } - val (trampolineAmount, trampolineExpiry, trampolinePacket) = createTrampolinePayload(payment.request, nextFees, currentBlockHeight) - when (val routes = routeCalculation.findRoutes(payment.request.paymentId, trampolineAmount, channels)) { - is Either.Left -> { - logger.warning { "payment failed: ${routes.value}" } - val aborted = PaymentAttempt.PaymentAborted(payment.request, routes.value, mapOf(), payment.failures + Either.Right(failure)) - val result = Failure(payment.request, OutgoingPaymentFailure(aborted.reason, aborted.failures)) - db.completeOutgoingPaymentOffchain(payment.request.paymentId, result.failure.reason) - Pair(aborted, result) - } - is Either.Right -> { - // We generate a random secret for this payment to avoid leaking the invoice secret to the trampoline node. - val trampolinePaymentSecret = Lightning.randomBytes32() - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolinePacket) - val childPayments = createChildPayments(payment.request, routes.value, trampolinePayload) - db.addOutgoingLightningParts(payment.request.paymentId, childPayments.map { it.first }) - val newAttempt = PaymentAttempt.PaymentInProgress( - payment.request, - payment.attemptNumber + 1, - trampolinePayload, - childPayments.associate { it.first.id to Pair(it.first, it.second) }, - setOf(), // we reset ignored channels - payment.failures + Either.Right(failure) - ) - val result = Progress(newAttempt.request, newAttempt.fees, childPayments.map { it.third }) - Pair(newAttempt, result) - } - } - } + is Try.Success -> { + logger.debug { "HTLC failed: ${decrypted.result.failureMessage.message}" } + Either.Right(decrypted.result.failureMessage) } } - is PaymentAttempt.PaymentAborted -> payment.failChild(event.paymentId, Either.Right(failure), db, logger) - is PaymentAttempt.PaymentSucceeded -> payment.failChild(event.paymentId, db, logger) + is ChannelAction.HtlcResult.Fail.RemoteFailMalformed -> { + logger.warning { "our peer couldn't decrypt our payment onion (failureCode=${event.result.fail.failureCode})" } + Either.Left(CannotDecryptFailure(channelId, "malformed onion")) + } + is ChannelAction.HtlcResult.Fail.OnChainFail -> { + logger.warning { "channel closed while our HTLC was pending: ${event.result.cause.message}" } + Either.Left(event.result.cause) + } + is ChannelAction.HtlcResult.Fail.Disconnected -> { + logger.warning { "we got disconnected before signing outgoing HTLC" } + Either.Left(CannotSignDisconnected(channelId)) + } } - updateGlobalState(event.paymentId, updated) + // We update the status in our DB. + db.completeOutgoingLightningPart(event.paymentId, OutgoingPaymentFailure.convertFailure(failure)) - return result + val trampolineFees = payment.request.trampolineFeesOverride ?: walletParams.trampolineFees + val finalError = when { + trampolineFees.size <= payment.attemptNumber + 1 -> FinalFailure.RetryExhausted + failure == Either.Right(UnknownNextPeer) -> FinalFailure.RecipientUnreachable + failure != Either.Right(TrampolineExpiryTooSoon) && failure != Either.Right(TrampolineFeeInsufficient) -> FinalFailure.UnknownError // non-retriable error + else -> null + } + return if (finalError != null) { + db.completeOutgoingPaymentOffchain(payment.request.paymentId, finalError) + removeFromState(payment.request.paymentId) + Failure(payment.request, OutgoingPaymentFailure(finalError, payment.failures + failure)) + } else { + // The trampoline node is asking us to retry the payment with more fees or a larger expiry delta. + sendPaymentInternal(payment.request, payment.failures + failure, channels, currentBlockHeight, logger).fold({ it }, { it }) + } } - private suspend fun processPostRestartFailure(partId: UUID, failure: Either): ProcessFailureResult? { - when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { + private suspend fun processPostRestartFailure(partId: UUID, failure: Either): Failure? { + return when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { null -> { logger.error { "paymentId=$partId doesn't match any known payment attempt" } - return null + null } else -> { val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to partId) + payment.mdc()) logger.debug { "could not send HTLC (wallet restart): ${failure.fold({ it.message }, { it.message })}" } val status = LightningOutgoingPayment.Part.Status.Failed(OutgoingPaymentFailure.convertFailure(failure)) db.completeOutgoingLightningPart(partId, status.failure) - val hasMorePendingParts = payment.parts.any { it.status == LightningOutgoingPayment.Part.Status.Pending && it.id != partId } - return if (!hasMorePendingParts) { - logger.warning { "payment failed: ${FinalFailure.WalletRestarted}" } - db.completeOutgoingPaymentOffchain(payment.id, FinalFailure.WalletRestarted) - val request = when (payment.details) { - is LightningOutgoingPayment.Details.Normal -> PayInvoice(payment.id, payment.recipientAmount, payment.details) - else -> { - logger.debug { "cannot recreate send-payment-request failure from db data with details=${payment.details}" } - return null - } - } - Failure( - request = request, - failure = OutgoingPaymentFailure( - reason = FinalFailure.WalletRestarted, - failures = payment.parts.map { it.status }.filterIsInstance() + status - ) - ) - } else { - null - } + logger.warning { "payment failed: ${FinalFailure.WalletRestarted}" } + val request = PayInvoice(payment.id, payment.recipientAmount, payment.details) + db.completeOutgoingPaymentOffchain(payment.id, FinalFailure.WalletRestarted) + removeFromState(payment.id) + val failures = payment.parts.map { it.status }.filterIsInstance() + status + Failure(request, OutgoingPaymentFailure(FinalFailure.WalletRestarted, failures)) } } } - suspend fun processAddSettled(event: ChannelAction.ProcessCmdRes.AddSettledFulfill): ProcessFulfillResult? { + suspend fun processAddSettledFulfilled(event: ChannelAction.ProcessCmdRes.AddSettledFulfill): Success? { val preimage = event.result.paymentPreimage val payment = getPaymentAttempt(event.paymentId) ?: return processPostRestartFulfill(event.paymentId, preimage) val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to event.paymentId) + payment.request.mdc()) - logger.debug { "HTLC fulfilled" } - val part = payment.pending[event.paymentId]?.first?.copy(status = LightningOutgoingPayment.Part.Status.Succeeded(preimage)) - db.completeOutgoingLightningPart(event.paymentId, preimage) - - val updated = when (payment) { - is PaymentAttempt.PaymentInProgress -> PaymentAttempt.PaymentSucceeded(payment.request, preimage, part?.let { listOf(it) } ?: listOf(), payment.pending - event.paymentId) - is PaymentAttempt.PaymentSucceeded -> payment.copy(pending = payment.pending - event.paymentId, parts = part?.let { payment.parts + it } ?: payment.parts) - is PaymentAttempt.PaymentAborted -> { - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - logger.warning { "payment succeeded after partial failure: we may have paid less than the full amount" } - PaymentAttempt.PaymentSucceeded(payment.request, preimage, part?.let { listOf(it) } ?: listOf(), payment.pending - event.paymentId) - } + if (payment.pending.id != event.paymentId) { + logger.error { "fulfilled HTLC that does not match latest payment part, this should never happen (${event.paymentId} != ${payment.pending.id})" } + return null } - updateGlobalState(event.paymentId, updated) - - return if (updated.isComplete()) { - logger.info { "payment successfully sent (fees=${updated.fees})" } - db.completeOutgoingPaymentOffchain(payment.request.paymentId, preimage) - val r = payment.request - Success(r, LightningOutgoingPayment(r.paymentId, r.amount, r.recipient, r.paymentDetails, updated.parts, LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage)), preimage) - } else { - PreimageReceived(payment.request, preimage) - } + logger.info { "payment successfully sent (fees=${payment.fees})" } + db.completeOutgoingLightningPart(event.paymentId, preimage) + db.completeOutgoingPaymentOffchain(payment.request.paymentId, preimage) + removeFromState(payment.request.paymentId) + val status = LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage) + val part = payment.pending.copy(status = LightningOutgoingPayment.Part.Status.Succeeded(preimage)) + val result = LightningOutgoingPayment(payment.request.paymentId, payment.request.amount, payment.request.recipient, payment.request.paymentDetails, listOf(part), status) + return Success(payment.request, result, preimage) } - private suspend fun processPostRestartFulfill(partId: UUID, preimage: ByteVector32): ProcessFulfillResult? { - when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { + private suspend fun processPostRestartFulfill(partId: UUID, preimage: ByteVector32): Success? { + return when (val payment = db.getLightningOutgoingPaymentFromPartId(partId)) { null -> { logger.error { "paymentId=$partId doesn't match any known payment attempt" } - return null + null } else -> { val logger = MDCLogger(logger, staticMdc = mapOf("childPaymentId" to partId) + payment.mdc()) - logger.debug { "HTLC succeeded (wallet restart): $preimage" } db.completeOutgoingLightningPart(partId, preimage) - // We try to re-create the request from what we have in the DB. - val request = when (payment.details) { - is LightningOutgoingPayment.Details.Normal -> PayInvoice(payment.id, payment.recipientAmount, payment.details) - else -> { - logger.warning { "cannot recreate send-payment-request fulfill from db data with details=${payment.details}" } - return null - } - } - val hasMorePendingParts = payment.parts.any { it.status == LightningOutgoingPayment.Part.Status.Pending && it.id != partId } - return if (!hasMorePendingParts) { - logger.info { "payment successfully sent (wallet restart)" } - db.completeOutgoingPaymentOffchain(payment.id, preimage) - // NB: we reload the payment to ensure all parts status are updated - // this payment cannot be null - val succeeded = db.getLightningOutgoingPayment(payment.id)!! - Success(request, succeeded, preimage) - } else { - PreimageReceived(request, preimage) - } + logger.info { "payment successfully sent (wallet restart)" } + val request = PayInvoice(payment.id, payment.recipientAmount, payment.details) + db.completeOutgoingPaymentOffchain(payment.id, preimage) + removeFromState(payment.id) + // NB: we reload the payment to ensure all parts status are updated + // this payment cannot be null + val succeeded = db.getLightningOutgoingPayment(payment.id)!! + Success(request, succeeded, preimage) } } } - private fun updateGlobalState(processedChildId: UUID, updatedPayment: PaymentAttempt) { - childToParentId.remove(processedChildId) - if (updatedPayment.isComplete()) { - pending.remove(updatedPayment.request.paymentId) - } else { - pending[updatedPayment.request.paymentId] = updatedPayment + private fun removeFromState(paymentId: UUID) { + val children = childToPaymentId.filterValues { it == paymentId }.keys + children.forEach { childToPaymentId.remove(it) } + pending.remove(paymentId) + } + + /** + * We assume that we have at most one channel with our trampoline node. + * We return it if it's ready and has enough balance for the payment, otherwise we return a failure. + */ + private fun selectChannel(toSend: MilliSatoshi, channels: Map): Either { + return when (val available = channels.values.firstOrNull { it is Normal }) { + is Normal -> when { + toSend < available.channelUpdate.htlcMinimumMsat -> Either.Left(FinalFailure.InvalidPaymentAmount) + available.commitments.availableBalanceForSend() < toSend -> Either.Left(FinalFailure.InsufficientBalance) + else -> Either.Right(available) + } + else -> { + val failure = when { + channels.values.any { it is Syncing || it is Offline } -> FinalFailure.ChannelNotConnected + channels.values.any { it is WaitForOpenChannel || it is WaitForAcceptChannel || it is WaitForFundingCreated || it is WaitForFundingSigned || it is WaitForFundingConfirmed || it is WaitForChannelReady } -> FinalFailure.ChannelOpening + channels.values.any { it is ShuttingDown || it is Negotiating || it is Closing || it is Closed || it is WaitForRemotePublishFutureCommitment } -> FinalFailure.ChannelClosing + else -> FinalFailure.NoAvailableChannels + } + Either.Left(failure) + } } } - private fun createOutgoingPart(request: PayInvoice, route: RouteCalculation.Route, trampolinePayload: PaymentAttempt.TrampolinePayload): Triple { + private fun createOutgoingPayment(request: PayInvoice, channel: Normal, hop: NodeHop, currentBlockHeight: Int): Triple { + val logger = MDCLogger(logger, staticMdc = request.mdc()) val childId = UUID.randomUUID() + childToPaymentId[childId] = request.paymentId + val (amount, expiry, onion) = createPaymentOnion(request, hop, currentBlockHeight) val outgoingPayment = LightningOutgoingPayment.Part( id = childId, - amount = route.amount, - route = listOf(HopDesc(nodeParams.nodeId, route.channel.commitments.remoteNodeId, route.channel.shortChannelId), HopDesc(route.channel.commitments.remoteNodeId, request.recipient)), + amount = amount, + route = listOf(HopDesc(nodeParams.nodeId, hop.nodeId, channel.shortChannelId), HopDesc(hop.nodeId, hop.nextNodeId)), status = LightningOutgoingPayment.Part.Status.Pending ) - val channelHops: List = listOf(ChannelHop(nodeParams.nodeId, route.channel.commitments.remoteNodeId, route.channel.channelUpdate)) - val (cmdAdd, secrets) = OutgoingPaymentPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) - return Triple(outgoingPayment, secrets, WrappedChannelCommand(route.channel.channelId, cmdAdd)) + logger.info { "sending $amount to channel ${channel.shortChannelId}" } + val add = ChannelCommand.Htlc.Add(amount, request.paymentHash, expiry, onion.packet, paymentId = childId, commit = true) + return Triple(outgoingPayment, onion.sharedSecrets, WrappedChannelCommand(channel.channelId, add)) } - private fun createTrampolinePayload(request: PayInvoice, fees: TrampolineFees, currentBlockHeight: Int): Triple { - // We are either directly paying our peer (the trampoline node) or a remote node via our peer (using trampoline). - val trampolineRoute = when (request.recipient) { - walletParams.trampolineNode.id -> listOf( - NodeHop(nodeParams.nodeId, request.recipient, /* ignored */ CltvExpiryDelta(0), /* ignored */ 0.msat) - ) - else -> listOf( - NodeHop(nodeParams.nodeId, walletParams.trampolineNode.id, /* ignored */ CltvExpiryDelta(0), /* ignored */ 0.msat), - NodeHop(walletParams.trampolineNode.id, request.recipient, fees.cltvExpiryDelta, fees.calculateFees(request.amount)) - ) - } - when (val paymentRequest = request.paymentDetails.paymentRequest) { + private fun createPaymentOnion(request: PayInvoice, hop: NodeHop, currentBlockHeight: Int): Triple { + return when (val paymentRequest = request.paymentDetails.paymentRequest) { is Bolt11Invoice -> { val minFinalExpiryDelta = paymentRequest.minFinalExpiryDelta ?: Channel.MIN_CLTV_EXPIRY_DELTA - val finalExpiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) - val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(request.amount, finalExpiry, paymentRequest.paymentSecret, paymentRequest.paymentMetadata) + val expiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) val invoiceFeatures = paymentRequest.features - val (trampolineAmount, trampolineExpiry, trampolineOnion) = if (invoiceFeatures.hasFeature(Feature.TrampolinePayment) || invoiceFeatures.hasFeature(Feature.ExperimentalTrampolinePayment)) { - // We may be paying an older version of lightning-kmp that only supports trampoline packets of size 400. - OutgoingPaymentPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, 400) + if (request.recipient == walletParams.trampolineNode.id) { + // We are directly paying our trampoline node. + OutgoingPaymentPacket.buildPacketToTrampolinePeer(paymentRequest, request.amount, expiry) + } else if (invoiceFeatures.hasFeature(Feature.TrampolinePayment) || invoiceFeatures.hasFeature(Feature.ExperimentalTrampolinePayment)) { + OutgoingPaymentPacket.buildPacketToTrampolineRecipient(paymentRequest, request.amount, expiry, hop) } else { - OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(paymentRequest, trampolineRoute, finalPayload) + OutgoingPaymentPacket.buildPacketToLegacyRecipient(paymentRequest, request.amount, expiry, hop) } - return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) } is Bolt12Invoice -> { - val finalExpiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, CltvExpiryDelta(0)) - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(paymentRequest, trampolineRoute.last(), request.amount, finalExpiry) - return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) - } - } - } - - sealed class PaymentAttempt { - abstract val request: PayInvoice - abstract val pending: Map> - abstract val fees: MilliSatoshi - - fun isComplete(): Boolean = pending.isEmpty() - - /** - * @param totalAmount total amount that the trampoline node should receive. - * @param expiry expiry at the trampoline node. - * @param paymentSecret trampoline payment secret (should be different from the invoice payment secret). - * @param packet trampoline onion packet. - */ - data class TrampolinePayload(val totalAmount: MilliSatoshi, val expiry: CltvExpiry, val paymentSecret: ByteVector32, val packet: OnionRoutingPacket) { - fun createFinalPayload(partialAmount: MilliSatoshi): PaymentOnion.FinalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) - } - - /** - * While a payment is in progress, we listen to child payments failures. - * When we receive failures, we retry the failed amount with different routes/fees. - * - * @param request payment request containing the total amount to send. - * @param attemptNumber number of failed previous payment attempts. - * @param trampolinePayload trampoline payload for the current payment attempt. - * @param pending pending child payments (HTLCs were sent, we are waiting for a fulfill or a failure). - * @param ignore channels that should be ignored (previously returned an error). - * @param failures previous child payment failures. - */ - data class PaymentInProgress( - override val request: PayInvoice, - val attemptNumber: Int, - val trampolinePayload: TrampolinePayload, - override val pending: Map>, - val ignore: Set, - val failures: List> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = pending.values.map { it.first.amount }.sum() - request.amount - } - - /** - * When we exhaust our retry attempts without success or encounter a non-recoverable error, we abort the payment. - * Once we're in that state, we wait for all the pending child payments to settle. - * - * @param request payment request containing the total amount to send. - * @param reason failure reason. - * @param pending pending child payments (we are waiting for them to be failed downstream). - * @param failures child payment failures. - */ - data class PaymentAborted( - override val request: PayInvoice, - val reason: FinalFailure, - override val pending: Map>, - val failures: List> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = 0.msat - - suspend fun failChild(childId: UUID, failure: Either, db: OutgoingPaymentsDb, logger: MDCLogger): Pair { - val updated = copy(pending = pending - childId, failures = failures + failure) - val result = if (updated.isComplete()) { - logger.warning { "payment failed: ${updated.reason}" } - db.completeOutgoingPaymentOffchain(request.paymentId, updated.reason) - Failure(request, OutgoingPaymentFailure(updated.reason, updated.failures)) - } else { - null - } - return Pair(updated, result) - } - } - - /** - * Once we receive a first fulfill for a child payment, we can consider that the whole payment succeeded (because we - * received the payment preimage that we can use as a proof of payment). - * Once we're in that state, we wait for all the pending child payments to fulfill. - * - * @param request payment request containing the total amount to send. - * @param preimage payment preimage. - * @param parts fulfilled child payments. - * @param pending pending child payments (we are waiting for them to be fulfilled downstream). - */ - data class PaymentSucceeded( - override val request: PayInvoice, - val preimage: ByteVector32, - val parts: List, - override val pending: Map> - ) : PaymentAttempt() { - override val fees: MilliSatoshi = parts.map { it.amount }.sum() + pending.values.map { it.first.amount }.sum() - request.amount - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - suspend fun failChild(childId: UUID, db: OutgoingPaymentsDb, logger: MDCLogger): Pair { - logger.warning { "partial payment failure after fulfill: we may have paid less than the full amount" } - val updated = copy(pending = pending - childId) - val result = if (updated.isComplete()) { - logger.info { "payment successfully sent (fees=${updated.fees})" } - db.completeOutgoingPaymentOffchain(request.paymentId, preimage) - Success(request, LightningOutgoingPayment(request.paymentId, request.amount, request.recipient, request.paymentDetails, parts, LightningOutgoingPayment.Status.Completed.Succeeded.OffChain(preimage)), preimage) - } else { - null - } - return Pair(updated, result) + // The recipient already included a final cltv-expiry-delta in their invoice blinded paths. + val minFinalExpiryDelta = CltvExpiryDelta(0) + val expiry = nodeParams.paymentRecipientExpiryParams.computeFinalExpiry(currentBlockHeight, minFinalExpiryDelta) + OutgoingPaymentPacket.buildPacketToBlindedRecipient(paymentRequest, request.amount, expiry, hop) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt index 920ac64d1..6c09586c1 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt @@ -6,17 +6,14 @@ import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.PublicKey import fr.acinq.bitcoin.utils.Either import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.Feature import fr.acinq.lightning.Lightning import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.channel.ChannelCommand import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets -import fr.acinq.lightning.crypto.sphinx.SharedSecrets import fr.acinq.lightning.crypto.sphinx.Sphinx -import fr.acinq.lightning.router.ChannelHop -import fr.acinq.lightning.router.Hop import fr.acinq.lightning.router.NodeHop -import fr.acinq.lightning.utils.UUID import fr.acinq.lightning.wire.* object OutgoingPaymentPacket { @@ -24,128 +21,127 @@ object OutgoingPaymentPacket { /** * Build an encrypted onion packet from onion payloads and node public keys. */ - private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int?): PacketAndSecrets { - require(nodes.size == payloads.size) + fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int? = null): PacketAndSecrets { val sessionKey = Lightning.randomKey() + return buildOnion(sessionKey, nodes, payloads, associatedData, payloadLength) + } + + private fun buildOnion(sessionKey: PrivateKey, nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int? = null): PacketAndSecrets { + require(nodes.size == payloads.size) val payloadsBin = payloads.map { it.write() } val totalPayloadLength = payloadLength ?: payloadsBin.sumOf { it.size + Sphinx.MacLength } return Sphinx.create(sessionKey, nodes, payloadsBin, associatedData, totalPayloadLength) } /** - * Build the onion payloads for each hop. + * Build an encrypted payment onion packet when the final recipient supports trampoline. + * The trampoline node will receive instructions on how much to relay to the final recipient. * - * @param hops the hops as computed by the router + extra routes from payment request - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, payloads) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - a sequence of payloads that will be used to build the onion + * @param invoice a Bolt 11 invoice that contains the trampoline feature bit. + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. + * @param hop the trampoline hop from the trampoline node to the recipient. */ - private fun buildPayloads(hops: List, finalPayload: PaymentOnion.FinalPayload): Triple> { - return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> - val (amount, expiry, payloads) = triple - val payload = when (hop) { - // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. - is ChannelHop -> PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) - is NodeHop -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) - } - Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) + fun buildPacketToTrampolineRecipient(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + require(invoice.features.hasFeature(Feature.ExperimentalTrampolinePayment) || invoice.features.hasFeature(Feature.TrampolinePayment)) { "invoice must support trampoline" } + val trampolineOnion = run { + val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, invoice.paymentMetadata) + val trampolinePayload = PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + // We may be paying an older version of lightning-kmp that only supports trampoline packets of size 400. + buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, finalPayload), invoice.paymentHash, payloadLength = 400) } + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + // We generate a random secret to avoid leaking the invoice secret to the trampoline node. + val trampolinePaymentSecret = Lightning.randomBytes32() + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } /** - * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. - * The next-to-last trampoline node payload will contain instructions to convert to a legacy payment. + * Build an encrypted payment onion packet when the final recipient is our trampoline node. * - * @param invoice a Bolt11 invoice (features and routing hints will be provided to the next-to-last node). - * @param hops the trampoline hops (including ourselves in the first hop, and the non-trampoline final recipient in the last hop). - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) triple where: - * - firstAmount is the amount for the trampoline node in the route - * - firstExpiry is the cltv expiry for the first trampoline node in the route - * - the trampoline onion to include in final payload of a normal onion + * @param invoice a Bolt 11 invoice that contains the trampoline feature bit. + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. */ - fun buildTrampolineToNonTrampolinePacket(invoice: Bolt11Invoice, hops: List, finalPayload: PaymentOnion.FinalPayload.Standard): Triple { - // NB: the final payload will never reach the recipient, since the next-to-last trampoline hop will convert that to a legacy payment - // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. - val dummyFinalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, finalPayload.paymentSecret, null) - val (firstAmount, firstExpiry, initialPayloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(dummyFinalPayload))) { triple, hop -> - val (amount, expiry, payloads) = triple - val payload = when (payloads.size) { - // The next-to-last trampoline hop must include invoice data to indicate the conversion to a legacy payment. - 1 -> PaymentOnion.RelayToNonTrampolinePayload.create(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) - else -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) - } - Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) - } - var payloads = initialPayloads - val nodes = hops.map { it.nextNodeId } - var onion = buildOnion(nodes, payloads, invoice.paymentHash, payloadLength = null) - // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. - while (onion.packet.payload.size() > 1000) { - payloads = payloads.map { payload -> when (payload) { - is PaymentOnion.RelayToNonTrampolinePayload -> payload.copy(records = payload.records.copy(records = payload.records.records.map { when (it) { - is OnionPaymentPayloadTlv.InvoiceRoutingInfo -> OnionPaymentPayloadTlv.InvoiceRoutingInfo(it.extraHops.dropLast(1)) - else -> it - } }.toSet())) - else -> payload - } } - onion = buildOnion(nodes, payloads, invoice.paymentHash, payloadLength = null) + fun buildPacketToTrampolinePeer(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry): Triple { + require(invoice.features.hasFeature(Feature.ExperimentalTrampolinePayment) || invoice.features.hasFeature(Feature.TrampolinePayment)) { "invoice must support trampoline" } + val trampolineOnion = run { + val finalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, invoice.paymentMetadata) + buildOnion(listOf(invoice.nodeId), listOf(finalPayload), invoice.paymentHash) } - return Triple(firstAmount, firstExpiry, onion) + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, invoice.paymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(invoice.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(amount, expiry, paymentOnion) } /** - * Build an encrypted trampoline onion packet when the final recipient is using a blinded path. - * The trampoline payload will contain data from the invoice to allow the trampoline node to pay the blinded path. - * We only need a single trampoline node, who will find a route to the blinded path's introduction node without learning the recipient's identity. + * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. + * The trampoline node will receive instructions to convert to a legacy payment. + * This reveals to the trampoline node who the recipient is and details from the invoice. + * This must be deprecated once recipients support either trampoline or blinded paths. * - * @param invoice a Bolt12 invoice (blinded path data will be provided to the trampoline node). + * @param invoice a Bolt11 invoice (features and routing hints will be provided to the trampoline node). + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. * @param hop the trampoline hop from the trampoline node to the recipient. - * @param finalAmount amount that should be received by the final recipient. - * @param finalExpiry cltv expiry that should be received by the final recipient. */ - fun buildTrampolineToNonTrampolinePacket(invoice: Bolt12Invoice, hop: NodeHop, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry): Triple { - var payload = PaymentOnion.RelayToBlindedPayload.create(finalAmount, finalExpiry, invoice) - var onion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, payloadLength = null) - // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. - while (onion.packet.payload.size() > 1000) { - payload = payload.copy(records = payload.records.copy(records = payload.records.records.map { when (it) { - is OnionPaymentPayloadTlv.OutgoingBlindedPaths -> OnionPaymentPayloadTlv.OutgoingBlindedPaths(it.paths.dropLast(1)) - else -> it - } }.toSet())) - onion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, payloadLength = null) + fun buildPacketToLegacyRecipient(invoice: Bolt11Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + val trampolineOnion = run { + // NB: the final payload will never reach the recipient, since the trampoline node will convert that to a legacy payment. + // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. + val dummyFinalPayload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, invoice.paymentSecret, null) + var routingInfo = invoice.routingInfo + var trampolinePayload = PaymentOnion.RelayToNonTrampolinePayload.create(amount, amount, expiry, hop.nextNodeId, invoice, routingInfo) + var trampolineOnion = buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, dummyFinalPayload), invoice.paymentHash) + // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. + while (trampolineOnion.packet.payload.size() > 1000) { + routingInfo = routingInfo.dropLast(1) + trampolinePayload = PaymentOnion.RelayToNonTrampolinePayload.create(amount, amount, expiry, hop.nextNodeId, invoice, routingInfo) + trampolineOnion = buildOnion(listOf(hop.nodeId, hop.nextNodeId), listOf(trampolinePayload, dummyFinalPayload), invoice.paymentHash) + } + trampolineOnion } - return Triple(finalAmount + hop.fee(finalAmount), finalExpiry + hop.cltvExpiryDelta, onion) - } - - /** - * Build an encrypted onion packet with the given final payload. - * - * @param hops the hops as computed by the router + extra routes from payment request, including ourselves in the first hop - * @param finalPayload payload data for the final node (amount, expiry, etc) - * @return a (firstAmount, firstExpiry, onion) tuple where: - * - firstAmount is the amount for the first htlc in the route - * - firstExpiry is the cltv expiry for the first htlc in the route - * - the onion to include in the HTLC - */ - fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload, payloadLength: Int?): Triple { - val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) - val nodes = hops.map { it.nextNodeId } - // BOLT 2 requires that associatedData == paymentHash - val onion = buildOnion(nodes, payloads, paymentHash, payloadLength) - return Triple(firstAmount, firstExpiry, onion) + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, invoice.paymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } /** - * Build the command to add an HTLC with the given final payload and using the provided hops. + * Build an encrypted trampoline onion packet when the final recipient is using a blinded path. + * The trampoline node will receive data from the invoice to allow them to pay the blinded path. + * The data revealed to the trampoline node doesn't leak anything about the recipient's identity. + * We only need a single trampoline node, who will find routes to the blinded paths. * - * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) + * @param invoice a Bolt12 invoice (blinded path data will be provided to the trampoline node). + * @param amount amount that should be received by the final recipient. + * @param expiry cltv expiry that should be received by the final recipient. + * @param hop the trampoline hop from the trampoline node to the recipient. */ - fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Pair { - val (firstAmount, firstExpiry, onion) = buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) - return Pair(ChannelCommand.Htlc.Add(firstAmount, paymentHash, firstExpiry, onion.packet, paymentId, commit = true), onion.sharedSecrets) + fun buildPacketToBlindedRecipient(invoice: Bolt12Invoice, amount: MilliSatoshi, expiry: CltvExpiry, hop: NodeHop): Triple { + val trampolineOnion = run { + var blindedPaths = invoice.blindedPaths + var trampolinePayload = PaymentOnion.RelayToBlindedPayload.create(amount, expiry, invoice.features, blindedPaths) + var trampolineOnion = buildOnion(listOf(hop.nodeId), listOf(trampolinePayload), invoice.paymentHash) + // Ensure that this onion can fit inside the outer 1300 bytes onion. The outer onion fields need ~150 bytes and we add some safety margin. + while (trampolineOnion.packet.payload.size() > 1000) { + blindedPaths = blindedPaths.dropLast(1) + trampolinePayload = PaymentOnion.RelayToBlindedPayload.create(amount, expiry, invoice.features, blindedPaths) + trampolineOnion = buildOnion(listOf(hop.nodeId), listOf(trampolinePayload), invoice.paymentHash) + } + trampolineOnion + } + val trampolineAmount = amount + hop.fee(amount) + val trampolineExpiry = expiry + hop.cltvExpiryDelta + // We generate a random secret to avoid leaking the invoice secret to the trampoline node. + val trampolinePaymentSecret = Lightning.randomBytes32() + val payload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet) + val paymentOnion = buildOnion(listOf(hop.nodeId), listOf(payload), invoice.paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(trampolineAmount, trampolineExpiry, paymentOnion) } fun buildHtlcFailure(nodeSecret: PrivateKey, paymentHash: ByteVector32, onion: OnionRoutingPacket, reason: ChannelCommand.Htlc.Settlement.Fail.Reason): Either { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt deleted file mode 100644 index ac4e94654..000000000 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/RouteCalculation.kt +++ /dev/null @@ -1,61 +0,0 @@ -package fr.acinq.lightning.payment - -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.Satoshi -import fr.acinq.bitcoin.utils.Either -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.channel.states.* -import fr.acinq.lightning.logging.LoggerFactory -import fr.acinq.lightning.logging.MDCLogger -import fr.acinq.lightning.utils.UUID -import fr.acinq.lightning.utils.msat - -class RouteCalculation(loggerFactory: LoggerFactory) { - - private val logger = loggerFactory.newLogger(this::class) - - data class Route(val amount: MilliSatoshi, val channel: Normal) - - data class ChannelBalance(val c: Normal) { - val balance: MilliSatoshi = c.commitments.availableBalanceForSend() - val capacity: Satoshi = c.commitments.latest.fundingAmount - } - - fun findRoutes(paymentId: UUID, amount: MilliSatoshi, channels: Map): Either> { - val logger = MDCLogger(logger, staticMdc = mapOf("paymentId" to paymentId, "amount" to amount)) - - val sortedChannels = channels.values.filterIsInstance().map { ChannelBalance(it) }.sortedBy { it.balance }.reversed() - if (sortedChannels.isEmpty()) { - val failure = when { - channels.values.any { it is Syncing || it is Offline } -> FinalFailure.ChannelNotConnected - channels.values.any { it is WaitForOpenChannel || it is WaitForAcceptChannel || it is WaitForFundingCreated || it is WaitForFundingSigned || it is WaitForFundingConfirmed || it is WaitForChannelReady } -> FinalFailure.ChannelOpening - channels.values.any { it is ShuttingDown || it is Negotiating || it is Closing || it is WaitForRemotePublishFutureCommitment } -> FinalFailure.ChannelClosing - // This may happen if adding an HTLC failed because we hit channel limits (e.g. max-accepted-htlcs) and we're retrying with this channel filtered out. - else -> FinalFailure.NoAvailableChannels - } - logger.warning { "no available channels: $failure" } - return Either.Left(failure) - } - - val filteredChannels = sortedChannels.filter { it.balance >= it.c.channelUpdate.htlcMinimumMsat } - var remaining = amount - val routes = mutableListOf() - for (channel in filteredChannels) { - val toSend = channel.balance.min(remaining) - routes.add(Route(toSend, channel.c)) - remaining -= toSend - if (remaining == 0.msat) { - break - } - } - - return if (remaining > 0.msat) { - logger.info { "insufficient balance: ${sortedChannels.joinToString { "${it.c.shortChannelId}->${it.balance}/${it.capacity}" }}" } - Either.Left(FinalFailure.InsufficientBalance) - } else { - logger.info { "routes found: ${routes.map { "${it.channel.shortChannelId}->${it.amount}" }}" } - Either.Right(routes) - } - } - -} \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt index bde0d2bc6..73c651239 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt @@ -9,10 +9,7 @@ import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.io.Output import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.flatMap -import fr.acinq.lightning.CltvExpiry -import fr.acinq.lightning.CltvExpiryDelta -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId +import fr.acinq.lightning.* import fr.acinq.lightning.payment.Bolt11Invoice import fr.acinq.lightning.payment.Bolt12Invoice import fr.acinq.lightning.utils.msat @@ -498,7 +495,7 @@ object PaymentOnion { } } - fun create(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice): RelayToNonTrampolinePayload = + fun create(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice, routingInfo: List): RelayToNonTrampolinePayload = RelayToNonTrampolinePayload( TlvStream( buildSet { @@ -508,7 +505,7 @@ object PaymentOnion { add(OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, totalAmount)) invoice.paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } add(OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features.toByteArray().toByteVector())) - add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(invoice.routingInfo.map { it.hints })) + add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(routingInfo.map { it.hints })) } ) ) @@ -538,14 +535,14 @@ object PaymentOnion { } } - fun create(amount: MilliSatoshi, expiry: CltvExpiry, invoice: Bolt12Invoice): RelayToBlindedPayload = + fun create(amount: MilliSatoshi, expiry: CltvExpiry, features: Features, blindedPaths: List): RelayToBlindedPayload = RelayToBlindedPayload( TlvStream( setOf( OnionPaymentPayloadTlv.AmountToForward(amount), OnionPaymentPayloadTlv.OutgoingCltv(expiry), - OnionPaymentPayloadTlv.OutgoingBlindedPaths(invoice.blindedPaths), - OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features.toByteArray().toByteVector()) + OnionPaymentPayloadTlv.OutgoingBlindedPaths(blindedPaths), + OnionPaymentPayloadTlv.InvoiceFeatures(features.toByteArray().toByteVector()) ) ) ) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt index c3b63ad94..7606a1eff 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt @@ -14,7 +14,6 @@ import fr.acinq.lightning.json.JsonSerializers import fr.acinq.lightning.logging.MDCLogger import fr.acinq.lightning.logging.mdc import fr.acinq.lightning.payment.OutgoingPaymentPacket -import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.serialization.Serialization import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.testLoggerFactory @@ -429,16 +428,11 @@ object TestsHelper { } fun makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long, paymentPreimage: ByteVector32 = randomBytes32(), paymentId: UUID = UUID.randomUUID()): Pair { - val paymentHash: ByteVector32 = Crypto.sha256(paymentPreimage).toByteVector32() + val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() val expiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() - val dummyUpdate = ChannelUpdate(ByteVector64.Zeroes, BlockHash(ByteVector32.Zeroes), ShortChannelId(144, 0, 0), 0, 0, 0, CltvExpiryDelta(1), 0.msat, 0.msat, 0, null) - val cmd = OutgoingPaymentPacket.buildCommand( - paymentId, - paymentHash, - listOf(ChannelHop(dummyKey, destination, dummyUpdate)), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), null) - ).first.copy(commit = false) + val payload = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), null) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destination), listOf(payload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + val cmd = ChannelCommand.Htlc.Add(amount, paymentHash, expiry, onion, paymentId, commit = false) return Pair(paymentPreimage, cmd) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt b/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt index 9e92686f3..af58cc686 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/io/peer/PeerTest.kt @@ -486,101 +486,52 @@ class PeerTest : LightningTestSuite() { @Test fun `payment between two nodes -- with disconnection`() = runSuspendTest { - // We create two channels between Alice and Bob to ensure that the payment is split in two parts. - val (aliceChan1, bobChan1) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, bobFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobPushAmount = 0.msat) - val (aliceChan2, bobChan2) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, bobFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobPushAmount = 0.msat) - val nodeParams = Pair(aliceChan1.staticParams.nodeParams, bobChan1.staticParams.nodeParams) + val (alice0, bob0) = TestsHelper.reachNormal() + val nodeParams = Pair(alice0.staticParams.nodeParams, bob0.staticParams.nodeParams) val walletParams = Pair( // Alice must declare Bob as her trampoline node to enable direct payments. - TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(nodeParams.second.nodeId, "bob.com", 9735)), + TestConstants.Alice.walletParams.copy(trampolineNode = NodeUri(bob0.staticParams.nodeParams.nodeId, "bob.com", 9735)), TestConstants.Bob.walletParams ) - // Bob sends a multipart payment to Alice. - val (alice, bob, alice2bob1, bob2alice1) = newPeers(this, nodeParams, walletParams, listOf(aliceChan1 to bobChan1, aliceChan2 to bobChan2), automateMessaging = false) + val (alice, bob, alice2bob1, bob2alice1) = newPeers(this, nodeParams, walletParams, listOf(alice0 to bob0), automateMessaging = false) val invoice = alice.createInvoice(randomBytes32(), 150_000_000.msat, Either.Left("test invoice"), null) bob.send(PayInvoice(UUID.randomUUID(), invoice.amount!!, LightningOutgoingPayment.Details.Normal(invoice))) - // Bob sends one HTLC on each channel. - val htlcs = listOf( - bob2alice1.expect(), - bob2alice1.expect(), - ) - assertEquals(2, htlcs.map { it.channelId }.toSet().size) - val commitSigsBob = listOf( - bob2alice1.expect(), - bob2alice1.expect(), - ) + // Bob sends an HTLC to Alice. + alice.forward(bob2alice1.expect()) + alice.forward(bob2alice1.expect()) - // We cross-sign the HTLC on the first channel. - run { - val htlc = htlcs.find { it.channelId == aliceChan1.channelId } - assertNotNull(htlc) - alice.forward(htlc) - val commitSigBob = commitSigsBob.find { it.channelId == aliceChan1.channelId } - assertNotNull(commitSigBob) - alice.forward(commitSigBob) - bob.forward(alice2bob1.expect()) - bob.forward(alice2bob1.expect()) - alice.forward(bob2alice1.expect()) - } - // We start cross-signing the HTLC on the second channel. - run { - val htlc = htlcs.find { it.channelId == aliceChan2.channelId } - assertNotNull(htlc) - alice.forward(htlc) - val commitSigBob = commitSigsBob.find { it.channelId == aliceChan2.channelId } - assertNotNull(commitSigBob) - alice.forward(commitSigBob) - bob.forward(alice2bob1.expect()) - bob.forward(alice2bob1.expect()) - bob2alice1.expect() // Alice doesn't receive Bob's revocation. - } + // We start cross-signing the HTLC. + bob.forward(alice2bob1.expect()) + bob.forward(alice2bob1.expect()) + bob2alice1.expect() // Alice doesn't receive Bob's revocation. - // We disconnect before Alice receives Bob's revocation on the second channel. + // We disconnect before Alice receives Bob's revocation. alice.disconnect() alice.send(Disconnected) bob.disconnect() bob.send(Disconnected) // On reconnection, Bob retransmits its revocation. - val (_, _, alice2bob2, bob2alice2) = connect(this, connectionId = 1, alice, bob, channelsCount = 2, expectChannelReady = false, automateMessaging = false) + val (_, _, alice2bob2, bob2alice2) = connect(this, connectionId = 1, alice, bob, channelsCount = 1, expectChannelReady = false, automateMessaging = false) alice.forward(bob2alice2.expect(), connectionId = 1) // Alice has now received the complete payment and fulfills it. - val fulfills = listOf( - alice2bob2.expect(), - alice2bob2.expect(), - ) - val commitSigsAlice = listOf( - alice2bob2.expect(), - alice2bob2.expect(), - ) + bob.forward(alice2bob2.expect(), connectionId = 1) + bob.forward(alice2bob2.expect(), connectionId = 1) + alice.forward(bob2alice2.expect(), connectionId = 1) + bob2alice2.expect() // Alice doesn't receive Bob's signature. - // We fulfill the first HTLC. - run { - val fulfill = fulfills.find { it.channelId == aliceChan1.channelId } - assertNotNull(fulfill) - bob.forward(fulfill, connectionId = 1) - val commitSigAlice = commitSigsAlice.find { it.channelId == aliceChan1.channelId } - assertNotNull(commitSigAlice) - bob.forward(commitSigAlice, connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - bob.forward(alice2bob2.expect(), connectionId = 1) - } + // We disconnect before Alice receives Bob's signature. + alice.disconnect() + alice.send(Disconnected) + bob.disconnect() + bob.send(Disconnected) - // We fulfill the second HTLC. - run { - val fulfill = fulfills.find { it.channelId == aliceChan2.channelId } - assertNotNull(fulfill) - bob.forward(fulfill, connectionId = 1) - val commitSigAlice = commitSigsAlice.find { it.channelId == aliceChan2.channelId } - assertNotNull(commitSigAlice) - bob.forward(commitSigAlice, connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - alice.forward(bob2alice2.expect(), connectionId = 1) - bob.forward(alice2bob2.expect(), connectionId = 1) - } + // On reconnection, Bob retransmits its signature. + val (_, _, alice2bob3, bob2alice3) = connect(this, connectionId = 2, alice, bob, channelsCount = 1, expectChannelReady = false, automateMessaging = false) + alice.forward(bob2alice3.expect(), connectionId = 2) + bob.forward(alice2bob3.expect(), connectionId = 2) assertEquals(invoice.amount, alice.db.payments.getIncomingPayment(invoice.paymentHash)?.received?.amount) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt index 83b6a422a..cb0ee6790 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt @@ -16,8 +16,6 @@ import fr.acinq.lightning.db.IncomingPaymentsDb import fr.acinq.lightning.io.AddLiquidityForIncomingPayment import fr.acinq.lightning.io.SendOnTheFlyFundingMessage import fr.acinq.lightning.io.WrappedChannelCommand -import fr.acinq.lightning.router.ChannelHop -import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.tests.utils.runSuspendTest @@ -344,15 +342,16 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val (paymentHandler, incomingPayment, paymentSecret) = createFixture(defaultAmount) checkDbPayment(incomingPayment, paymentHandler.db) val willAddHtlc = run { - // We simulate a trampoline-relay with a dummy channel hop between the liquidity provider and the wallet. - val (amount, expiry, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - incomingPayment.paymentHash, - listOf(NodeHop(TestConstants.Alice.nodeParams.nodeId, TestConstants.Bob.nodeParams.nodeId, CltvExpiryDelta(144), 0.msat)), - makeMppPayload(defaultAmount, defaultAmount, paymentSecret), - null - ) - assertTrue(trampolineOnion.packet.payload.size() < 500) - makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, randomBytes32(), trampolineOnion.packet)) + // We simulate a trampoline-relay: the trampoline node will relay the a payment onion with a dummy channel hop. + val trampolinePayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret) + val trampolineOnion = OutgoingPaymentPacket.buildOnion( + nodes = listOf(TestConstants.Bob.nodeParams.nodeId), + payloads = listOf(trampolinePayload), + associatedData = incomingPayment.paymentHash, + ).packet + assertTrue(trampolineOnion.payload.size() < 500) + val finalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolinePayload.amount, trampolinePayload.totalAmount, trampolinePayload.expiry, randomBytes32(), trampolineOnion) + makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, finalPayload) } val result = paymentHandler.process(willAddHtlc, Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, TestConstants.fundingRates) assertIs(result) @@ -395,15 +394,16 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val (paymentHandler, incomingPayment, _) = createFixture(defaultAmount) checkDbPayment(incomingPayment, paymentHandler.db) val willAddHtlc = run { - // We simulate a trampoline-relay with a dummy channel hop between the liquidity provider and the wallet. - val (amount, expiry, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - incomingPayment.paymentHash, - listOf(NodeHop(TestConstants.Alice.nodeParams.nodeId, TestConstants.Bob.nodeParams.nodeId, CltvExpiryDelta(144), 0.msat)), - makeMppPayload(defaultAmount, defaultAmount, randomBytes32()), - null - ) - assertTrue(trampolineOnion.packet.payload.size() < 500) - makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amount, amount, expiry, randomBytes32(), trampolineOnion.packet)) + // We simulate a trampoline-relay: the trampoline node will relay the a payment onion with a dummy channel hop. + val trampolinePayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()) + val trampolineOnion = OutgoingPaymentPacket.buildOnion( + nodes = listOf(TestConstants.Bob.nodeParams.nodeId), + payloads = listOf(trampolinePayload), + associatedData = incomingPayment.paymentHash, + ).packet + assertTrue(trampolineOnion.payload.size() < 500) + val finalPayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolinePayload.amount, trampolinePayload.totalAmount, trampolinePayload.expiry, randomBytes32(), trampolineOnion) + makeWillAddHtlc(paymentHandler, incomingPayment.paymentHash, finalPayload) } val result = paymentHandler.process(willAddHtlc, Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, TestConstants.fundingRates) assertIs(result) @@ -1780,27 +1780,9 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { val defaultAmount = 150_000_000.msat val feeCreditFeatures = Features(Feature.ExperimentalSplice to FeatureSupport.Optional, Feature.OnTheFlyFunding to FeatureSupport.Optional, Feature.FundingFeeCredit to FeatureSupport.Optional) - private fun channelHops(destination: PublicKey): List { - val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() - val dummyUpdate = ChannelUpdate( - signature = ByteVector64.Zeroes, - chainHash = BlockHash(ByteVector32.Zeroes), - shortChannelId = ShortChannelId(144, 0, 0), - timestampSeconds = 0, - messageFlags = 0, - channelFlags = 0, - cltvExpiryDelta = CltvExpiryDelta(144), - htlcMinimumMsat = 1000.msat, - feeBaseMsat = 1.msat, - feeProportionalMillionths = 10, - htlcMaximumMsat = null - ) - val channelHop = ChannelHop(dummyKey, destination, dummyUpdate) - return listOf(channelHop) - } - private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): ChannelCommand.Htlc.Add { - return OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destination), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + return ChannelCommand.Htlc.Add(finalPayload.amount, paymentHash, finalPayload.expiry, onion, UUID.randomUUID(), commit = true) } private fun makeUpdateAddHtlc( @@ -1816,9 +1798,9 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { null -> destination.nodeParams.nodeId else -> RouteBlinding.derivePrivateKey(destination.nodeParams.nodePrivateKey, blinding).publicKey() } - val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destinationNodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destinationNodeId), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet val amount = finalPayload.amount - (fundingFee?.amount ?: 0.msat) - return UpdateAddHtlc(channelId, id, amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet, blinding, fundingFee) + return UpdateAddHtlc(channelId, id, amount, paymentHash, finalPayload.expiry, onion, blinding, fundingFee) } private fun makeWillAddHtlc(destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload, blinding: PublicKey? = null): WillAddHtlc { @@ -1826,8 +1808,8 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { null -> destination.nodeParams.nodeId else -> RouteBlinding.derivePrivateKey(destination.nodeParams.nodePrivateKey, blinding).publicKey() } - val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destinationNodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) - return WillAddHtlc(destination.nodeParams.chainHash, randomBytes32(), finalPayload.amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet, blinding) + val onion = OutgoingPaymentPacket.buildOnion(listOf(destinationNodeId), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength).packet + return WillAddHtlc(destination.nodeParams.chainHash, randomBytes32(), finalPayload.amount, paymentHash, finalPayload.expiry, onion, blinding) } private fun makeMppPayload( diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt index 01cb3c0ed..ad6f5ca3c 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt @@ -1,13 +1,14 @@ package fr.acinq.lightning.payment -import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Chain +import fr.acinq.bitcoin.Crypto +import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.utils.Either import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.Lightning.randomKey -import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.channel.* -import fr.acinq.lightning.channel.states.Normal import fr.acinq.lightning.channel.states.Offline import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.Sphinx @@ -15,11 +16,9 @@ import fr.acinq.lightning.db.InMemoryPaymentsDb import fr.acinq.lightning.db.LightningOutgoingPayment import fr.acinq.lightning.db.OutgoingPaymentsDb import fr.acinq.lightning.io.PayInvoice -import fr.acinq.lightning.io.WrappedChannelCommand import fr.acinq.lightning.tests.TestConstants import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.tests.utils.runSuspendTest -import fr.acinq.lightning.transactions.CommitmentSpec import fr.acinq.lightning.utils.* import fr.acinq.lightning.wire.* import kotlin.test.* @@ -57,10 +56,25 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { Feature.RouteBlinding to FeatureSupport.Optional, ) // The following invoice requires payment_metadata. - val invoice1 = - Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.PaymentMetadata to FeatureSupport.Mandatory)) + val invoice1 = run { + val unsupportedFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Mandatory + ) + Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = unsupportedFeatures) + } // The following invoice requires unknown feature bit 188. - val invoice2 = Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = Features(mapOf(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory), setOf(UnknownFeature(188)))) + val invoice2 = run { + val unsupportedFeatures = Features( + mapOf( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory + ), + setOf(UnknownFeature(188)) + ) + Bolt11InvoiceTestsCommon.createInvoiceUnsafe(features = unsupportedFeatures) + } for (invoice in listOf(invoice1, invoice2)) { val outgoingPaymentHandler = OutgoingPaymentHandler(alice.staticParams.nodeParams.copy(features = features), defaultWalletParams, InMemoryPaymentsDb()) val payment = PayInvoice(UUID.randomUUID(), 15_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) @@ -111,16 +125,16 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { @Test fun `invoice already paid`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = 100_000.msat, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val (channelId, add) = filterAddHtlcCommands(result).first() - outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId, add, randomBytes32())) as OutgoingPaymentHandler.Success + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) + outgoingPaymentHandler.processAddSettledFulfilled(createRemoteFulfill(channelId, add, randomBytes32())) as OutgoingPaymentHandler.Success val duplicatePayment = payment.copy(paymentId = UUID.randomUUID()) - val error = outgoingPaymentHandler.sendPayment(duplicatePayment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val error = outgoingPaymentHandler.sendPayment(duplicatePayment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure assertEquals(error.failure.reason, FinalFailure.AlreadyPaid) } @@ -134,14 +148,13 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 1 of 2: this should work because we're still under the maxAcceptedHtlcs. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result is OutgoingPaymentHandler.Progress } - - val progress = result as OutgoingPaymentHandler.Progress - assertEquals(1, result.actions.size) - val processResult = alice.processSameState(progress.actions.first().channelCommand) - assertTrue { processResult.second.filterIsInstance().isEmpty() } - alice = processResult.first + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) + + val (alice1, actions1) = alice.processSameState(progress.actions.first().channelCommand) + assertTrue(actions1.filterIsInstance().isEmpty()) + alice = alice1 assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) @@ -154,21 +167,18 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 2 of 2: this should exceed the configured maxAcceptedHtlcs. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 50_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result1 is OutgoingPaymentHandler.Progress } + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) - val progress = result1 as OutgoingPaymentHandler.Progress - assertEquals(1, result1.actions.size) val cmdAdd = progress.actions.first().channelCommand - val processResult = alice.processSameState(cmdAdd) - alice = processResult.first - - val addFailure = processResult.second.filterIsInstance().firstOrNull() + val (_, actions1) = alice.processSameState(cmdAdd) + val addFailure = actions1.filterIsInstance().firstOrNull() assertNotNull(addFailure) // Now the channel error gets sent back to the OutgoingPaymentHandler. - val result2 = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure, mapOf(alice.channelId to alice.state)) + val failure = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, listOf(Either.Left(TooManyAcceptedHtlcs(alice.channelId, 1))))) - assertFailureEquals(result2 as OutgoingPaymentHandler.Failure, expected) + assertFailureEquals(failure as OutgoingPaymentHandler.Failure, expected) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) @@ -179,22 +189,22 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { fun `channel restrictions -- maxHtlcValueInFlight`() = runSuspendTest { var (alice, _) = TestsHelper.reachNormal() val maxHtlcValueInFlightMsat = 150_000L - alice = - alice.copy(state = alice.state.copy(commitments = alice.commitments.copy(params = alice.commitments.params.copy(remoteParams = alice.commitments.params.remoteParams.copy(maxHtlcValueInFlightMsat = maxHtlcValueInFlightMsat))))) + alice = alice.copy( + state = alice.state.copy(commitments = alice.commitments.copy(params = alice.commitments.params.copy(remoteParams = alice.commitments.params.remoteParams.copy(maxHtlcValueInFlightMsat = maxHtlcValueInFlightMsat)))) + ) val outgoingPaymentHandler = OutgoingPaymentHandler(alice.staticParams.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) run { // Send payment 1 of 2: this should work because we're still under the maxHtlcValueInFlightMsat. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result is OutgoingPaymentHandler.Progress } - - val progress = result as OutgoingPaymentHandler.Progress - assertEquals(1, result.actions.size) - val processResult = alice.processSameState(progress.actions.first().channelCommand) - assertTrue { processResult.second.filterIsInstance().isEmpty() } - alice = processResult.first + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) + + val (alice1, actions1) = alice.processSameState(progress.actions.first().channelCommand) + assertTrue(actions1.filterIsInstance().isEmpty()) + alice = alice1 assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) @@ -207,21 +217,18 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // Send payment 2 of 2: this should exceed the configured maxHtlcValueInFlightMsat. val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 100_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) - assertTrue { result1 is OutgoingPaymentHandler.Progress } + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), alice.currentBlockHeight) + assertIs(progress) + assertEquals(1, progress.actions.size) - val progress = result1 as OutgoingPaymentHandler.Progress - assertEquals(1, result1.actions.size) val cmdAdd = progress.actions.first().channelCommand - val processResult = alice.processSameState(cmdAdd) - alice = processResult.first - - val addFailure = processResult.second.filterIsInstance().firstOrNull() + val (_, actions1) = alice.processSameState(cmdAdd) + val addFailure = actions1.filterIsInstance().firstOrNull() assertNotNull(addFailure) // Now the channel error gets sent back to the OutgoingPaymentHandler. - val result2 = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure, mapOf(alice.channelId to alice.state)) + val failure = outgoingPaymentHandler.processAddFailed(alice.channelId, addFailure) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.NoAvailableChannels, listOf(Either.Left(HtlcValueTooHighInFlight(alice.channelId, maxHtlcValueInFlightMsat.toULong(), 200_000.msat))))) - assertFailureEquals(result2 as OutgoingPaymentHandler.Failure, expected) + assertFailureEquals(failure as OutgoingPaymentHandler.Failure, expected) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) @@ -229,7 +236,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- single part`() = runSuspendTest { + fun `successful first attempt`() = runSuspendTest { val recipientKey = randomKey() val invoice = makeInvoice(amount = 195_000.msat, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 200_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) // we slightly overpay the invoice amount @@ -237,7 +244,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- single part + backwards-compatibility trampoline bit`() = runSuspendTest { + fun `successful first attempt -- backwards-compatibility trampoline bit`() = runSuspendTest { val recipientKey = randomKey() val invoice = run { // Invoices generated by older versions of wallets based on lightning-kmp will generate invoices with the following feature bits. @@ -262,19 +269,17 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } private suspend fun testSinglePartTrampolinePayment(payment: PayInvoice, invoice: Bolt11Invoice, recipientKey: PrivateKey) { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(3.sat, 10_000, CltvExpiryDelta(144)))) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(1, adds.size) - val (channelId, add) = adds.first() + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) assertEquals(205_000.msat, add.amount) assertEquals(payment.paymentHash, add.paymentHash) // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) assertEquals(205_000.msat, outerB.amount) assertEquals(205_000.msat, outerB.totalAmount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) @@ -291,7 +296,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { assertEquals(invoice.paymentSecret, payloadC.paymentSecret) val preimage = randomBytes32() - val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId, add, preimage)) as OutgoingPaymentHandler.Success + val success = outgoingPaymentHandler.processAddSettledFulfilled(createRemoteFulfill(channelId, add, preimage)) as OutgoingPaymentHandler.Success assertEquals(preimage, success.preimage) assertEquals(5_000.msat, success.payment.fees) assertEquals(200_000.msat, success.payment.recipientAmount) @@ -310,65 +315,8 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } @Test - fun `successful first attempt -- multiple parts`() = runSuspendTest { - val channels = makeChannels() - val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(10.sat, 0, CltvExpiryDelta(144)))) - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) - val recipientKey = randomKey() - val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(310_000.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) - assertEquals(payloadC.amount, payloadC.totalAmount) - assertEquals(invoice.paymentSecret, payloadC.paymentSecret) - } - - val preimage = randomBytes32() - val (channelId1, add1) = adds[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val fulfill2 = ChannelAction.ProcessCmdRes.AddSettledFulfill(add2.paymentId, makeUpdateAddHtlc(channelId2, add2), ChannelAction.HtlcResult.Fulfill.OnChainFulfill(preimage)) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(10_000.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(310_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 2) - } - - @Test - fun `successful first attempt -- multiple parts + legacy recipient`() = runSuspendTest { - val channels = makeChannels() + fun `successful first attempt -- legacy recipient`() = runSuspendTest { + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(10.sat, 0, CltvExpiryDelta(144)))) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) val recipientKey = randomKey() @@ -376,51 +324,46 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = false, privKey = recipientKey, extraHops = extraHops) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, _) = PaymentPacketTestsCommon.decryptRelayToNonTrampolinePayload(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(310_000.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - assertEquals(invoice.paymentSecret, innerB.paymentSecret) - assertEquals(invoice.features.toByteArray().toByteVector(), innerB.invoiceFeatures) - assertFalse(innerB.invoiceRoutingInfo.isEmpty()) - assertEquals(invoice.routingInfo.map { it.hints }, innerB.invoiceRoutingInfo) - } + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) + assertEquals(310_000.msat, add.amount) + assertEquals(payment.paymentHash, add.paymentHash) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB) = PaymentPacketTestsCommon.decryptRelayToLegacy(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add.amount, outerB.amount) + assertEquals(310_000.msat, outerB.totalAmount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(144) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) + assertEquals(300_000.msat, innerB.amountToForward) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) + assertEquals(payment.recipient, innerB.outgoingNodeId) + assertEquals(invoice.paymentSecret, innerB.paymentSecret) + assertEquals(invoice.features.toByteArray().toByteVector(), innerB.invoiceFeatures) + assertFalse(innerB.invoiceRoutingInfo.isEmpty()) + assertEquals(invoice.routingInfo.map { it.hints }, innerB.invoiceRoutingInfo) val preimage = randomBytes32() - val (channelId1, add1) = adds[0] - val success1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId1, add1, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val success2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(channelId2, add2, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(10_000.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(310_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val success = outgoingPaymentHandler.processAddSettledFulfilled(createRemoteFulfill(channelId, add, preimage)) + assertNotNull(success) + assertEquals(preimage, success.preimage) + assertEquals(10_000.msat, success.payment.fees) + assertEquals(300_000.msat, success.payment.recipientAmount) + assertEquals(invoice.nodeId, success.payment.recipient) + assertEquals(invoice.paymentHash, success.payment.paymentHash) + assertEquals(invoice, (success.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(310_000.msat, success.payment.parts.first().amount) + val status = success.payment.parts.first().status + assertIs(status) + assertEquals(preimage, status.preimage) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 2) + assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 10_000.msat, partsCount = 1) } @Test fun `successful first attempt -- random final expiry`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy(trampolineFees = listOf(TrampolineFees(25.sat, 0, CltvExpiryDelta(48)))) val recipientExpiryParams = RecipientCltvExpiryParams(CltvExpiryDelta(144), CltvExpiryDelta(288)) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams.copy(paymentRecipientExpiryParams = recipientExpiryParams), walletParams, InMemoryPaymentsDb()) @@ -428,230 +371,126 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - - adds.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - val minFinalExpiry = CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA + recipientExpiryParams.min - assertTrue(minFinalExpiry + CltvExpiryDelta(48) <= outerB.expiry) - assertTrue(minFinalExpiry <= innerB.outgoingCltv) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertTrue(minFinalExpiry <= payloadC.expiry) - assertEquals(innerB.outgoingCltv, payloadC.expiry) - } - } - - @Test - fun `successful first attempt -- multiple parts + recipient is our peer`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - // The invoice comes from Bob, our direct peer (and trampoline node). - val preimage = randomBytes32() - val incomingPaymentHandler = IncomingPaymentHandler(TestConstants.Bob.nodeParams, InMemoryPaymentsDb()) - val invoice = incomingPaymentHandler.createInvoice(preimage, amount = null, Either.Left("phoenix to phoenix"), listOf()) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(result) - assertEquals(2, adds.size) - assertEquals(300_000.msat, adds.map { it.second.amount }.sum()) - adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - adds.forEach { (channelId, add) -> - // Bob should receive the right final information. - val payloadB = IncomingPaymentPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! - assertIs(payloadB) - assertEquals(add.amount, payloadB.amount) - assertEquals(300_000.msat, payloadB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + TestConstants.Alice.nodeParams.minFinalCltvExpiryDelta, payloadB.expiry) - assertEquals(invoice.paymentSecret, payloadB.paymentSecret) - } - - // Bob receives these 2 HTLCs. - val process1 = incomingPaymentHandler.process(makeUpdateAddHtlc(adds[0].first, adds[0].second, 3), Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, remoteFundingRates = null) - assertTrue(process1 is IncomingPaymentHandler.ProcessAddResult.Pending) - val process2 = incomingPaymentHandler.process(makeUpdateAddHtlc(adds[1].first, adds[1].second, 5), Features.empty, TestConstants.defaultBlockHeight, TestConstants.feeratePerKw, remoteFundingRates = null) - assertTrue(process2 is IncomingPaymentHandler.ProcessAddResult.Accepted) - val fulfills = process2.actions.filterIsInstance().mapNotNull { it.channelCommand as? ChannelCommand.Htlc.Settlement.Fulfill } - assertEquals(2, fulfills.size) - - // Alice receives the fulfill for these 2 HTLCs. - val (channelId1, add1) = adds[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - val (channelId2, add2) = adds[1] - val fulfill2 = ChannelAction.ProcessCmdRes.AddSettledFulfill(add2.paymentId, makeUpdateAddHtlc(channelId2, add2), ChannelAction.HtlcResult.Fulfill.OnChainFulfill(preimage)) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(0.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(TestConstants.Bob.nodeParams.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(300_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val result = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val (channelId, add) = findAddHtlcCommand(result) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add.amount, outerB.amount) + val minFinalExpiry = CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA + recipientExpiryParams.min + assertTrue(minFinalExpiry + CltvExpiryDelta(48) <= outerB.expiry) + assertTrue(minFinalExpiry <= innerB.outgoingCltv) - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 0.msat, partsCount = 2) + // The recipient should receive the right amount and expiry. + val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! + val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! + assertEquals(300_000.msat, payloadC.amount) + assertTrue(minFinalExpiry <= payloadC.expiry) + assertEquals(innerB.outgoingCltv, payloadC.expiry) } @Test fun `successful second attempt`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val recipientKey = randomKey() val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = recipientKey) val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(2, adds1.size) - assertEquals(300_000.msat, adds1.map { it.second.amount }.sum()) + val progress1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress1) + val (channelId1, add1) = findAddHtlcCommand(progress1) + assertEquals(alice.channelId, channelId1) + assertEquals(300_000.msat, add1.amount) // This first attempt fails because fees are too low. val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail1 = outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) - assertNull(fail1) - val progress2 = outgoingPaymentHandler.processAddSettled(adds1[1].first, createRemoteFailure(adds1[1].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(progress2) - assertEquals(2, adds2.size) - assertEquals(301_030.msat, adds2.map { it.second.amount }.sum()) - adds2.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } - adds2.forEach { (channelId, add) -> - // The trampoline node should receive the right forwarding information. - val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptNodeRelay(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey) - assertEquals(add.amount, outerB.amount) - assertEquals(301_030.msat, outerB.totalAmount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(576) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) - assertEquals(300_000.msat, innerB.amountToForward) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) - assertEquals(payment.recipient, innerB.outgoingNodeId) - - // The recipient should receive the right amount and expiry. - val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! - val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! - assertEquals(300_000.msat, payloadC.amount) - assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) - assertEquals(payloadC.amount, payloadC.totalAmount) - assertEquals(invoice.paymentSecret, payloadC.paymentSecret) - } + val progress2 = outgoingPaymentHandler.processAddSettledFailed(channelId1, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress2) + val (channelId2, add2) = findAddHtlcCommand(progress2) + assertEquals(channelId1, channelId2) + assertEquals(301_030.msat, add2.amount) + assertEquals(payment.paymentHash, add2.paymentHash) + // The trampoline node should receive the right forwarding information. + val (outerB, innerB, packetC) = PaymentPacketTestsCommon.decryptRelayToTrampoline(makeUpdateAddHtlc(channelId2, add2), TestConstants.Bob.nodeParams.nodePrivateKey) + assertEquals(add2.amount, outerB.amount) + assertEquals(301_030.msat, outerB.totalAmount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + CltvExpiryDelta(576) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, outerB.expiry) + assertEquals(300_000.msat, innerB.amountToForward) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, innerB.outgoingCltv) + assertEquals(payment.recipient, innerB.outgoingNodeId) + + // The recipient should receive the right amount and expiry. + val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC).right!! + val payloadC = PaymentOnion.FinalPayload.Standard.read(payloadBytesC.payload).right!! + assertEquals(300_000.msat, payloadC.amount) + assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) + assertEquals(payloadC.amount, payloadC.totalAmount) + assertEquals(invoice.paymentSecret, payloadC.paymentSecret) val dbPayment1 = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) assertNotNull(dbPayment1) - assertTrue(dbPayment1.status is LightningOutgoingPayment.Status.Pending) - assertEquals(2, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Failed }.size) - assertEquals(2, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Pending }.size) + assertIs(dbPayment1.status) + assertEquals(1, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Failed }.size) + assertEquals(1, dbPayment1.parts.filter { it.status is LightningOutgoingPayment.Part.Status.Pending }.size) // The second attempt succeeds. val preimage = randomBytes32() - val success1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds2[0].first, adds2[0].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - assertNotNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - val success2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds2[1].first, adds2[1].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(1_030.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(invoice.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(301_030.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) + val success = outgoingPaymentHandler.processAddSettledFulfilled(createRemoteFulfill(channelId2, add2, preimage)) + assertIs(success) + assertEquals(preimage, success.preimage) + assertEquals(1_030.msat, success.payment.fees) + assertEquals(300_000.msat, success.payment.recipientAmount) + assertEquals(invoice.nodeId, success.payment.recipient) + assertEquals(invoice.paymentHash, success.payment.paymentHash) + assertEquals(invoice, (success.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(301_030.msat, success.payment.parts.first().amount) + val status = success.payment.parts.first().status + assertIs(status) + assertEquals(preimage, status.preimage) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) val dbPayment2 = outgoingPaymentHandler.db.getLightningOutgoingPayment(payment.paymentId) assertNotNull(dbPayment2) - assertTrue(dbPayment2.status is LightningOutgoingPayment.Status.Completed.Succeeded.OffChain) - assertEquals(2, dbPayment2.parts.size) + assertIs(dbPayment2.status) + assertEquals(1, dbPayment2.parts.size) assertTrue(dbPayment2.parts.all { it.status is LightningOutgoingPayment.Part.Status.Succeeded }) } - @Test - fun `successful second attempt -- recipient is our peer`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - // The invoice comes from Bob, our direct peer (and trampoline node). - val invoice = makeInvoice(amount = null, supportsTrampoline = true, privKey = TestConstants.Bob.nodeParams.nodePrivateKey) - val payment = PayInvoice(UUID.randomUUID(), 300_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val result1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(result1) - assertEquals(2, adds1.size) - - // The first attempt fails because of a local channel error. - val result2 = outgoingPaymentHandler.processAddFailed(adds1[1].first, ChannelAction.ProcessCmdRes.AddFailed(adds1[1].second, TooManyAcceptedHtlcs(adds1[1].first, 10), null), channels) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(result2) - assertEquals(1, adds2.size) - - // The other HTLCs succeed. - val preimage = randomBytes32() - val (channelId1, add1) = adds1[0] - val fulfill1 = createRemoteFulfill(channelId1, add1, preimage) - val success1 = outgoingPaymentHandler.processAddSettled(fulfill1) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), success1) - - val (channelId2, add2) = adds2[0] - val fulfill2 = createRemoteFulfill(channelId2, add2, preimage) - val success2 = outgoingPaymentHandler.processAddSettled(fulfill2) as OutgoingPaymentHandler.Success - assertEquals(preimage, success2.preimage) - assertEquals(0.msat, success2.payment.fees) - assertEquals(300_000.msat, success2.payment.recipientAmount) - assertEquals(TestConstants.Bob.nodeParams.nodeId, success2.payment.recipient) - assertEquals(invoice.paymentHash, success2.payment.paymentHash) - assertEquals(invoice, (success2.payment.details as LightningOutgoingPayment.Details.Normal).paymentRequest) - assertEquals(preimage, (success2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) - assertEquals(2, success2.payment.parts.size) - assertEquals(300_000.msat, success2.payment.parts.map { it.amount }.sum()) - assertEquals(setOf(preimage), success2.payment.parts.map { (it.status as LightningOutgoingPayment.Part.Status.Succeeded).preimage }.toSet()) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 300_000.msat, fees = 0.msat, partsCount = 2) - } - @Test fun `insufficient funds when retrying with higher fees`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal(aliceFundingAmount = 100_000.sat, alicePushAmount = 0.msat, bobFundingAmount = 0.sat, bobPushAmount = 0.msat) + assertTrue(83_500_000.msat < alice.commitments.availableBalanceForSend()) + assertTrue(alice.commitments.availableBalanceForSend() < 84_000_000.msat) val walletParams = defaultWalletParams.copy( trampolineFees = listOf( - TrampolineFees(10.sat, 0, CltvExpiryDelta(144)), TrampolineFees(100.sat, 0, CltvExpiryDelta(144)), + TrampolineFees(1000.sat, 0, CltvExpiryDelta(144)), ) ) val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, walletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 650_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) + val payment = PayInvoice(UUID.randomUUID(), 83_000_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(4, adds1.size) - assertEquals(660_000.msat, adds1.map { it.second.amount }.sum()) + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + val (_, add1) = findAddHtlcCommand(progress) + assertEquals(83_100_000.msat, add1.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - assertNull(outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - assertNull(outgoingPaymentHandler.processAddSettled(adds1[1].first, createRemoteFailure(adds1[1].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - assertNull(outgoingPaymentHandler.processAddSettled(adds1[2].first, createRemoteFailure(adds1[2].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight)) - val fail = outgoingPaymentHandler.processAddSettled(adds1[3].first, createRemoteFailure(adds1[3].second, attempt, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val fail = outgoingPaymentHandler.processAddSettledFailed(alice.channelId, createRemoteFailure(add1, attempt, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.InsufficientBalance, listOf(Either.Right(TrampolineFeeInsufficient)))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 4) + assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 1) } @Test fun `retries exhausted`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val walletParams = defaultWalletParams.copy( trampolineFees = listOf( TrampolineFees(10.sat, 0, CltvExpiryDelta(144)), @@ -662,20 +501,21 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 220_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds1 = filterAddHtlcCommands(progress1) - assertEquals(1, adds1.size) - assertEquals(230_000.msat, adds1.map { it.second.amount }.sum()) + val progress1 = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress1) + val (_, add1) = findAddHtlcCommand(progress1) + assertEquals(230_000.msat, add1.amount) val attempt1 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val progress2 = outgoingPaymentHandler.processAddSettled(adds1[0].first, createRemoteFailure(adds1[0].second, attempt1, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds2 = filterAddHtlcCommands(progress2) - assertEquals(1, adds2.size) - assertEquals(240_000.msat, adds2.map { it.second.amount }.sum()) + val progress2 = outgoingPaymentHandler.processAddSettledFailed(alice.channelId, createRemoteFailure(add1, attempt1, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress2) + val (_, add2) = findAddHtlcCommand(progress2) + assertEquals(240_000.msat, add2.amount) val attempt2 = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(adds2[0].first, createRemoteFailure(adds2[0].second, attempt2, TrampolineFeeInsufficient), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure - val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeInsufficient)))) + val fail = outgoingPaymentHandler.processAddSettledFailed(alice.channelId, createRemoteFailure(add2, attempt2, TrampolineFeeInsufficient), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) + val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(FinalFailure.RetryExhausted, listOf(Either.Right(TrampolineFeeInsufficient), Either.Right(TrampolineFeeInsufficient)))) assertFailureEquals(expected, fail) assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) @@ -689,18 +529,19 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { Pair(IncorrectOrUnknownPaymentDetails(50_000.msat, TestConstants.defaultBlockHeight.toLong()), FinalFailure.UnknownError) ) fatalFailures.forEach { (remoteFailure, userFailure) -> - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 50_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(1, adds.size) - assertEquals(50_000.msat, adds.map { it.second.amount }.sum()) + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + val (_, add) = findAddHtlcCommand(progress) + assertEquals(50_000.msat, add.amount) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val fail = outgoingPaymentHandler.processAddSettled(adds[0].first, createRemoteFailure(adds[0].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure + val fail = outgoingPaymentHandler.processAddSettledFailed(alice.channelId, createRemoteFailure(add, attempt, remoteFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) val expected = OutgoingPaymentHandler.Failure(payment, OutgoingPaymentFailure(userFailure, listOf(Either.Right(remoteFailure)))) assertFailureEquals(expected, fail) @@ -709,212 +550,74 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } } - private suspend fun testLocalChannelFailures(invoice: Bolt11Invoice) { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val payment = PayInvoice(UUID.randomUUID(), 5_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - var progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - assertEquals(1, progress.actions.size) - assertEquals(5_000.msat, filterAddHtlcCommands(progress).map { it.second.amount }.sum()) - - // Channels fail, so we retry with different channels, without raising the fees. - val localFailures = listOf( - { channelId: ByteVector32 -> TooManyAcceptedHtlcs(channelId, 15) }, - { channelId: ByteVector32 -> InsufficientFunds(channelId, 5_000.msat, 1.sat, 20.sat, 1.sat) }, - { channelId: ByteVector32 -> HtlcValueTooHighInFlight(channelId, 150_000U, 155_000.msat) }, - { channelId: ByteVector32 -> ForbiddenDuringSplice(channelId, "update-add-htlc") } - ) - localFailures.forEach { localFailure -> - val (channelId, add) = filterAddHtlcCommands(progress).first() - progress = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, localFailure(channelId), null), channels) as OutgoingPaymentHandler.Progress - assertEquals(5_000.msat, add.amount) - } - - // The last channel fails: we don't have any channels available to retry. - val (channelId, add) = filterAddHtlcCommands(progress).first() - val fail = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, TooManyAcceptedHtlcs(channelId, 15), null), channels) as OutgoingPaymentHandler.Failure - assertEquals(FinalFailure.InsufficientBalance, fail.failure.reason) - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentFailed(outgoingPaymentHandler.db, payment.paymentId, 5) - } - - @Test - fun `local channel failures`() = runSuspendTest { - testLocalChannelFailures(makeInvoice(amount = null, supportsTrampoline = true)) - } - - @Test - fun `local channel failures -- recipient is our peer`() = runSuspendTest { - // The invoice comes from Bob, our direct peer (and trampoline node). - testLocalChannelFailures(makeInvoice(amount = null, supportsTrampoline = true, privKey = TestConstants.Bob.nodeParams.nodePrivateKey)) - } - - @Test - fun `local channel failure followed by success`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 5_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress1 = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - assertEquals(1, progress1.actions.size) - assertEquals(5_000.msat, filterAddHtlcCommands(progress1).map { it.second.amount }.sum()) - - // This first payment fails: - val (channelId, add) = filterAddHtlcCommands(progress1).first() - val progress2 = outgoingPaymentHandler.processAddFailed(channelId, ChannelAction.ProcessCmdRes.AddFailed(add, TooManyAcceptedHtlcs(channelId, 1), null), channels) as OutgoingPaymentHandler.Progress - assertEquals(1, progress2.actions.size) - val adds = filterAddHtlcCommands(progress2) - assertEquals(5_000.msat, adds.map { it.second.amount }.sum()) - - // This second attempt succeeds: - val preimage = randomBytes32() - val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(0.msat, success.payment.fees) - assertEquals(5_000.msat, success.payment.recipientAmount) - assertEquals(preimage, success.preimage) - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 5_000.msat, fees = 0.msat, partsCount = 1) - } - - @Test - fun `partial failure then fulfill -- spec violation`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 310_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - - val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val remoteFailure = IncorrectOrUnknownPaymentDetails(310_000.msat, TestConstants.defaultBlockHeight.toLong()) - assertNull(outgoingPaymentHandler.processAddSettled(adds[0].first, createRemoteFailure(adds[0].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight)) - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - val preimage = randomBytes32() - val success = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[1].first, adds[1].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, success.preimage) - assertEquals((-250_000).msat, success.payment.fees) // since we paid much less than the expected amount, it results in negative fees - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 310_000.msat, fees = (-250_000).msat, partsCount = 1) - } - - @Test - fun `partial fulfill then failure -- spec violation`() = runSuspendTest { - val channels = makeChannels() - val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, InMemoryPaymentsDb()) - val invoice = makeInvoice(amount = null, supportsTrampoline = true) - val payment = PayInvoice(UUID.randomUUID(), 310_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(2, adds.size) - assertEquals(310_000.msat, adds.map { it.second.amount }.sum()) - - val preimage = randomBytes32() - val expected = OutgoingPaymentHandler.PreimageReceived(payment, preimage) - val result = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) - assertEquals(expected, result) - - // The recipient released the preimage without receiving the full payment amount. - // This is a spec violation and is too bad for them, we obtained a proof of payment without paying the full amount. - val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val remoteFailure = IncorrectOrUnknownPaymentDetails(310_000.msat, TestConstants.defaultBlockHeight.toLong()) - val success = outgoingPaymentHandler.processAddSettled(adds[1].first, createRemoteFailure(adds[1].second, attempt, remoteFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Success - assertEquals(preimage, success.preimage) - assertEquals((-60_000).msat, success.payment.fees) // since we paid much less than the expected amount, it results in negative fees - - assertNull(outgoingPaymentHandler.getPendingPayment(payment.paymentId)) - assertDbPaymentSucceeded(outgoingPaymentHandler.db, payment.paymentId, amount = 310_000.msat, fees = (-60_000).msat, partsCount = 1) - } - @Test fun `failure after a wallet restart`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val db = InMemoryPaymentsDb() // Step 1: a payment attempt is made. - val (adds, attempt) = run { + val (add, attempt) = run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 550_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) val attempt = outgoingPaymentHandler.getPendingPayment(payment.paymentId)!! - val adds = filterAddHtlcCommands(progress) - assertEquals(3, adds.size) - Pair(adds, attempt) + val (_, add) = findAddHtlcCommand(progress) + Pair(add, attempt) } // Step 2: the wallet restarts and payment fails. run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - assertNull(outgoingPaymentHandler.processAddFailed(adds[0].first, ChannelAction.ProcessCmdRes.AddFailed(adds[0].second, ChannelUnavailable(adds[0].first), null), channels)) - assertNull(outgoingPaymentHandler.processAddSettled(adds[1].first, createRemoteFailure(adds[1].second, attempt, TemporaryNodeFailure), channels, TestConstants.defaultBlockHeight)) - val result = outgoingPaymentHandler.processAddSettled(adds[2].first, createRemoteFailure(adds[2].second, attempt, PermanentNodeFailure), channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Failure - val failures: List> = listOf( - Either.Left(ChannelUnavailable(adds[0].first)), - // Since we've lost the shared secrets, we can't decrypt remote failures. - Either.Right(UnknownFailureMessage(0)), - Either.Right(UnknownFailureMessage(0)) - ) - assertFailureEquals(OutgoingPaymentHandler.Failure(attempt.request, OutgoingPaymentFailure(FinalFailure.WalletRestarted, failures)), result) + val fail = outgoingPaymentHandler.processAddSettledFailed(alice.channelId, createRemoteFailure(add, attempt, TemporaryNodeFailure), mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(fail) + assertEquals(attempt.request, fail.request) + assertEquals(FinalFailure.WalletRestarted, fail.failure.reason) + assertEquals(1, fail.failure.failures.size) + // Since we haven't stored the shared secrets, we can't decrypt remote failure. + assertIs(fail.failure.failures.first().failure) } } @Test fun `success after a wallet restart`() = runSuspendTest { - val channels = makeChannels() + val (alice, _) = TestsHelper.reachNormal() val preimage = randomBytes32() val invoice = makeInvoice(amount = null, supportsTrampoline = true) val payment = PayInvoice(UUID.randomUUID(), 550_000.msat, LightningOutgoingPayment.Details.Normal(invoice)) val db = InMemoryPaymentsDb() // Step 1: a payment attempt is made. - val adds = run { + val add = run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - val progress = outgoingPaymentHandler.sendPayment(payment, channels, TestConstants.defaultBlockHeight) as OutgoingPaymentHandler.Progress - val adds = filterAddHtlcCommands(progress) - assertEquals(3, adds.size) - // A first part is fulfilled before the wallet restarts. - val result = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[0].first, adds[0].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), result) - adds + val progress = outgoingPaymentHandler.sendPayment(payment, mapOf(alice.channelId to alice.state), TestConstants.defaultBlockHeight) + assertIs(progress) + findAddHtlcCommand(progress).second } // Step 2: the wallet restarts and payment succeeds. run { val outgoingPaymentHandler = OutgoingPaymentHandler(TestConstants.Alice.nodeParams, defaultWalletParams, db) - val result1 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[1].first, adds[1].second, preimage)) - assertEquals(OutgoingPaymentHandler.PreimageReceived(payment, preimage), result1) - val result2 = outgoingPaymentHandler.processAddSettled(createRemoteFulfill(adds[2].first, adds[2].second, preimage)) as OutgoingPaymentHandler.Success - assertEquals(preimage, result2.preimage) - assertEquals(3, result2.payment.parts.size) - assertEquals(payment, PayInvoice(result2.payment.id, result2.payment.recipientAmount, result2.payment.details as LightningOutgoingPayment.Details.Normal)) - assertEquals(preimage, (result2.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) + val success = outgoingPaymentHandler.processAddSettledFulfilled(createRemoteFulfill(alice.channelId, add, preimage)) + assertIs(success) + assertEquals(preimage, success.preimage) + assertEquals(1, success.payment.parts.size) + assertEquals(payment, success.request) + assertEquals(preimage, (success.payment.status as LightningOutgoingPayment.Status.Completed.Succeeded.OffChain).preimage) } } private fun makeInvoice(amount: MilliSatoshi?, supportsTrampoline: Boolean, privKey: PrivateKey = randomKey(), extraHops: List> = listOf()): Bolt11Invoice { val paymentPreimage: ByteVector32 = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() - - val invoiceFeatures = mutableMapOf( - Feature.VariableLengthOnion to FeatureSupport.Optional, - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.BasicMultiPartPayment to FeatureSupport.Optional - ) - if (supportsTrampoline) { - invoiceFeatures[Feature.ExperimentalTrampolinePayment] = FeatureSupport.Optional + val invoiceFeatures: Map = buildMap { + put(Feature.VariableLengthOnion, FeatureSupport.Optional) + put(Feature.PaymentSecret, FeatureSupport.Mandatory) + put(Feature.BasicMultiPartPayment, FeatureSupport.Optional) + if (supportsTrampoline) put(Feature.ExperimentalTrampolinePayment, FeatureSupport.Optional) } - return Bolt11Invoice.create( chain = Chain.Mainnet, amount = amount, @@ -922,50 +625,21 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { privateKey = privKey, description = Either.Left("unit test"), minFinalCltvExpiryDelta = Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA, - features = Features(invoiceFeatures.toMap()), + features = Features(invoiceFeatures), extraHops = extraHops ) } - private fun makeChannels(): Map { - val (alice, _) = TestsHelper.reachNormal() - val reserve = alice.commitments.latest.localChannelReserve - val channelDetails = listOf( - Pair(ShortChannelId(1), 250_000.msat), - Pair(ShortChannelId(2), 150_000.msat), - Pair(ShortChannelId(3), 0.msat), - Pair(ShortChannelId(4), 10_000.msat), - Pair(ShortChannelId(5), 200_000.msat), - Pair(ShortChannelId(6), 100_000.msat), - ) - return channelDetails.associate { - val channelId = randomBytes32() - val channel = alice.state.copy( - shortChannelId = it.first, - commitments = alice.commitments.copy( - params = alice.commitments.params.copy(channelId = channelId), - active = alice.commitments.active.map { commitment -> - commitment.copy(remoteCommit = commitment.remoteCommit.copy(spec = CommitmentSpec(setOf(), FeeratePerKw(0.sat), 50_000.msat, (it.second + ((Commitments.ANCHOR_AMOUNT * 2) + reserve).toMilliSatoshi())))) - } - ) - ) - channelId to channel - } - } - - private fun filterAddHtlcCommands(progress: OutgoingPaymentHandler.Progress): List> { - val addCommands = mutableListOf>() - for (action in progress.actions) { - val addCommand = action.channelCommand as? ChannelCommand.Htlc.Add - if (addCommand != null) { - addCommands.add(Pair(action.channelId, addCommand)) + private fun findAddHtlcCommand(progress: OutgoingPaymentHandler.Progress): Pair { + return progress.actions.firstNotNullOf { + when (val cmd = it.channelCommand) { + is ChannelCommand.Htlc.Add -> Pair(it.channelId, cmd) + else -> null } } - return addCommands.toList() } - private fun makeUpdateAddHtlc(channelId: ByteVector32, cmd: ChannelCommand.Htlc.Add, htlcId: Long = 0): UpdateAddHtlc = - UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) + private fun makeUpdateAddHtlc(channelId: ByteVector32, cmd: ChannelCommand.Htlc.Add, htlcId: Long = 0): UpdateAddHtlc = UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) private fun createRemoteFulfill(channelId: ByteVector32, add: ChannelCommand.Htlc.Add, preimage: ByteVector32): ChannelAction.ProcessCmdRes.AddSettledFulfill { val updateAddHtlc = makeUpdateAddHtlc(channelId, add) @@ -973,8 +647,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { } private fun createRemoteFailure(add: ChannelCommand.Htlc.Add, attempt: OutgoingPaymentHandler.PaymentAttempt, failureMessage: FailureMessage): ChannelAction.ProcessCmdRes.AddSettledFail { - val sharedSecrets = attempt.pending.getValue(add.paymentId).second - val reason = FailurePacket.create(sharedSecrets.perHopSecrets.last().first, failureMessage) + val reason = FailurePacket.create(attempt.sharedSecrets.perHopSecrets.last().first, failureMessage) val updateAddHtlc = makeUpdateAddHtlc(randomBytes32(), add) return ChannelAction.ProcessCmdRes.AddSettledFail( add.paymentId, diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt index 9b699dee3..e57d8390b 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt @@ -10,122 +10,70 @@ import fr.acinq.lightning.Lightning.randomBytes64 import fr.acinq.lightning.Lightning.randomKey import fr.acinq.lightning.channel.states.Channel import fr.acinq.lightning.crypto.RouteBlinding +import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.Sphinx import fr.acinq.lightning.crypto.sphinx.Sphinx.hash -import fr.acinq.lightning.payment.OutgoingPaymentHandler.PaymentAttempt import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.tests.utils.LightningTestSuite -import fr.acinq.lightning.utils.* +import fr.acinq.lightning.utils.currentTimestampMillis +import fr.acinq.lightning.utils.msat +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.utils.toByteVector32 import fr.acinq.lightning.wire.* import kotlin.test.* class PaymentPacketTestsCommon : LightningTestSuite() { companion object { - private val privA = randomKey() - private val a = privA.publicKey() private val privB = randomKey() private val b = privB.publicKey() private val privC = randomKey() private val c = privC.publicKey() private val privD = randomKey() private val d = privD.publicKey() - private val privE = randomKey() - private val e = privE.publicKey() private val defaultChannelUpdate = ChannelUpdate(randomBytes64(), Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0, 1, 0, CltvExpiryDelta(0), 42000.msat, 0.msat, 0, 500000000.msat) - private val channelUpdateAB = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1), cltvExpiryDelta = CltvExpiryDelta(4), feeBaseMsat = 642000.msat, feeProportionalMillionths = 7) - private val channelUpdateBC = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(2), cltvExpiryDelta = CltvExpiryDelta(5), feeBaseMsat = 153000.msat, feeProportionalMillionths = 4) - private val channelUpdateCD = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(3), cltvExpiryDelta = CltvExpiryDelta(10), feeBaseMsat = 60000.msat, feeProportionalMillionths = 1) - private val channelUpdateDE = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(4), cltvExpiryDelta = CltvExpiryDelta(7), feeBaseMsat = 766000.msat, feeProportionalMillionths = 10) - - // simple route a -> b -> c -> d -> e - private val hops = listOf( - ChannelHop(a, b, channelUpdateAB), - ChannelHop(b, c, channelUpdateBC), - ChannelHop(c, d, channelUpdateCD), - ChannelHop(d, e, channelUpdateDE) - ) + private val channelUpdateBC = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1105), cltvExpiryDelta = CltvExpiryDelta(36), feeBaseMsat = 7_500.msat, feeProportionalMillionths = 250) + private val channelUpdateCD = defaultChannelUpdate.copy(shortChannelId = ShortChannelId(1729), cltvExpiryDelta = CltvExpiryDelta(24), feeBaseMsat = 5_000.msat, feeProportionalMillionths = 100) - private val finalAmount = 42000000.msat - private const val currentBlockCount = 400000L + private val finalAmount = 50_000_000.msat + private const val currentBlockCount = 400_000L private val finalExpiry = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA private val paymentPreimage = randomBytes32() private val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() private val paymentSecret = randomBytes32() private val paymentMetadata = randomBytes(64).toByteVector() - private val expiryDE = finalExpiry - private val amountDE = finalAmount - private val feeD = nodeFee(channelUpdateDE.feeBaseMsat, channelUpdateDE.feeProportionalMillionths, amountDE) + private val nonTrampolineFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.BasicMultiPartPayment to FeatureSupport.Optional, + ) + private val trampolineFeatures = Features( + Feature.VariableLengthOnion to FeatureSupport.Mandatory, + Feature.PaymentSecret to FeatureSupport.Mandatory, + Feature.BasicMultiPartPayment to FeatureSupport.Optional, + Feature.ExperimentalTrampolinePayment to FeatureSupport.Optional, + ) - private val expiryCD = expiryDE + channelUpdateDE.cltvExpiryDelta - private val amountCD = amountDE + feeD + // Amount and expiry sent by C to D. + private val expiryCD = finalExpiry + private val amountCD = finalAmount private val feeC = nodeFee(channelUpdateCD.feeBaseMsat, channelUpdateCD.feeProportionalMillionths, amountCD) + // Amount and expiry sent by B to C. private val expiryBC = expiryCD + channelUpdateCD.cltvExpiryDelta private val amountBC = amountCD + feeC private val feeB = nodeFee(channelUpdateBC.feeBaseMsat, channelUpdateBC.feeProportionalMillionths, amountBC) + // Amount and expiry sent by A to B. private val expiryAB = expiryBC + channelUpdateBC.cltvExpiryDelta private val amountAB = amountBC + feeB - // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e - - private val trampolineHops = listOf( - NodeHop(a, c, channelUpdateAB.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta, feeB), - NodeHop(c, d, channelUpdateCD.cltvExpiryDelta, feeC), - NodeHop(d, e, channelUpdateDE.cltvExpiryDelta, feeD) - ) - - private val trampolineChannelHops = listOf( - ChannelHop(a, b, channelUpdateAB), - ChannelHop(b, c, channelUpdateBC) - ) - - private fun testBuildOnion() { - val finalPayload = PaymentOnion.FinalPayload.Standard(TlvStream(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount))) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) - assertEquals(amountAB, firstAmount) - assertEquals(expiryAB, firstExpiry) - assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) - // let's peel the onion - testPeelOnion(onion.packet) - } - - private fun testPeelOnion(packet_b: OnionRoutingPacket) { - val addB = UpdateAddHtlc(randomBytes32(), 0, amountAB, paymentHash, expiryAB, packet_b) - val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetC.payload.size()) - assertEquals(amountBC, payloadB.amountToForward) - assertEquals(expiryBC, payloadB.outgoingCltv) - assertEquals(channelUpdateBC.shortChannelId, payloadB.outgoingChannelId) - - val addC = UpdateAddHtlc(randomBytes32(), 1, amountBC, paymentHash, expiryBC, packetC) - val (payloadC, packetD) = decryptChannelRelay(addC, privC) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetD.payload.size()) - assertEquals(amountCD, payloadC.amountToForward) - assertEquals(expiryCD, payloadC.outgoingCltv) - assertEquals(channelUpdateCD.shortChannelId, payloadC.outgoingChannelId) - - val addD = UpdateAddHtlc(randomBytes32(), 2, amountCD, paymentHash, expiryCD, packetD) - val (payloadD, packetE) = decryptChannelRelay(addD, privD) - assertEquals(OnionRoutingPacket.PaymentPacketLength, packetE.payload.size()) - assertEquals(amountDE, payloadD.amountToForward) - assertEquals(expiryDE, payloadD.outgoingCltv) - assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) - - val addE = UpdateAddHtlc(randomBytes32(), 2, amountDE, paymentHash, expiryDE, packetE) - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! - assertIs(payloadE) - assertEquals(finalAmount, payloadE.amount) - assertEquals(finalAmount, payloadE.totalAmount) - assertEquals(finalExpiry, payloadE.expiry) - assertEquals(paymentSecret, payloadE.paymentSecret) - } + // C is directly connected to the recipient D. + val nodeHop_cd = NodeHop(c, d, channelUpdateCD.cltvExpiryDelta, feeC) + // B is not directly connected to the recipient D. + val nodeHop_bd = NodeHop(b, d, channelUpdateBC.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta, feeB + feeC) // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { @@ -135,14 +83,25 @@ class PaymentPacketTestsCommon : LightningTestSuite() { return Pair(decoded, decrypted.nextPacket) } + // Wallets don't need to create channel routes, but it's useful to test the end-to-end flow. + fun encryptChannelRelay(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Triple { + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { (amount, expiry, payloads), hop -> + val payload = PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) + Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) + } + val onion = OutgoingPaymentPacket.buildOnion(hops.map { it.nextNodeId }, payloads, paymentHash, OnionRoutingPacket.PaymentPacketLength) + return Triple(firstAmount, firstExpiry, onion) + } + // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToTrampoline(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertFalse(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.NodeRelayPayload.read(decryptedInner.payload).right!! assertNull(innerPayload.records.get()) assertNull(innerPayload.records.get()) @@ -152,34 +111,35 @@ class PaymentPacketTestsCommon : LightningTestSuite() { } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptRelayToNonTrampolinePayload(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToLegacy(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertFalse(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.RelayToNonTrampolinePayload.read(decryptedInner.payload).right!! - return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) + return Pair(outerPayload, innerPayload) } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptRelayToBlinded(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptRelayToBlinded(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket).right!! assertTrue(decrypted.isLastPacket) val outerPayload = PaymentOnion.FinalPayload.Standard.read(decrypted.payload).right!! val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet).right!! + assertTrue(decryptedInner.isLastPacket) val innerPayload = PaymentOnion.RelayToBlindedPayload.read(decryptedInner.payload).right!! - return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) + return Pair(outerPayload, innerPayload) } - // Create an HTLC paying an empty blinded path. - fun createBlindedHtlc(): Pair { - val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), null, 1, currentTimestampMillis()) - val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))) - val blindedRoute = RouteBlinding.create(randomKey(), listOf(e), listOf(blindedPayload.write().byteVector())).route + fun createBlindedHtlcCD(): Pair { + val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedRoute = RouteBlinding.create(randomKey(), listOf(d), listOf(blindedPayload.write().byteVector())).route val finalPayload = PaymentOnion.FinalPayload.Blinded( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), @@ -189,345 +149,299 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ), blindedPayload ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(paymentMetadata.paymentHash, listOf(blindedHop), finalPayload, OnionRoutingPacket.PaymentPacketLength) - val add = UpdateAddHtlc(randomBytes32(), 2, amountE, paymentMetadata.paymentHash, expiryE, onionE.packet, blindedRoute.blindingKey, null) - return Pair(add, finalPayload) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(finalPayload), paymentHash, OnionRoutingPacket.PaymentPacketLength) + val addD = UpdateAddHtlc(randomBytes32(), 1, finalAmount, paymentHash, finalExpiry, onionD.packet, blindedRoute.blindingKey, null) + return Pair(addD, paymentMetadata) } - } - @Test - fun `build onion`() { - testBuildOnion() - } - - @Test - fun `build a command including the onion`() { - val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null)) - assertTrue(add.amount > finalAmount) - assertEquals(add.cltvExpiry, finalExpiry + channelUpdateDE.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta) - assertEquals(add.paymentHash, paymentHash) - assertEquals(add.onion.payload.size(), OnionRoutingPacket.PaymentPacketLength) - - // let's peel the onion - testPeelOnion(add.onion) - } - - @Test - fun `build a command with no hops`() { - val paymentSecret = randomBytes32() - val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata)) - assertEquals(add.amount, finalAmount) - assertEquals(add.cltvExpiry, finalExpiry) - assertEquals(add.paymentHash, paymentHash) - assertEquals(add.onion.payload.size(), OnionRoutingPacket.PaymentPacketLength) - - // let's peel the onion - val addB = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion) - val finalPayload = IncomingPaymentPacket.decrypt(addB, privB).right!! - assertIs(finalPayload) - assertEquals(finalPayload.amount, finalAmount) - assertEquals(finalPayload.totalAmount, finalAmount) - assertEquals(finalPayload.expiry, finalExpiry) - assertEquals(paymentSecret, finalPayload.paymentSecret) - assertEquals(paymentMetadata, finalPayload.paymentMetadata) + fun createBlindedPaymentInfo(u: ChannelUpdate): OfferTypes.PaymentInfo { + return OfferTypes.PaymentInfo(u.feeBaseMsat, u.feeProportionalMillionths, u.cltvExpiryDelta, u.htlcMinimumMsat, u.htlcMaximumMsat!!, Features.empty) + } } @Test - fun `build a trampoline payment`() { - // simple trampoline route to e: - // .--. .--. - // / \ / \ - // a -> b -> c d e - - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, paymentMetadata), - null - ) - assertEquals(amountBC, amountAC) - assertEquals(expiryBC, expiryAC) - - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - assertEquals(amountAB, firstAmount) - assertEquals(expiryAB, firstExpiry) - - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(PaymentOnion.ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) - - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) + fun `send a trampoline payment -- recipient connected to trampoline node`() { + // .--. + // / \ + // b -> c d + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(amountBC, firstAmount) + assertEquals(expiryBC, firstExpiry) + + // B sends an HTLC to its trampoline node C. + val addC = UpdateAddHtlc(randomBytes32(), 1, amountBC, paymentHash, expiryBC, onion.packet) + val (outerC, innerC, packetD) = decryptRelayToTrampoline(addC, privC) assertEquals(amountBC, outerC.amount) assertEquals(amountBC, outerC.totalAmount) assertEquals(expiryBC, outerC.expiry) - assertEquals(amountCD, innerC.amountToForward) - assertEquals(expiryCD, innerC.outgoingCltv) + assertEquals(finalAmount, innerC.amountToForward) + assertEquals(finalExpiry, innerC.outgoingCltv) assertEquals(d, innerC.outgoingNodeId) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) + // C forwards the trampoline payment to D over its direct channel. + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } assertEquals(amountCD, amountD) assertEquals(expiryCD, expiryD) - val addD = UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet) - val (outerD, innerD, packetE) = decryptNodeRelay(addD, privD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertEquals(amountDE, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(e, innerD.outgoingNodeId) - - // d forwards the trampoline payment to e. - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - assertEquals(amountDE, amountE) - assertEquals(expiryDE, expiryE) - val addE = UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet) - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD, onionD.packet) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), - OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount * 3), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) ) ) - assertEquals(payloadE, expectedFinalPayload) + assertEquals(payloadD, expectedFinalPayload) } @Test - fun `build a trampoline payment with non-trampoline recipient`() { - // simple trampoline route to e where e doesn't support trampoline: - // .--. - // / \ - // a -> b -> c d -> e - - val routingHints = listOf(Bolt11Invoice.TaggedField.ExtraHop(randomKey().publicKey(), ShortChannelId(42), 10.msat, 100, CltvExpiryDelta(144))) - val invoiceFeatures = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) - val invoice = Bolt11Invoice( - "lnbcrt", finalAmount, currentTimestampSeconds(), e, listOf( - Bolt11Invoice.TaggedField.PaymentHash(paymentHash), - Bolt11Invoice.TaggedField.PaymentSecret(paymentSecret), - Bolt11Invoice.TaggedField.PaymentMetadata(paymentMetadata), - Bolt11Invoice.TaggedField.DescriptionHash(randomBytes32()), - Bolt11Invoice.TaggedField.Features(invoiceFeatures.toByteArray().toByteVector()), - Bolt11Invoice.TaggedField.RoutingInfo(routingHints) - ), ByteVector.empty - ) - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket( - invoice, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null) - ) - assertEquals(amountBC, amountAC) - assertEquals(expiryBC, expiryAC) - - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) + fun `send a trampoline payment -- recipient not connected to trampoline node`() { + // .-> c -. + // / \ + // a -> b d + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) assertEquals(amountAB, firstAmount) assertEquals(expiryAB, firstExpiry) - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (_, packetC) = decryptChannelRelay(addB, privB) + // A sends an HTLC to its trampoline node B. + val addB = UpdateAddHtlc(randomBytes32(), 1, amountAB, paymentHash, expiryAB, onion.packet) + val (outerB, innerB, packetC) = decryptRelayToTrampoline(addB, privB) + assertEquals(amountAB, outerB.amount) + assertEquals(amountAB, outerB.totalAmount) + assertEquals(expiryAB, outerB.expiry) + assertEquals(finalAmount, innerB.amountToForward) + assertEquals(finalExpiry, innerB.outgoingCltv) + assertEquals(d, innerB.outgoingNodeId) + + // B forwards the trampoline payment to D over an indirect channel route. + val (amountC, expiryC, onionC) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerB.amountToForward, innerB.amountToForward, innerB.outgoingCltv, randomBytes32(), packetC) + encryptChannelRelay(paymentHash, listOf(ChannelHop(b, c, channelUpdateBC), ChannelHop(c, d, channelUpdateCD)), payloadD) + } + assertEquals(amountBC, amountC) + assertEquals(expiryBC, expiryC) + + // C relays the payment to D. + val addC = UpdateAddHtlc(randomBytes32(), 2, amountC, paymentHash, expiryC, onionC.packet) + val (payloadC, packetD) = decryptChannelRelay(addC, privC) + val addD = UpdateAddHtlc(randomBytes32(), 3, payloadC.amountToForward, paymentHash, payloadC.outgoingCltv, packetD) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( + TlvStream( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) + ) + ) + assertEquals(payloadD, expectedFinalPayload) + } - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) - assertIs(outerC) - assertEquals(amountBC, outerC.amount) - assertEquals(amountBC, outerC.totalAmount) - assertEquals(expiryBC, outerC.expiry) - assertNotEquals(invoice.paymentSecret, outerC.paymentSecret) - assertEquals(amountCD, innerC.amountToForward) - assertEquals(expiryCD, innerC.outgoingCltv) - assertEquals(d, innerC.outgoingNodeId) + @Test + fun `send a trampoline payment to legacy non-trampoline recipient`() { + // .-> c -. + // / \ + // a -> b d + val routingHint = listOf(Bolt11Invoice.TaggedField.ExtraHop(c, channelUpdateCD.shortChannelId, channelUpdateCD.feeBaseMsat, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.cltvExpiryDelta)) + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), nonTrampolineFeatures, paymentSecret, paymentMetadata, extraHops = listOf(routingHint)) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToLegacyRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) + assertEquals(amountAB, firstAmount) + assertEquals(expiryAB, firstExpiry) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength + // A sends an HTLC to its trampoline node B. + val addB = UpdateAddHtlc(randomBytes32(), 1, amountAB, paymentHash, expiryAB, onion.packet) + val (outerB, innerB) = decryptRelayToLegacy(addB, privB) + assertEquals(amountAB, outerB.amount) + assertEquals(amountAB, outerB.totalAmount) + assertEquals(expiryAB, outerB.expiry) + assertEquals(invoice.paymentSecret, outerB.paymentSecret) + assertEquals(finalAmount, innerB.amountToForward) + assertEquals(finalAmount, innerB.totalAmount) + assertEquals(finalExpiry, innerB.outgoingCltv) + assertEquals(d, innerB.outgoingNodeId) + assertEquals(invoice.paymentSecret, innerB.paymentSecret) + assertEquals(invoice.paymentMetadata, innerB.paymentMetadata) + assertEquals(ByteVector("024100"), innerB.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp + assertEquals(listOf(routingHint), innerB.invoiceRoutingInfo) + + // B forwards the trampoline payment to D over an indirect channel route. + val (amountC, expiryC, onionC) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createSinglePartPayload(innerB.amountToForward, innerB.outgoingCltv, innerB.paymentSecret, innerB.paymentMetadata) + encryptChannelRelay(paymentHash, listOf(ChannelHop(b, c, channelUpdateBC), ChannelHop(c, d, channelUpdateCD)), payloadD) + } + assertEquals(amountBC, amountC) + assertEquals(expiryBC, expiryC) + + // C relays the payment to D. + val addC = UpdateAddHtlc(randomBytes32(), 2, amountC, paymentHash, expiryC, onionC.packet) + val (payloadC, packetD) = decryptChannelRelay(addC, privC) + val addD = UpdateAddHtlc(randomBytes32(), 3, payloadC.amountToForward, paymentHash, payloadC.outgoingCltv, packetD) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload.Standard( + TlvStream( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata), + ) ) - assertEquals(amountCD, amountD) - assertEquals(expiryCD, expiryD) - val addD = UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet) - val (outerD, innerD, _) = decryptRelayToNonTrampolinePayload(addD, privD) - assertIs(outerD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertNotEquals(invoice.paymentSecret, outerD.paymentSecret) - assertEquals(finalAmount, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(e, innerD.outgoingNodeId) - assertEquals(finalAmount, innerD.totalAmount) - assertEquals(invoice.paymentSecret, innerD.paymentSecret) - assertEquals(invoice.paymentMetadata, innerD.paymentMetadata) - assertEquals(ByteVector("024100"), innerD.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp - assertEquals(listOf(routingHints), innerD.invoiceRoutingInfo) + assertEquals(payloadD, expectedFinalPayload) } @Test - fun `build a trampoline payment to blinded paths`() { + fun `send a trampoline payment to blinded paths`() { val features = Features(Feature.BasicMultiPartPayment to FeatureSupport.Optional) - val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", e, features, Block.LivenetGenesisBlock.hash) - // E uses a 1-hop blinded path from its LSP. + val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", d, features, Block.RegtestGenesisBlock.hash) + // D uses a 1-hop blinded path from its trampoline node C. val (invoice, blindedRoute) = run { val payerKey = randomKey() - val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, null, Block.LivenetGenesisBlock.hash) - val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), null, 1, currentTimestampMillis()) - val blindedPayloads = listOf( - RouteBlindingEncryptedData( - TlvStream( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(channelUpdateDE.shortChannelId), - RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateDE.cltvExpiryDelta, channelUpdateDE.feeProportionalMillionths, channelUpdateDE.feeBaseMsat), - RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), - ) - ), - RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))), - ).map { it.write().byteVector() } - val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(d, e), blindedPayloads) - val paymentInfo = OfferTypes.PaymentInfo(channelUpdateDE.feeBaseMsat, channelUpdateDE.feeProportionalMillionths, channelUpdateDE.cltvExpiryDelta, channelUpdateDE.htlcMinimumMsat, channelUpdateDE.htlcMaximumMsat!!, Features.empty) + val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, "hello", Block.RegtestGenesisBlock.hash) + val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayloadC = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(channelUpdateCD.shortChannelId), + RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateCD.cltvExpiryDelta, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.feeBaseMsat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), + ) + ) + val blindedPayloadD = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)) + ) + ) + val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(c, d), listOf(blindedPayloadC, blindedPayloadD).map { it.write().byteVector() }) + val paymentInfo = createBlindedPaymentInfo(channelUpdateCD) val path = Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRouteDetails.route), paymentInfo) - val invoice = Bolt12Invoice(request, paymentPreimage, blindedRouteDetails.blindedPrivateKey(privE), 600, features, listOf(path)) + val invoice = Bolt12Invoice(request, paymentPreimage, blindedRouteDetails.blindedPrivateKey(privD), 600, features, listOf(path)) assertEquals(invoice.nodeId, blindedRouteDetails.route.blindedNodeIds.last()) + assertNotEquals(invoice.nodeId, d) Pair(invoice, blindedRouteDetails.route) } - // C pays that invoice using a trampoline node to relay to the invoice's blinded path. - val (firstAmount, firstExpiry, onion) = run { - val trampolineHop = NodeHop(d, invoice.nodeId, channelUpdateDE.cltvExpiryDelta, feeD) - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(invoice, trampolineHop, finalAmount, finalExpiry) - val trampolinePayload = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet) - OutgoingPaymentPacket.buildPacket(invoice.paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), trampolinePayload, OnionRoutingPacket.PaymentPacketLength) - } - assertEquals(amountCD, firstAmount) - assertEquals(expiryCD, firstExpiry) + // B pays that invoice using its trampoline node C to relay to the invoice's blinded path. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToBlindedRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(amountBC, firstAmount) + assertEquals(expiryBC, firstExpiry) - // D decrypts the onion that contains a blinded path in the trampoline onion. - val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (outerD, innerD, _) = decryptRelayToBlinded(addD, privD) - assertEquals(amountCD, outerD.amount) - assertEquals(amountCD, outerD.totalAmount) - assertEquals(expiryCD, outerD.expiry) - assertEquals(finalAmount, innerD.amountToForward) - assertEquals(expiryDE, innerD.outgoingCltv) - assertEquals(listOf(blindedRoute), innerD.outgoingBlindedPaths.map { it.route.route }) - assertEquals(invoice.features.toByteArray().toByteVector(), innerD.invoiceFeatures) - - // D is the introduction node of the blinded path: it can decrypt the first blinded payload and relay to E. - val addE = run { - val (dataD, blindingE) = RouteBlinding.decryptPayload(privD, blindedRoute.blindingKey, blindedRoute.encryptedPayloads.first()).right!! - val payloadD = RouteBlindingEncryptedData.read(dataD.toByteArray()).right!! - assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) - // D would normally create this payload based on the blinded path's payment_info field. - val payloadE = PaymentOnion.FinalPayload.Blinded( + // C decrypts the onion that contains a blinded path in the trampoline onion. + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (outerC, innerC) = decryptRelayToBlinded(addC, privC) + assertEquals(amountBC, outerC.amount) + assertEquals(amountBC, outerC.totalAmount) + assertEquals(expiryBC, outerC.expiry) + assertEquals(finalAmount, innerC.amountToForward) + assertEquals(finalExpiry, innerC.outgoingCltv) + assertEquals(listOf(blindedRoute), innerC.outgoingBlindedPaths.map { it.route.route }) + assertEquals(invoice.features.toByteArray().toByteVector(), innerC.invoiceFeatures) + + // C is the introduction node of the blinded path: it can decrypt the first blinded payload and relay to D. + val addD = run { + val (dataC, blindingD) = RouteBlinding.decryptPayload(privC, blindedRoute.blindingKey, blindedRoute.encryptedPayloads.first()).right!! + val payloadC = RouteBlindingEncryptedData.read(dataC.toByteArray()).right!! + assertEquals(channelUpdateCD.shortChannelId, payloadC.outgoingChannelId) + // C would normally create this payload based on the payment_relay field it received. + val payloadD = PaymentOnion.FinalPayload.Blinded( TlvStream( - OnionPaymentPayloadTlv.AmountToForward(finalAmount), - OnionPaymentPayloadTlv.TotalAmount(finalAmount), - OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.AmountToForward(innerC.amountToForward), + OnionPaymentPayloadTlv.TotalAmount(innerC.amountToForward), + OnionPaymentPayloadTlv.OutgoingCltv(innerC.outgoingCltv), OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads.last()), ), - // This dummy value is ignored when creating the htlc (D is not the recipient). + // This dummy value is ignored when creating the htlc (C is not the recipient). RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(ByteVector("deadbeef")))) ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(addD.paymentHash, listOf(blindedHop), payloadE, OnionRoutingPacket.PaymentPacketLength) - UpdateAddHtlc(randomBytes32(), 2, amountE, addD.paymentHash, expiryE, onionE.packet, blindingE, null) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(payloadD), addC.paymentHash, OnionRoutingPacket.PaymentPacketLength) + UpdateAddHtlc(randomBytes32(), 2, innerC.amountToForward, addC.paymentHash, innerC.outgoingCltv, onionD.packet, blindingD, null) } - // E can correctly decrypt the blinded payment. - val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! - assertIs(payloadE) - val paymentMetadata = OfferPaymentMetadata.fromPathId(e, payloadE.pathId) + // D can correctly decrypt the blinded payment. + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertIs(payloadD) + assertEquals(finalAmount, payloadD.amount) + assertEquals(finalExpiry, payloadD.expiry) + val paymentMetadata = OfferPaymentMetadata.fromPathId(d, payloadD.pathId) assertNotNull(paymentMetadata) assertEquals(offer.offerId, paymentMetadata.offerId) assertEquals(paymentMetadata.paymentHash, invoice.paymentHash) } @Test - fun `build a payment to a blinded path`() { - val (addE, payloadE) = createBlindedHtlc() - // E can correctly decrypt the blinded payment. - assertEquals(payloadE, IncomingPaymentPacket.decrypt(addE, privE).right) + fun `receive a channel payment`() { + // c -> d + val finalPayload = PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 1.5, finalExpiry, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), finalPayload) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertEquals(finalPayload, payloadD) } @Test - fun `fail to decrypt when the onion is invalid`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + fun `receive a blinded channel payment`() { + // c -> d + val (addD, paymentMetadata) = createBlindedHtlcCD() + val payloadD = IncomingPaymentPacket.decrypt(addD, privD).right!! + assertIs(payloadD) + assertEquals(finalAmount, payloadD.amount) + assertEquals(finalAmount, payloadD.totalAmount) + assertEquals(finalExpiry, payloadD.expiry) + assertEquals(paymentMetadata, OfferPaymentMetadata.fromPathId(d, payloadD.pathId)) + } + + @Test + fun `fail to decrypt when the payment onion is invalid`() { + val (firstAmount, firstExpiry, onion) = encryptChannelRelay( paymentHash, - hops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null), - OnionRoutingPacket.PaymentPacketLength + listOf(ChannelHop(c, d, channelUpdateCD)), + PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null) ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) - val failure = IncomingPaymentPacket.decrypt(add, privB) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when the trampoline onion is invalid`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), - OnionRoutingPacket.PaymentPacketLength - ) - val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val (_, packetC) = decryptChannelRelay(addB, privB) - val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val failure = IncomingPaymentPacket.decrypt(addC, privC) + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + // C modifies the trampoline onion before forwarding the trampoline payment to D. + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD.copy(payload = packetD.payload.reversed())) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD, onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when payment hash doesn't match associated data`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = encryptChannelRelay( paymentHash.reversed(), - hops, - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength + listOf(ChannelHop(c, d, channelUpdateCD)), + PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata) ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) + val addD = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when blinded route data is invalid`() { - val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), null, 1, currentTimestampMillis()) - val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privE)))) - val blindedRoute = RouteBlinding.create(randomKey(), listOf(e), listOf(blindedPayload.write().byteVector())).route - val payloadE = PaymentOnion.FinalPayload.Blinded( + val paymentMetadata = OfferPaymentMetadata.V1(randomBytes32(), finalAmount, paymentPreimage, randomKey().publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedRoute = RouteBlinding.create(randomKey(), listOf(d), listOf(blindedPayload.write().byteVector())).route + val payloadD = PaymentOnion.FinalPayload.Blinded( TlvStream( OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.TotalAmount(finalAmount), @@ -537,135 +451,105 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ), blindedPayload ) - val blindedHop = ChannelHop(d, blindedRoute.blindedNodeIds.last(), channelUpdateDE) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket(paymentMetadata.paymentHash, listOf(blindedHop), payloadE, OnionRoutingPacket.PaymentPacketLength) - val addE = UpdateAddHtlc(randomBytes32(), 2, amountE, paymentMetadata.paymentHash, expiryE, onionE.packet, blindedRoute.blindingKey, null) - val failure = IncomingPaymentPacket.decrypt(addE, privE) + val onionD = OutgoingPaymentPacket.buildOnion(listOf(blindedRoute.blindedNodeIds.last()), listOf(payloadD), paymentHash, OnionRoutingPacket.PaymentPacketLength) + val addD = UpdateAddHtlc(randomBytes32(), 1, finalAmount, paymentHash, finalExpiry, onionD.packet, blindedRoute.blindingKey, null) + val failure = IncomingPaymentPacket.decrypt(addD, privD) assertTrue(failure.isLeft) - assertEquals(failure.left, InvalidOnionBlinding(hash(addE.onionRoutingPacket))) + assertEquals(failure.left, InvalidOnionBlinding(hash(addD.onionRoutingPacket))) } @Test - fun `fail to decrypt at the final node when amount has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - hops.take(1), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength - ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) - assertEquals(Either.Left(FinalIncorrectHtlcAmount(firstAmount - 100.msat)), failure) + fun `fail to decrypt when amount has been modified by trampoline node`() { + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD - 100.msat, paymentHash, expiryD, onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) + assertEquals(Either.Left(FinalIncorrectHtlcAmount(finalAmount - 100.msat)), failure) } @Test - fun `fail to decrypt at the final node when expiry has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - hops.take(1), - PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), - OnionRoutingPacket.PaymentPacketLength - ) - val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet) - val failure = IncomingPaymentPacket.decrypt(add, privB) - assertEquals(Either.Left(FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))), failure) + fun `fail to decrypt when expiry has been modified by trampoline node`() { + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), trampolineFeatures, paymentSecret, paymentMetadata) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToTrampolineRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC, packetD) = decryptRelayToTrampoline(addC, privC) + val (amountD, expiryD, onionD) = run { + val payloadD = PaymentOnion.FinalPayload.Standard.createTrampolinePayload(innerC.amountToForward, innerC.amountToForward, innerC.outgoingCltv, randomBytes32(), packetD) + encryptChannelRelay(paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), payloadD) + } + val addD = UpdateAddHtlc(randomBytes32(), 2, amountD, paymentHash, expiryD - CltvExpiryDelta(12), onionD.packet) + val failure = IncomingPaymentPacket.decrypt(addD, privD) + assertEquals(Either.Left(FinalIncorrectCltvExpiry(finalExpiry - CltvExpiryDelta(12))), failure) } @Test - fun `fail to decrypt blinded payment at the final node when amount is too low`() { - val (addE, _) = createBlindedHtlc() - // E receives a smaller amount than expected and rejects the payment. - val failure = IncomingPaymentPacket.decrypt(addE.copy(amountMsat = addE.amountMsat - 1.msat), privE).left - assertEquals(InvalidOnionBlinding(hash(addE.onionRoutingPacket)), failure) + fun `fail to decrypt blinded payment when amount is too low`() { + val (addD, _) = createBlindedHtlcCD() + // D receives a smaller amount than expected and rejects the payment. + val failure = IncomingPaymentPacket.decrypt(addD.copy(amountMsat = addD.amountMsat - 1.msat), privD).left + assertEquals(InvalidOnionBlinding(hash(addD.onionRoutingPacket)), failure) } @Test - fun `fail to decrypt blinded payment at the final node when expiry is too low`() { - val (addE, _) = createBlindedHtlc() + fun `fail to decrypt blinded payment when expiry is too low`() { + val (addD, _) = createBlindedHtlcCD() // E receives a smaller expiry than expected and rejects the payment. - val failure = IncomingPaymentPacket.decrypt(addE.copy(cltvExpiry = addE.cltvExpiry - CltvExpiryDelta(1)), privE).left - assertEquals(InvalidOnionBlinding(hash(addE.onionRoutingPacket)), failure) - } - - @Test - fun `fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) - val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) - // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). - val invalidTotalAmount = amountDE + 100.msat - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) - assertEquals(Either.Left(FinalIncorrectHtlcAmount(invalidTotalAmount)), failure) + val failure = IncomingPaymentPacket.decrypt(addD.copy(cltvExpiry = addD.cltvExpiry - CltvExpiryDelta(1)), privD).left + assertEquals(InvalidOnionBlinding(hash(addD.onionRoutingPacket)), failure) } @Test - fun `fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineHops, - PaymentOnion.FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), - null - ) - val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( - paymentHash, - trampolineChannelHops, - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) - val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) - // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(c, d, channelUpdateCD)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), - OnionRoutingPacket.PaymentPacketLength - ) - val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) - // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). - val invalidExpiry = expiryDE - CltvExpiryDelta(12) - val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( - paymentHash, - listOf(ChannelHop(d, e, channelUpdateDE)), - PaymentOnion.FinalPayload.Standard.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), - OnionRoutingPacket.PaymentPacketLength - ) - val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) - assertEquals(Either.Left(FinalIncorrectCltvExpiry(invalidExpiry)), failure) + fun `prune outgoing blinded paths`() { + // We create an invoice with a large number of blinded paths. + val features = Features(Feature.BasicMultiPartPayment to FeatureSupport.Optional) + val offer = OfferTypes.Offer.createNonBlindedOffer(finalAmount, "test offer", d, features, Block.RegtestGenesisBlock.hash) + val invoice = run { + val payerKey = randomKey() + val request = OfferTypes.InvoiceRequest(offer, finalAmount, 1, features, payerKey, "hello", Block.RegtestGenesisBlock.hash) + val paymentMetadata = OfferPaymentMetadata.V1(offer.offerId, finalAmount, paymentPreimage, payerKey.publicKey(), "hello", 1, currentTimestampMillis()) + val blindedPayloadD = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(paymentMetadata.toPathId(privD)))) + val blindedPaths = (0 until 20L).map { i -> + val blindedPayloadC = RouteBlindingEncryptedData( + TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(i)), + RouteBlindingEncryptedDataTlv.PaymentRelay(channelUpdateCD.cltvExpiryDelta, channelUpdateCD.feeProportionalMillionths, channelUpdateCD.feeBaseMsat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiry, 1.msat), + ) + ) + val blindedRouteDetails = RouteBlinding.create(randomKey(), listOf(c, d), listOf(blindedPayloadC, blindedPayloadD).map { it.write().byteVector() }) + val paymentInfo = createBlindedPaymentInfo(channelUpdateCD) + Bolt12Invoice.Companion.PaymentBlindedContactInfo(OfferTypes.ContactInfo.BlindedPath(blindedRouteDetails.route), paymentInfo) + } + Bolt12Invoice(request, paymentPreimage, randomKey(), 600, features, blindedPaths) + } + // B sends an HTLC to its trampoline node C: we prune the blinded paths to fit inside the onion. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToBlindedRecipient(invoice, finalAmount, finalExpiry, nodeHop_cd) + assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) + val addC = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerC) = decryptRelayToBlinded(addC, privC) + assertTrue(innerC.outgoingBlindedPaths.size < invoice.blindedPaths.size) + innerC.outgoingBlindedPaths.forEach { assertTrue(invoice.blindedPaths.contains(it)) } } @Test - fun `relay to blinded with many large blinded routes`() { - val invoice = Bolt12Invoice.fromString("lni1qqs0sehhttf0swv6sxsxuefk9q23yj9h0cl4wyn324jlt2gll4fzk0syzquta4dy4m9jgp5s970w2lw8928ppnqz662mq8r6dyy7w9kgv0ann0zlk7aacwpykl7uu5adckf7t0sgpees93j6nwr4xle4ck4d2wup6j044fx6f5zkdacvjen3herm493en2lmqgprfesng92cpl68qvlxv96prguvtl7yv49nwq9a92zl5n2pv73u4wgqx08jupl9mcphw7ww5mv50j7xkcjxvjamrysy3wsfmgj4cu9aq4y6hszdk5rys4qwa4muxczk4nuj54gyw3jqygaqqk2edh9c90ergwvaxu93jynty9z0gzrpcwuhkzt05epatv45qqg93d3x5hx8jgwgkwpnz85y3narw93pqvmp27mzgwjaswd76lgxm88qw2ehqszd2mlgntvnwq0vzrk4ehq325pqdl3gcz4k7xeh9sdx5fr2uclhf7f3aqm9u9dq38rg6cvsqqqqqqq9yqjwyp2qxqsqqpvzzqnfhw87mvwxx0qrljmlxppxr7h4hmdgdzpsxnj0rw2jckjtc93g44v3gmrfva58gmnfdenjq6tnyp3k7mmvypq5dg8aqsqsxhj07sv0ez642nzanm4xvwtvyfaag2dry5wge0r3zqpt5g2mls3xqwsadj5tn7gm50l43xmsu2dszx8s672wrkqn48qp7ftw6kvhltlrkqsrwc9agnmlg8dzefxe4vlz9euhhwjzeflzrrkkurluwhzfjqf62trsq2dm4984vptwe3zpea0e3n27w75pyt2ym8g734awf99ep759a4xdf29epp5z3e303e08q2e6qdycnlzdkv4c80u2pr7ql2ndpa9zfafwyqkf9r9cqtaza35txq9rppm56ajy97nema0j8vtuld097u0hjhh5cmwhlp8du95va7wvq2emxkvafy2wg8ghl7vpkw2w0w0pp2tr6gqqmnvxy9u9jg9rkwn2flxn9d3vd9x0g94wdhgap22l780d0t9zywpj7yu949347hsag0s6pgpkuaf2pg8xqzjk336e8azw8672qtvvljye6gtevmnrygav24jvsxmypf25rqs8ersr9wfsqhdxjl6lmy3lqac4q96rjn02vtk4n6zcaqzr5qhlqlywrga7lmwt99gyj23ndzrl2y8380nxydqntsh7gpajau2wrvpshzpryq37n72ujf26djermgcf47wfsfdts3p97u0lnt6szck8zlqzqfz6ex8p9xtugj6jtuyf3flkqk9ahu8e6dxgvp49xl24x6hdmljcjqptk8erxq4uwnkv4g32p4d8un0f9az5ft59mwvgtf3xxm3xepx68x6vk46pf5xh4lyr35m32q5ukd5xj4kj7vka67uggjd4saahn3dswdwsqe8epg2pvele4cz0zvq288t94n69cj87ewal7yqvvpq6tj6mpjw2d36uhlf2khtjlpdfv672gvkszdpqht37shtf0ytyfnl7xkeudm2xscvpack3whk29hsfr36cv95ddzc2klk2yc93s09y2uh2xz5k5jap4550z2vd7gy9faxd2gcr4xjmke8yhn0wdsqxpe78qyqy0elprd0swe9yxc82as23vm5fyumc7nv470xfsjy3zjg4zfvpa6kv3ndl0w3ukc6627dua9g7pmzfupesqvpjfw3rjt39hlmk40gtralyk5a4lq42207acdqekpgmdjqpmw0zy3cz8d3cpvw4a28lqnf3etu8adl43zra0zk5m094h86g4yfm8cxnz7zqyqa5q4dz6ceqp0jeuakvyx4653h0kjllhhpfduquh4ch8rtu9lf20gqznhxe60n3m45k80vcgnvut8cfgpvy37ae93rpah9sa4ddyzyt9gu0uj5zny00vrkkxycrxm8wk3gzyrn3q8caxrqwa48fwqzlagph95y5tw555px0s97y4lgspgedkg4q9gxgv4qyc7x8433j0a5fnntdjhlsmx4wje53dj8y4c3757yt87kenvej8ugwkg05flugshd906tvtqcldrcgxfqqw8ecz926ft3ehdahpjhueag8znklv6qvz7w3htszz5k3we0dd8r2grxpz5edgtuh28u0c5ugrf8autgu3crkjt58ev9gcvnetn5kjwk4wehr8yq7lulsug7dhyg0reyqlhhsxs8ueghtm02g7jd2p9ytwstsacry3y8mv202y4qqqqqqqqqqqq9qq6qqqqqqqqqqqqqsqqqqqf2q5htpqqqqqqqraqqqqqqpqp5qqqqqqqqqqqqpqqqqqqqqh6uzcqqqqqqqqqqqqqqeqqrgqqqqqqqqqqqqzqqqqqqjnems5qqqpfqyv6g47m4gyqw3qa7penn7yv39ufk8j0mpj0ldfyu9qd7j2vcgt27rqe9l5ydpm2szfcs2uqczqqqtqggrxc2hkcjr5hvrn0kh6pkeecrjkdcyqn2kl6y6mymsrmqsa4wdcy2lqsxw9n37kdz8zq0ytckjvcy7jcwqrqj5rgd56q4e3g76k0lja96hzkdda7z5sxqwemrdjje2rqm7jcv6ll39dut0mqjvrl4skedkfz5mu").get() - val (trampolineAmount, trampolineExpiry, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToNonTrampolinePacket(invoice, NodeHop(randomKey().publicKey(), randomKey().publicKey(), CltvExpiryDelta(444), 0.msat), finalAmount, finalExpiry) - val trampolinePayload = PaymentAttempt.TrampolinePayload(trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet) - val channelHops: List = listOf(ChannelHop(randomKey().publicKey(), randomKey().publicKey(), defaultChannelUpdate)) - OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), randomBytes32(), channelHops, trampolinePayload.createFinalPayload(finalAmount)) + fun `prune outgoing routing info`() { + // We create an invoice with a large number of routing hints. + val routingHints = (0 until 50L).map { i -> listOf(Bolt11Invoice.TaggedField.ExtraHop(randomKey().publicKey(), ShortChannelId(i), 5.msat, 25, CltvExpiryDelta(36))) } + val invoice = Bolt11Invoice.create(Chain.Regtest, finalAmount, paymentHash, privD, Either.Left("test"), CltvExpiryDelta(6), nonTrampolineFeatures, paymentSecret, paymentMetadata, extraHops = routingHints) + // A sends an HTLC to its trampoline node B: we prune the routing hints to fit inside the onion. + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacketToLegacyRecipient(invoice, finalAmount, finalExpiry, nodeHop_bd) + assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) + val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) + val (_, innerB) = decryptRelayToLegacy(addB, privB) + assertTrue(innerB.invoiceRoutingInfo.isNotEmpty()) + assertTrue(innerB.invoiceRoutingInfo.size < routingHints.size) + innerB.invoiceRoutingInfo.forEach { assertTrue(routingHints.contains(it)) } } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt deleted file mode 100644 index e04098ea0..000000000 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/RouteCalculationTestsCommon.kt +++ /dev/null @@ -1,154 +0,0 @@ -package fr.acinq.lightning.payment - -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.utils.Either -import fr.acinq.lightning.Lightning.randomBytes32 -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId -import fr.acinq.lightning.blockchain.fee.FeeratePerKw -import fr.acinq.lightning.channel.Commitments -import fr.acinq.lightning.channel.states.Normal -import fr.acinq.lightning.channel.states.Offline -import fr.acinq.lightning.channel.states.Syncing -import fr.acinq.lightning.channel.TestsHelper.reachNormal -import fr.acinq.lightning.tests.utils.LightningTestSuite -import fr.acinq.lightning.transactions.CommitmentSpec -import fr.acinq.lightning.utils.* -import kotlin.random.Random -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertTrue - -class RouteCalculationTestsCommon : LightningTestSuite() { - - private val defaultChannel = reachNormal().first - private val paymentId = UUID.randomUUID() - private val routeCalculation = RouteCalculation(loggerFactory) - - private fun makeChannel(channelId: ByteVector32, balance: MilliSatoshi, htlcMin: MilliSatoshi): Normal { - val shortChannelId = ShortChannelId(Random.nextLong()) - val reserve = defaultChannel.commitments.latest.localChannelReserve - val commitments = defaultChannel.commitments.copy( - params = defaultChannel.commitments.params.copy(channelId = channelId), - active = defaultChannel.commitments.active.map { - it.copy(remoteCommit = it.remoteCommit.copy(spec = CommitmentSpec(setOf(), FeeratePerKw(0.sat), 50_000.msat, balance + ((Commitments.ANCHOR_AMOUNT * 2) + reserve).toMilliSatoshi()))) - } - ) - val channelUpdate = defaultChannel.state.channelUpdate.copy(htlcMinimumMsat = htlcMin) - return defaultChannel.state.copy(shortChannelId = shortChannelId, commitments = commitments, channelUpdate = channelUpdate) - } - - @Test - fun `make channel fixture`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val offlineChannels = mapOf( - channelId1 to Offline(makeChannel(channelId1, 15_000.msat, 10.msat)), - channelId2 to Offline(makeChannel(channelId2, 20_000.msat, 5.msat)), - channelId3 to Offline(makeChannel(channelId3, 10_000.msat, 10.msat)), - ) - assertEquals(setOf(10_000.msat, 15_000.msat, 20_000.msat), offlineChannels.map { (it.value.state as Normal).commitments.availableBalanceForSend() }.toSet()) - - val normalChannels = mapOf( - channelId1 to makeChannel(channelId1, 15_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 15_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 10_000.msat, 10.msat), - ) - assertEquals(setOf(10_000.msat, 15_000.msat), normalChannels.map { it.value.commitments.availableBalanceForSend() }.toSet()) - } - - @Test - fun `no available channels`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to Offline(makeChannel(channelId1, 15_000.msat, 10.msat)), - channelId2 to Syncing(makeChannel(channelId2, 20_000.msat, 5.msat), channelReestablishSent = true), - channelId3 to Offline(makeChannel(channelId3, 10_000.msat, 10.msat)), - ) - assertEquals(Either.Left(FinalFailure.ChannelNotConnected), routeCalculation.findRoutes(paymentId, 5_000.msat, channels)) - } - - @Test - fun `insufficient balance`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 15_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 18_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 12_000.msat, 10.msat), - ) - assertEquals(Either.Left(FinalFailure.InsufficientBalance), routeCalculation.findRoutes(paymentId, 50_000.msat, channels)) - } - - @Test - fun `single payment`() { - val (channelId1, channelId2, channelId3) = listOf(randomBytes32(), randomBytes32(), randomBytes32()) - run { - val channels = mapOf( - channelId1 to makeChannel(channelId1, 35_000.msat, 10.msat), - channelId2 to makeChannel(channelId2, 30_000.msat, 5.msat), - channelId3 to makeChannel(channelId3, 38_000.msat, 10.msat), - ) - val routes = routeCalculation.findRoutes(paymentId, 38_000.msat, channels).right!! - assertEquals(listOf(RouteCalculation.Route(38_000.msat, channels.getValue(channelId3))), routes) - } - run { - val channels = mapOf(channelId3 to makeChannel(channelId3, 38_000.msat, 10.msat)) - val routes = routeCalculation.findRoutes(paymentId, 38_000.msat, channels).right!! - assertEquals(listOf(RouteCalculation.Route(38_000.msat, channels.getValue(channelId3))), routes) - } - } - - @Test - fun `ignore empty channels`() { - val (channelId1, channelId2, channelId3, channelId4) = listOf(randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 0.msat, 10.msat), - channelId2 to makeChannel(channelId2, 50.msat, 100.msat), - channelId3 to makeChannel(channelId3, 30_000.msat, 15.msat), - channelId4 to makeChannel(channelId4, 20_000.msat, 50.msat), - ) - val routes = routeCalculation.findRoutes(paymentId, 50_000.msat, channels).right!! - val expected = setOf( - RouteCalculation.Route(30_000.msat, channels.getValue(channelId3)), - RouteCalculation.Route(20_000.msat, channels.getValue(channelId4)), - ) - assertEquals(expected, routes.toSet()) - assertEquals(Either.Left(FinalFailure.InsufficientBalance), routeCalculation.findRoutes(paymentId, 50010.msat, channels)) - } - - @Test - fun `split payment across many channels`() { - val (channelId1, channelId2, channelId3, channelId4) = listOf(randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32()) - val channels = mapOf( - channelId1 to makeChannel(channelId1, 50.msat, 10.msat), - channelId2 to makeChannel(channelId2, 150.msat, 100.msat), - channelId3 to makeChannel(channelId3, 25.msat, 15.msat), - channelId4 to makeChannel(channelId4, 75.msat, 50.msat), - ) - run { - val routes = routeCalculation.findRoutes(paymentId, 300.msat, channels).right!! - val expected = setOf( - RouteCalculation.Route(50.msat, channels.getValue(channelId1)), - RouteCalculation.Route(150.msat, channels.getValue(channelId2)), - RouteCalculation.Route(25.msat, channels.getValue(channelId3)), - RouteCalculation.Route(75.msat, channels.getValue(channelId4)), - ) - assertEquals(expected, routes.toSet()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 250.msat, channels).right!! - assertTrue(routes.size >= 3) - assertEquals(250.msat, routes.map { it.amount }.sum()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 200.msat, channels).right!! - assertTrue(routes.size >= 2) - assertEquals(200.msat, routes.map { it.amount }.sum()) - } - run { - val routes = routeCalculation.findRoutes(paymentId, 50.msat, channels).right!! - assertTrue(routes.size == 1) - assertEquals(50.msat, routes.map { it.amount }.sum()) - } - } - -} \ No newline at end of file