diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index 3cec8ebfa4..1fdfa1bab1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -108,6 +108,9 @@ object Sphinx extends Logging { val isLastPacket: Boolean = nextPacket.hmac == ByteVector32.Zeroes } + /** Shared secret used to encrypt the payload for a given node. */ + case class SharedSecret(secret: ByteVector32, remoteNodeId: PublicKey) + /** * A encrypted onion packet with all the associated shared secrets. * @@ -115,7 +118,7 @@ object Sphinx extends Logging { * @param sharedSecrets shared secrets (one per node in the route). Known (and needed) only if you're creating the * packet. Empty if you're just forwarding the packet to the next node. */ - case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[(ByteVector32, PublicKey)]) + case class PacketAndSecrets(packet: OnionRoutingPacket, sharedSecrets: Seq[SharedSecret]) /** * Generate a deterministic filler to prevent intermediate nodes from knowing their position in the route. @@ -239,12 +242,12 @@ object Sphinx extends Logging { */ def create(sessionKey: PrivateKey, packetPayloadLength: Int, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector], associatedData: Option[ByteVector32]): Try[PacketAndSecrets] = Try { require(payloadsTotalSize(payloads) <= packetPayloadLength, s"packet per-hop payloads cannot exceed $packetPayloadLength bytes") - val (ephemeralPublicKeys, sharedsecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys) - val filler = generateFiller("rho", packetPayloadLength, sharedsecrets.dropRight(1), payloads.dropRight(1)) + val (ephemeralPublicKeys, sharedSecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys) + val filler = generateFiller("rho", packetPayloadLength, sharedSecrets.dropRight(1), payloads.dropRight(1)) // We deterministically-derive the initial payload bytes: see https://github.com/lightningnetwork/lightning-rfc/pull/697 val startingBytes = generateStream(generateKey("pad", sessionKey.value), packetPayloadLength) - val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedsecrets.last, Left(startingBytes), filler) + val lastPacket = wrap(payloads.last, associatedData, ephemeralPublicKeys.last, sharedSecrets.last, Left(startingBytes), filler) @tailrec def loop(hopPayloads: Seq[ByteVector], ephKeys: Seq[PublicKey], sharedSecrets: Seq[ByteVector32], packet: OnionRoutingPacket): OnionRoutingPacket = { @@ -254,8 +257,8 @@ object Sphinx extends Logging { } } - val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedsecrets.dropRight(1), lastPacket) - PacketAndSecrets(packet, sharedsecrets.zip(publicKeys)) + val packet = loop(payloads.dropRight(1), ephemeralPublicKeys.dropRight(1), sharedSecrets.dropRight(1), lastPacket) + PacketAndSecrets(packet, sharedSecrets.zip(publicKeys).map { case (secret, remoteNodeId) => SharedSecret(secret, remoteNodeId) }) } /** @@ -324,20 +327,18 @@ object Sphinx extends Logging { * @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped * failure packet otherwise. */ - def decrypt(packet: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = { - @tailrec - def loop(packet: ByteVector, secrets: Seq[(ByteVector32, PublicKey)]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = secrets match { + @tailrec + def decrypt(packet: ByteVector, sharedSecrets: Seq[SharedSecret]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = { + sharedSecrets match { case Nil => Left(CannotDecryptFailurePacket(packet)) - case (secret, pubkey) :: tail => - val packet1 = wrap(packet, secret) - val um = generateKey("um", secret) + case ss :: tail => + val packet1 = wrap(packet, ss.secret) + val um = generateKey("um", ss.secret) FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match { - case Attempt.Successful(value) => Right(DecryptedFailurePacket(pubkey, value.value)) - case _ => loop(packet1, tail) + case Attempt.Successful(value) => Right(DecryptedFailurePacket(ss.remoteNodeId, value.value)) + case _ => decrypt(packet1, tail) } } - - loop(packet, sharedSecrets) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index e42ecc9c9a..bb696d1caa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -282,7 +282,7 @@ object IncomingPaymentPacket { * @param outgoingChannel channel to send the HTLC to. * @param sharedSecrets shared secrets (used to decrypt the error in case of payment failure). */ -case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[(ByteVector32, PublicKey)]) +case class OutgoingPaymentPacket(cmd: CMD_ADD_HTLC, outgoingChannel: ShortChannelId, sharedSecrets: Seq[Sphinx.SharedSecret]) /** Helpers to create outgoing payment packets. */ object OutgoingPaymentPacket { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 865505e0b1..8663e59b0a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -476,7 +476,7 @@ object PaymentLifecycle { sealed trait Data case object WaitingForRequest extends Data case class WaitingForRoute(request: SendPayment, failures: Seq[PaymentFailure], ignore: Ignore) extends Data - case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[(ByteVector32, PublicKey)], ignore: Ignore, route: Route) extends Data { + case class WaitingForComplete(request: SendPayment, cmd: CMD_ADD_HTLC, failures: Seq[PaymentFailure], sharedSecrets: Seq[Sphinx.SharedSecret], ignore: Ignore, route: Route) extends Data { val recipient = request.recipient } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 05bdf7a4ff..23dc635138 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -139,7 +139,7 @@ class SphinxSpec extends AnyFunSuite { val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2) val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3) assert(Seq(payload0, payload1, payload2, payload3, payload4) == referencePaymentPayloads) - assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1)) + assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret)) val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4) assert(packets(0).hmac == ByteVector32(hex"901fb2bb905d1cfac67727f900daa2bb9da6801ac31ccce78663e5021e83983b")) @@ -159,7 +159,7 @@ class SphinxSpec extends AnyFunSuite { val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2) val Right(DecryptedPacket(payload4, nextPacket4, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3) assert(Seq(payload0, payload1, payload2, payload3, payload4) == paymentPayloadsFull) - assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1)) + assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret)) val packets = Seq(nextPacket0, nextPacket1, nextPacket2, nextPacket3, nextPacket4) assert(packets(0).hmac == ByteVector32(hex"859cd694cf604442547246f4fae144f255e71e30cb366b9775f488cac713f0db")) @@ -196,7 +196,7 @@ class SphinxSpec extends AnyFunSuite { val Right(DecryptedPacket(payload3, nextPacket3, sharedSecret3)) = peel(privKeys(3), associatedData, nextPacket2) val Right(DecryptedPacket(payload4, _, sharedSecret4)) = peel(privKeys(4), associatedData, nextPacket3) assert(Seq(payload0, payload1, payload2, payload3, payload4) == trampolinePaymentPayloads) - assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_._1)) + assert(Seq(sharedSecret0, sharedSecret1, sharedSecret2, sharedSecret3, sharedSecret4) == sharedSecrets.map(_.secret)) } test("create packet with invalid payload") { @@ -229,19 +229,19 @@ class SphinxSpec extends AnyFunSuite { val packet1 = FailurePacket.create(sharedSecrets.head, expected.failureMessage) assert(packet1.length == 292) - val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => (sharedSecrets(i), publicKeys(i)))) + val Right(decrypted1) = FailurePacket.decrypt(packet1, Seq(0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))) assert(decrypted1 == expected) val packet2 = FailurePacket.wrap(packet1, sharedSecrets(1)) assert(packet2.length == 292) - val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => (sharedSecrets(i), publicKeys(i)))) + val Right(decrypted2) = FailurePacket.decrypt(packet2, Seq(1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))) assert(decrypted2 == expected) val packet3 = FailurePacket.wrap(packet2, sharedSecrets(2)) assert(packet3.length == 292) - val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => (sharedSecrets(i), publicKeys(i)))) + val Right(decrypted3) = FailurePacket.decrypt(packet3, Seq(2, 1, 0).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))) assert(decrypted3 == expected) } @@ -258,7 +258,7 @@ class SphinxSpec extends AnyFunSuite { sharedSecrets(1)), sharedSecrets(2)) - assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => (sharedSecrets(i), publicKeys(i)))).isLeft) + assert(FailurePacket.decrypt(packet, Seq(0, 2, 1).map(i => SharedSecret(sharedSecrets(i), publicKeys(i)))).isLeft) } test("last node replies with a short failure message (old reference test vector)") { @@ -565,7 +565,7 @@ class SphinxSpec extends AnyFunSuite { assert(payloadEve.allowedFeatures.isEmpty) assert(Seq(onionPayloadAlice, onionPayloadBob, onionPayloadCarol, onionPayloadDave, onionPayloadEve) == payloads) - assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_._1)) + assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_.secret)) val packets = Seq(packetForBob, packetForCarol, packetForDave, packetForEve, packetForNobody) assert(packets(0).hmac == ByteVector32(hex"73fba184685e19b9af78afe876aa4e4b4242382b293133771d95a2bd83fa9c62")) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 8b2bf22c63..597f50c3fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -500,7 +500,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1)) val failure = TemporaryChannelFailure(Some(update_bc)) - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) // payment lifecycle will ask the router to temporarily exclude this channel from its route calculations assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId) @@ -533,7 +533,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat) val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified)) // and node replies with a failure containing a new channel update - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING routerForwarder.expectMsg(defaultRouteRequest(a, cfg)) @@ -548,7 +548,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_bc_modified_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(43), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat) val failure2 = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified_2)) // and node replies with a failure containing a new channel update - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head._1, failure2))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets2.head.secret, failure2))))) // this time the payment lifecycle will ask the router to temporarily exclude this channel from its route calculations routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration))) @@ -578,7 +578,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // the node replies with a temporary failure containing the same update as the one we already have (likely a balance issue) val failure = TemporaryChannelFailure(Some(update_bc)) - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) // we should temporarily exclude that channel assert(routerForwarder.expectMsgType[ChannelCouldNotRelay].hop.shortChannelId == update_bc.shortChannelId) routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_bc.shortChannelId, b, c), Some(nodeParams.routerConf.channelExcludeDuration))) @@ -612,7 +612,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat) val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified)) // and node replies with a failure containing a new channel update - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING val extraEdges1 = Seq( @@ -651,7 +651,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // we disable the channel val channelUpdate_cd_disabled = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, scid_cd, CltvExpiryDelta(42), update_cd.htlcMinimumMsat, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.htlcMaximumMsat, enable = false) val failure = ChannelDisabled(channelUpdate_cd_disabled.messageFlags, channelUpdate_cd_disabled.channelFlags, Some(channelUpdate_cd_disabled)) - val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1) + val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1).secret, failure), sharedSecrets1.head.secret) sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)))) assert(routerForwarder.expectMsgType[RouteCouldRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc).map(_.shortChannelId)) @@ -674,7 +674,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val WaitingForComplete(_, cmd1, Nil, sharedSecrets1, _, route1) = paymentFSM.stateData register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1)) - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) // payment lifecycle forwards the embedded channelUpdate to the router awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) @@ -713,7 +713,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // The payment fails inside the blinded route: the introduction node sends back an error. val failure = InvalidOnionBlinding(randomBytes32()) - val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head._1, failure) + val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head.secret, failure) sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)))) // We retry but we exclude the failed blinded route. @@ -955,7 +955,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_bc_modified = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, c, scid_bc, CltvExpiryDelta(42), htlcMinimumMsat = update_bc.htlcMinimumMsat, feeBaseMsat = update_bc.feeBaseMsat, feeProportionalMillionths = update_bc.feeProportionalMillionths, htlcMaximumMsat = update_bc.htlcMaximumMsat) val failure = IncorrectCltvExpiry(CltvExpiry(5), Some(channelUpdate_bc_modified)) // and node replies with a failure containing a new channel update - sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head._1, failure))))) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, Sphinx.FailurePacket.create(sharedSecrets1.head.secret, failure))))) // The payment fails without retrying sender.expectMsgType[PaymentFailed]