Skip to content

Commit

Permalink
Properly type Sphinx shared secrets (#2959)
Browse files Browse the repository at this point in the history
Otherwise it gets confusing quite quickly, because there are a lot of
onion-related secrets (sphinx shared secret, path key, invoice payment
secret...).
  • Loading branch information
t-bast authored Dec 5, 2024
1 parent 2ad2260 commit feef44b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
33 changes: 17 additions & 16 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,17 @@ object Sphinx extends Logging {
val isLastPacket: Boolean = nextPacket.hmac == ByteVector32.Zeroes
}

/** Shared secret used to encrypt the payload for a given node. */
case class SharedSecret(secret: ByteVector32, remoteNodeId: PublicKey)

/**
* A encrypted onion packet with all the associated shared secrets.
*
* @param packet encrypted onion packet.
* @param sharedSecrets shared secrets (one per node in the route). Known (and needed) only if you're creating the
* packet. Empty if you're just forwarding the packet to the next node.
*/
case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[(ByteVector32, PublicKey)])
case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[SharedSecret])

/**
* Generate a deterministic filler to prevent intermediate nodes from knowing their position in the route.
Expand Down Expand Up @@ -239,12 +242,12 @@ object Sphinx extends Logging {
*/
def create(sessionKey: PrivateKey, packetPayloadLength: Int, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector], associatedData: Option[ByteVector32]): Try[PacketAndSecrets] = Try {
require(payloadsTotalSize(payloads) <= packetPayloadLength, s"packet per-hop payloads cannot exceed $packetPayloadLength bytes")
val (ephemeralPublicKeys, sharedsecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", packetPayloadLength, sharedsecrets.dropRight(1), payloads.dropRight(1))
val (ephemeralPublicKeys, sharedSecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", packetPayloadLength, sharedSecrets.dropRight(1), payloads.dropRight(1))

// We deterministically-derive the initial payload bytes: see https://github.com/lightningnetwork/lightning-rfc/pull/697
val startingBytes = generateStream(generateKey("pad", sessionKey.value), packetPayloadLength)
val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedsecrets.last, Left(startingBytes), filler)
val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedSecrets.last, Left(startingBytes), filler)

@tailrec
def loop(hopPayloads: Seq[ByteVector], ephKeys: Seq[PublicKey], sharedSecrets: Seq[ByteVector32], packet: OnionRoutingPacket): OnionRoutingPacket = {
Expand All @@ -254,8 +257,8 @@ object Sphinx extends Logging {
}
}

val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedsecrets.dropRight(1), lastPacket)
PacketAndSecrets(packet, sharedsecrets.zip(publicKeys))
val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedSecrets.dropRight(1), lastPacket)
PacketAndSecrets(packet, sharedSecrets.zip(publicKeys).map { case (secret, remoteNodeId) => SharedSecret(secret, remoteNodeId) })
}

/**
Expand Down Expand Up @@ -324,20 +327,18 @@ object Sphinx extends Logging {
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
* failure packet otherwise.
*/
def decrypt(packet: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = {
@tailrec
def loop(packet: ByteVector, secrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = secrets match {
@tailrec
def decrypt(packet: ByteVector, sharedSecrets: Seq[SharedSecret]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = {
sharedSecrets match {
case Nil => Left(CannotDecryptFailurePacket(packet))
case (secret, pubkey) :: tail =>
val packet1 = wrap(packet, secret)
val um = generateKey("um", secret)
case ss :: tail =>
val packet1 = wrap(packet, ss.secret)
val um = generateKey("um", ss.secret)
FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
case Attempt.Successful(value) => Right(DecryptedFailurePacket(pubkey, value.value))
case _ => loop(packet1, tail)
case Attempt.Successful(value) => Right(DecryptedFailurePacket(ss.remoteNodeId, value.value))
case _ => decrypt(packet1, tail)
}
}

loop(packet, sharedSecrets)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ object IncomingPaymentPacket {
* @param outgoingChannel channel to send the HTLC to.
* @param sharedSecrets shared secrets (used to decrypt the error in case of payment failure).
*/
case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[(ByteVector32, PublicKey)])
case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[Sphinx.SharedSecret])

/** Helpers to create outgoing payment packets. */
object OutgoingPaymentPacket {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ object PaymentLifecycle {
sealed trait Data
case object WaitingForRequest extends Data
case class WaitingForRoute(request: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data
case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data {
case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[Sphinx.SharedSecret], ignore: Ignore, route: Route) extends Data {
val recipient = request.recipient
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == referencePaymentPayloads)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))

val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4)
assert(packets(0).hmac == ByteVector32(hex"901fb2bb905d1cfac67727f900daa2bb9da6801ac31ccce78663e5021e83983b"))
Expand All @@ -159,7 +159,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == paymentPayloadsFull)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))

val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4)
assert(packets(0).hmac == ByteVector32(hex"859cd694cf604442547246f4fae144f255e71e30cb366b9775f488cac713f0db"))
Expand Down Expand Up @@ -196,7 +196,7 @@ class SphinxSpec extends AnyFunSuite {
val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2)
val Right(DecryptedPacket(payload4, _, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3)
assert(Seq(payload0, payload1, payload2, payload3, payload4) == trampolinePaymentPayloads)
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1))
assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret))
}

test("create packet with invalid payload") {
Expand Down Expand Up @@ -229,19 +229,19 @@ class SphinxSpec extends AnyFunSuite {
val packet1 = FailurePacket.create(sharedSecrets.head, expected.failureMessage)
assert(packet1.length == 292)

val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted1 == expected)

val packet2 = FailurePacket.wrap(packet1, sharedSecrets(1))
assert(packet2.length == 292)

val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted2 == expected)

val packet3 = FailurePacket.wrap(packet2, sharedSecrets(2))
assert(packet3.length == 292)

val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => (sharedSecrets(i), publicKeys(i))))
val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i))))
assert(decrypted3 == expected)
}

Expand All @@ -258,7 +258,7 @@ class SphinxSpec extends AnyFunSuite {
sharedSecrets(1)),
sharedSecrets(2))

assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => (sharedSecrets(i), publicKeys(i)))).isLeft)
assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))).isLeft)
}

test("last node replies with a short failure message (old reference test vector)") {
Expand Down Expand Up @@ -565,7 +565,7 @@ class SphinxSpec extends AnyFunSuite {
assert(payloadEve.allowedFeatures.isEmpty)

assert(Seq(onionPayloadAlice, onionPayloadBob, onionPayloadCarol, onionPayloadDave, onionPayloadEve) == payloads)
assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_._1))
assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_.secret))

val packets = Seq(packetForBob, packetForCarol, packetForDave, packetForEve, packetForNobody)
assert(packets(0).hmac == ByteVector32(hex"73fba184685e19b9af78afe876aa4e4b4242382b293133771d95a2bd83fa9c62"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1))
val failure = TemporaryChannelFailure(Some(update_bc))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// payment lifecycle will ask the router to temporarily exclude this channel from its route calculations
assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId)
Expand Down Expand Up @@ -533,7 +533,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING
routerForwarder.expectMsg(defaultRouteRequest(a, cfg))
Expand All @@ -548,7 +548,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(43), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure2 = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified_2))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head._1, failure2)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head.secret, failure2)))))

// this time the payment lifecycle will ask the router to temporarily exclude this channel from its route calculations
routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration)))
Expand Down Expand Up @@ -578,7 +578,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

// the node replies with a temporary failure containing the same update as the one we already have (likely a balance issue)
val failure = TemporaryChannelFailure(Some(update_bc))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))
// we should temporarily exclude that channel
assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId)
routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration)))
Expand Down Expand Up @@ -612,7 +612,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING
val extraEdges1 = Seq(
Expand Down Expand Up @@ -651,7 +651,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
// we disable the channel
val channelUpdate_cd_disabled = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, scid_cd, CltvExpiryDelta(42), update_cd.htlcMinimumMsat, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.htlcMaximumMsat, enable = false)
val failure = ChannelDisabled(channelUpdate_cd_disabled.messageFlags, channelUpdate_cd_disabled.channelFlags, Some(channelUpdate_cd_disabled))
val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1)
val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1).secret, failure), sharedSecrets1.head.secret)
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion))))

assert(routerForwarder.expectMsgType[RouteCouldRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc).map(_.shortChannelId))
Expand All @@ -674,7 +674,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData

register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// payment lifecycle forwards the embedded channelUpdate to the router
awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE)
Expand Down Expand Up @@ -713,7 +713,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {

// The payment fails inside the blinded route: the introduction node sends back an error.
val failure = InvalidOnionBlinding(randomBytes32())
val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head._1, failure)
val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head.secret, failure)
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion))))

// We retry but we exclude the failed blinded route.
Expand Down Expand Up @@ -955,7 +955,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec {
val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat)
val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified))
// and node replies with a failure containing a new channel update
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure)))))
sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure)))))

// The payment fails without retrying
sender.expectMsgType[PaymentFailed]
Expand Down

0 comments on commit feef44b

Please sign in to comment.