Skip to content

Commit

Permalink
Add support for trampoline failure encryption
Browse files Browse the repository at this point in the history
When returning trampoline failures for the payer (the creator of the
trampoline onion), they must be encrypted using the sphinx shared
secret of the trampoline onion.

When relaying a trampoline payment, we re-wrap the (peeled) trampoline
onion inside a payment onion: if we receive a failure for the outgoing
payment, it can be either coming from before the next trampoline node
or after them. If it's coming from before, we can decrypt that error
using the shared secrets we created for the payment onion: depending
on the error, we can then return our own error to the payer. If it's
coming from after the next trampoline onion, it will be encrypted for
the payer, so we cannot decrypt it. We must peel the shared secrets of
our payment onion, and then re-encrypted with the shared secret of the
incoming trampoline onion. This way only the payer will be able to
decrypt the failure, which is relayed back through each intermediate
trampoline node.
  • Loading branch information
t-bast committed Dec 6, 2024
1 parent db25da3 commit c333cd4
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ object Monitoring {
def apply(cmdFail: CMD_FAIL_HTLC): String = cmdFail.reason match {
case _: FailureReason.EncryptedDownstreamFailure => Remote
case FailureReason.LocalFailure(f) => f.getClass.getSimpleName
case FailureReason.LocalTrampolineFailure(f) => f.getClass.getSimpleName
}

def apply(pf: PaymentFailure): String = pf match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,23 +366,49 @@ object OutgoingPaymentPacket {
}

private def buildHtlcFailure(nodeSecret: PrivateKey, reason: FailureReason, add: UpdateAddHtlc): Either[CannotExtractSharedSecret, ByteVector] = {
extractSharedSecret(nodeSecret, add).map(sharedSecret => {
extractSharedSecret(nodeSecret, add).map(ss => {
reason match {
case FailureReason.EncryptedDownstreamFailure(packet) => Sphinx.FailurePacket.wrap(packet, sharedSecret)
case FailureReason.LocalFailure(failure) => Sphinx.FailurePacket.create(sharedSecret, failure)
case FailureReason.EncryptedDownstreamFailure(packet) =>
ss.trampolineOnionSecret_opt match {
case Some(trampolineOnionSecret) =>
// If we are unable to decrypt the downstream failure and the payment is using trampoline, the failure is
// intended for the payer. We encrypt it with the trampoline secret first and then the outer secret.
Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.wrap(packet, trampolineOnionSecret), ss.outerOnionSecret)
case None => Sphinx.FailurePacket.wrap(packet, ss.outerOnionSecret)
}
case FailureReason.LocalFailure(failure) =>
// This isn't a trampoline failure, so we only encrypt it for the node who created the outer onion.
Sphinx.FailurePacket.create(ss.outerOnionSecret, failure)
case FailureReason.LocalTrampolineFailure(failure) =>
// This is a trampoline failure: we try to encrypt it to the node who created the trampoline onion.
ss.trampolineOnionSecret_opt match {
case Some(trampolineOnionSecret) => Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(trampolineOnionSecret, failure), ss.outerOnionSecret)
case None => Sphinx.FailurePacket.create(ss.outerOnionSecret, failure) // this shouldn't happen, we only generate trampoline failures when there was a trampoline onion
}
}
})
}

private case class HtlcSharedSecrets(outerOnionSecret: ByteVector32, trampolineOnionSecret_opt: Option[ByteVector32])

/**
* We decrypt the onion again to extract the shared secret used to encrypt onion failures.
* We could avoid this by storing the shared secret after the initial onion decryption, but we would have to store it
* in the database since we must be able to fail HTLCs after restarting our node.
* It's simpler to extract it again from the encrypted onion.
*/
private def extractSharedSecret(nodeSecret: PrivateKey, add: UpdateAddHtlc): Either[CannotExtractSharedSecret, ByteVector32] = {
private def extractSharedSecret(nodeSecret: PrivateKey, add: UpdateAddHtlc): Either[CannotExtractSharedSecret, HtlcSharedSecrets] = {
Sphinx.peel(nodeSecret, Some(add.paymentHash), add.onionRoutingPacket) match {
case Right(Sphinx.DecryptedPacket(_, _, sharedSecret)) => Right(sharedSecret)
case Right(Sphinx.DecryptedPacket(payload, _, outerOnionSecret)) =>
// Let's look at the onion payload to see if it contains a trampoline onion.
PaymentOnionCodecs.perHopPayloadCodec.decode(payload.bits) match {
case Attempt.Successful(DecodeResult(perHopPayload, _)) =>
perHopPayload.get[OnionPaymentPayloadTlv.TrampolineOnion].flatMap(p => Sphinx.peel(nodeSecret, Some(add.paymentHash), p.packet).toOption) match {
case Some(Sphinx.DecryptedPacket(_, _, trampolineOnionSecret)) => Right(HtlcSharedSecrets(outerOnionSecret, Some(trampolineOnionSecret)))
case None => Right(HtlcSharedSecrets(outerOnionSecret, None))
}
case Attempt.Failure(_) => Right(HtlcSharedSecrets(outerOnionSecret, None))
}
case Left(_) => Left(CannotExtractSharedSecret(add.channelId, add))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,12 @@ object MultiPartHandler {

private def validateStandardPayment(nodeParams: NodeParams, add: UpdateAddHtlc, payload: FinalPayload.Standard, record: IncomingStandardPayment)(implicit log: LoggingAdapter): Option[CMD_FAIL_HTLC] = {
// We send the same error regardless of the failure to avoid probing attacks.
val cmdFail = CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(payload.totalAmount, nodeParams.currentBlockHeight)), commit = true)
val failure = if (payload.isTrampoline) {
FailureReason.LocalTrampolineFailure(IncorrectOrUnknownPaymentDetails(payload.totalAmount, nodeParams.currentBlockHeight))
} else {
FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(payload.totalAmount, nodeParams.currentBlockHeight))
}
val cmdFail = CMD_FAIL_HTLC(add.id, failure, commit = true)
val commonOk = validateCommon(nodeParams, add, payload, record)
val secretOk = validatePaymentSecret(add, payload, record.invoice)
if (commonOk && secretOk) None else Some(cmdFail)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,15 @@ object NodeRelay {
val amountOut = outgoingAmount(upstream, payloadOut)
val expiryOut = outgoingExpiry(upstream, payloadOut)
val fee = nodeFee(nodeParams.relayParams.minTrampolineFees, amountOut)
// We don't know yet how costly it is to reach the next node: we use a rough first estimate of twice our trampoline
// fees. If we fail to find routes, we will return a different error with higher fees and expiry delta.
val failure = TrampolineFeeOrExpiryInsufficient(nodeParams.relayParams.minTrampolineFees.feeBase * 2, nodeParams.relayParams.minTrampolineFees.feeProportionalMillionths * 2, nodeParams.channelConf.expiryDelta * 2)
if (upstream.amountIn - amountOut < fee) {
Some(TrampolineFeeInsufficient())
Some(failure)
} else if (upstream.expiryIn - expiryOut < nodeParams.channelConf.expiryDelta) {
Some(TrampolineExpiryTooSoon())
Some(failure)
} else if (expiryOut <= CltvExpiry(nodeParams.currentBlockHeight)) {
Some(TrampolineExpiryTooSoon())
Some(failure)
} else if (amountOut <= MilliSatoshi(0)) {
Some(InvalidOnionPayload(UInt64(2), 0))
} else {
Expand Down Expand Up @@ -181,31 +184,40 @@ object NodeRelay {
* This helper method translates relaying errors (returned by the downstream nodes) to a BOLT 4 standard error that we
* should return upstream.
*/
private def translateError(nodeParams: NodeParams, failures: Seq[PaymentFailure], upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay): Option[FailureMessage] = {
private def translateError(nodeParams: NodeParams, failures: Seq[PaymentFailure], upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay): FailureReason = {
val amountOut = outgoingAmount(upstream, nextPayload)
val routeNotFound = failures.collectFirst { case f@LocalFailure(_, _, RouteNotFound) => f }.nonEmpty
val routingFeeHigh = upstream.amountIn - amountOut >= nodeFee(nodeParams.relayParams.minTrampolineFees, amountOut) * 5
val trampolineFeesFailure = TrampolineFeeOrExpiryInsufficient(nodeParams.relayParams.minTrampolineFees.feeBase * 5, nodeParams.relayParams.minTrampolineFees.feeProportionalMillionths * 5, nodeParams.channelConf.expiryDelta * 5)
// We select the best error we can from our downstream attempts.
failures match {
case Nil => None
case Nil => FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure())
case LocalFailure(_, _, BalanceTooLow) :: Nil if routingFeeHigh =>
// We have direct channels to the target node, but not enough outgoing liquidity to use those channels.
// The routing fee proposed by the sender was high enough to find alternative, indirect routes, but didn't yield
// any result so we tell them that we don't have enough outgoing liquidity at the moment.
Some(TemporaryNodeFailure())
case LocalFailure(_, _, BalanceTooLow) :: Nil => Some(TrampolineFeeInsufficient()) // a higher fee/cltv may find alternative, indirect routes
case _ if routeNotFound => Some(TrampolineFeeInsufficient()) // if we couldn't find routes, it's likely that the fee/cltv was insufficient
// The routing fee proposed by the sender was high enough to find alternative, indirect routes, but didn't
// yield any result so we tell them that we don't have enough outgoing liquidity at the moment.
FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure())
case LocalFailure(_, _, BalanceTooLow) :: Nil =>
// A higher fee/cltv may find alternative, indirect routes.
FailureReason.LocalTrampolineFailure(trampolineFeesFailure)
case _ if routeNotFound =>
// If we couldn't find routes, it's likely that the fee/cltv was insufficient.
FailureReason.LocalTrampolineFailure(trampolineFeesFailure)
case _ =>
// Otherwise, we try to find a downstream error that we could decrypt.
val outgoingNodeFailure = nextPayload match {
case nextPayload: IntermediatePayload.NodeRelay.Standard => failures.collectFirst { case RemoteFailure(_, _, e) if e.originNode == nextPayload.outgoingNodeId => e.failureMessage }
case nextPayload: IntermediatePayload.NodeRelay.ToNonTrampoline => failures.collectFirst { case RemoteFailure(_, _, e) if e.originNode == nextPayload.outgoingNodeId => e.failureMessage }
nextPayload match {
case _: IntermediatePayload.NodeRelay.Standard =>
// If we received a failure from the next trampoline node, we won't be able to decrypt it: we should encrypt
// it with our trampoline shared secret and relay it upstream, because only the sender can decrypt it.
failures.collectFirst { case UnreadableRemoteFailure(_, _, packet) => FailureReason.EncryptedDownstreamFailure(packet) }
.getOrElse(FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure()))
case nextPayload: IntermediatePayload.NodeRelay.ToNonTrampoline =>
// The recipient doesn't support trampoline: if we received a failure from them, we forward it upstream.
failures.collectFirst { case RemoteFailure(_, _, e) if e.originNode == nextPayload.outgoingNodeId => FailureReason.LocalFailure(e.failureMessage) }
.getOrElse(FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure()))
// When using blinded paths, we will never get a failure from the final node (for privacy reasons).
case _: IntermediatePayload.NodeRelay.Blinded => None
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => None
case _: IntermediatePayload.NodeRelay.Blinded => FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure())
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure())
}
val otherNodeFailure = failures.collectFirst { case RemoteFailure(_, _, e) => e.failureMessage }
val failure = outgoingNodeFailure.getOrElse(otherNodeFailure.getOrElse(TemporaryNodeFailure()))
Some(failure)
}
}

Expand Down Expand Up @@ -245,15 +257,17 @@ class NodeRelay private(nodeParams: NodeParams,
case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) =>
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.amount, Some(failure)) }
// Note that we don't treat this as a trampoline failure, which would be encrypted for the payer.
// This is a failure of the previous trampoline node who didn't send a valid MPP payment.
parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.amount, Some(FailureReason.LocalFailure(failure))) }
stopping()
case WrappedMultiPartPaymentSucceeded(MultiPartPaymentFSM.MultiPartPaymentSucceeded(_, parts)) =>
context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.amount).sum)
val upstream = Upstream.Hot.Trampoline(htlcs.toList)
validateRelay(nodeParams, upstream, nextPayload) match {
case Some(failure) =>
context.log.warn(s"rejecting trampoline payment reason=$failure")
rejectPayment(upstream, Some(failure))
rejectPayment(upstream, FailureReason.LocalTrampolineFailure(failure), nextPayload.isLegacy)
stopping()
case None =>
resolveNextNode(upstream, nextPayload, nextPacket_opt)
Expand Down Expand Up @@ -288,7 +302,7 @@ class NodeRelay private(nodeParams: NodeParams,
ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt)
case WrappedOutgoingNodeId(None) =>
context.log.warn("rejecting trampoline payment to blinded trampoline: cannot identify next node for scid={}", payloadOut.outgoing)
rejectPayment(upstream, Some(UnknownNextPeer()))
rejectPayment(upstream, FailureReason.LocalTrampolineFailure(UnknownNextPeer()), nextPayload.isLegacy)
stopping()
}
}
Expand All @@ -308,7 +322,7 @@ class NodeRelay private(nodeParams: NodeParams,
rejectExtraHtlcPartialFunction orElse {
case WrappedResolvedPaths(resolved) if resolved.isEmpty =>
context.log.warn("rejecting trampoline payment to blinded paths: no usable blinded path")
rejectPayment(upstream, Some(UnknownNextPeer()))
rejectPayment(upstream, FailureReason.LocalTrampolineFailure(UnknownNextPeer()), nextPayload.isLegacy)
stopping()
case WrappedResolvedPaths(resolved) =>
// We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient.
Expand Down Expand Up @@ -344,7 +358,7 @@ class NodeRelay private(nodeParams: NodeParams,
rejectExtraHtlcPartialFunction orElse {
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
context.log.warn("rejecting payment: failed to wake-up remote peer")
rejectPayment(upstream, Some(UnknownNextPeer()))
rejectPayment(upstream, FailureReason.LocalTrampolineFailure(UnknownNextPeer()), nextPayload.isLegacy)
stopping()
case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) =>
relay(upstream, recipient, Some(walletNodeId), Some(r.remoteFeatures), nextPayload, nextPacket_opt)
Expand Down Expand Up @@ -420,7 +434,7 @@ class NodeRelay private(nodeParams: NodeParams,
context.log.info("trampoline payment failed, attempting on-the-fly funding")
attemptOnTheFlyFunding(upstream, walletNodeId, recipient, nextPayload, failures, startedAt)
case _ =>
rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload))
rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload), nextPayload.isLegacy)
recordRelayDuration(startedAt, isSuccess = false)
stopping()
}
Expand All @@ -443,7 +457,7 @@ class NodeRelay private(nodeParams: NodeParams,
OutgoingPaymentPacket.buildOutgoingPayment(Origin.Hot(ActorRef.noSender, upstream), paymentHash, dummyRoute, recipient, 1.0) match {
case Left(f) =>
context.log.warn("could not create payment onion for on-the-fly funding: {}", f.getMessage)
rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload))
rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload), nextPayload.isLegacy)
recordRelayDuration(startedAt, isSuccess = false)
stopping()
case Right(nextPacket) =>
Expand All @@ -462,7 +476,7 @@ class NodeRelay private(nodeParams: NodeParams,
stopping()
case ProposeOnTheFlyFundingResponse.NotAvailable(reason) =>
context.log.warn("could not propose on-the-fly funding: {}", reason)
rejectPayment(upstream, Some(UnknownNextPeer()))
rejectPayment(upstream, FailureReason.LocalTrampolineFailure(UnknownNextPeer()), nextPayload.isLegacy)
recordRelayDuration(startedAt, isSuccess = false)
stopping()
}
Expand Down Expand Up @@ -501,15 +515,30 @@ class NodeRelay private(nodeParams: NodeParams,
rejectHtlc(add.id, add.channelId, add.amountMsat)
}

private def rejectHtlc(htlcId: Long, channelId: ByteVector32, amount: MilliSatoshi, failure: Option[FailureMessage] = None): Unit = {
val failureMessage = failure.getOrElse(IncorrectOrUnknownPaymentDetails(amount, nodeParams.currentBlockHeight))
val cmd = CMD_FAIL_HTLC(htlcId, FailureReason.LocalFailure(failureMessage), commit = true)
private def rejectHtlc(htlcId: Long, channelId: ByteVector32, amount: MilliSatoshi, failure_opt: Option[FailureReason] = None): Unit = {
val failure = failure_opt.getOrElse(FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(amount, nodeParams.currentBlockHeight)))
val cmd = CMD_FAIL_HTLC(htlcId, failure, commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd)
}

private def rejectPayment(upstream: Upstream.Hot.Trampoline, failure: Option[FailureMessage]): Unit = {
Metrics.recordPaymentRelayFailed(failure.map(_.getClass.getSimpleName).getOrElse("Unknown"), Tags.RelayType.Trampoline)
upstream.received.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, failure))
private def rejectPayment(upstream: Upstream.Hot.Trampoline, failure: FailureReason, isLegacy: Boolean): Unit = {
val failure1 = failure match {
case failure: FailureReason.EncryptedDownstreamFailure =>
Metrics.recordPaymentRelayFailed("Unknown", Tags.RelayType.Trampoline)
failure
case failure: FailureReason.LocalFailure =>
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
failure
case failure: FailureReason.LocalTrampolineFailure =>
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
if (isLegacy) {
// The payer won't be able to decrypt our trampoline failure: we use a legacy failure for backwards-compat.
FailureReason.LocalFailure(LegacyTrampolineFeeInsufficient())
} else {
failure
}
}
upstream.received.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, Some(failure1)))
}

private def fulfillPayment(upstream: Upstream.Hot.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.received.foreach(r => {
Expand Down
Loading

0 comments on commit c333cd4

Please sign in to comment.