Skip to content

Commit

Permalink
Add EncodedNodeId for mobile wallets (#2867)
Browse files Browse the repository at this point in the history
We define a new type of `EncodedNodeId` that can be provided in blinded
paths to let a wallet provider know that the next node is a mobile
wallet.
  • Loading branch information
t-bast authored Jun 13, 2024
1 parent 741ac49 commit 3277e6d
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 59 deletions.
21 changes: 16 additions & 5 deletions eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,29 @@ package fr.acinq.eclair

import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey

/** Identifying information for a remote node, used in blinded paths and onion contents. */
sealed trait EncodedNodeId

object EncodedNodeId {
/** Nodes are usually identified by their public key. */
case class Plain(publicKey: PublicKey) extends EncodedNodeId {
override def toString: String = publicKey.toString
}

def apply(publicKey: PublicKey): EncodedNodeId = WithPublicKey.Plain(publicKey)

/** For compactness, nodes may be identified by the shortChannelId of one of their public channels. */
case class ShortChannelIdDir(isNode1: Boolean, scid: RealShortChannelId) extends EncodedNodeId {
override def toString: String = if (isNode1) s"<-$scid" else s"$scid->"
}

def apply(publicKey: PublicKey): EncodedNodeId = Plain(publicKey)
// @formatter:off
sealed trait WithPublicKey extends EncodedNodeId { def publicKey: PublicKey }
object WithPublicKey {
/** Standard case where a node is identified by its public key. */
case class Plain(publicKey: PublicKey) extends WithPublicKey { override def toString: String = publicKey.toString }
/**
* Wallet nodes are not part of the public graph, and may not have channels yet.
* Wallet providers are usually able to contact such nodes using push notifications or similar mechanisms.
*/
case class Wallet(publicKey: PublicKey) extends WithPublicKey { override def toString: String = publicKey.toString }
}
// @formatter:on

}
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ object Sphinx extends Logging {
e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes)))
(BlindedNode(blindedPublicKey, encryptedPayload ++ mac), blindingKey)
}.unzip
BlindedRouteDetails(BlindedRoute(EncodedNodeId(publicKeys.head), blindingKeys.head, blindedHops), blindingKeys.last)
BlindedRouteDetails(BlindedRoute(EncodedNodeId.WithPublicKey.Plain(publicKeys.head), blindingKeys.head, blindedHops), blindingKeys.last)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ private class MessageRelay(nodeParams: NodeParams,
case Right(EncodedNodeId.ShortChannelIdDir(isNode1, scid)) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid)
case Right(EncodedNodeId.Plain(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
case Right(encodedNodeId: EncodedNodeId.WithPublicKey) =>
withNextNodeId(msg, encodedNodeId.publicKey)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object OnionMessages {
}

object IntermediateNode {
def apply(publicKey: PublicKey): IntermediateNode = IntermediateNode(publicKey, EncodedNodeId(publicKey))
def apply(publicKey: PublicKey): IntermediateNode = IntermediateNode(publicKey, EncodedNodeId.WithPublicKey.Plain(publicKey))
}

// @formatter:off
Expand All @@ -64,7 +64,7 @@ object OnionMessages {
override def introductionNodeId: EncodedNodeId = route.introductionNodeId
}
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination {
override def introductionNodeId: EncodedNodeId = EncodedNodeId(nodeId)
override def introductionNodeId: EncodedNodeId = EncodedNodeId.WithPublicKey.Plain(nodeId)
}
// @formatter:on

Expand Down Expand Up @@ -93,7 +93,7 @@ object OnionMessages {
def buildRoute(blindingSecret: PrivateKey,
intermediateNodes: Seq[IntermediateNode],
recipient: Recipient): Sphinx.RouteBlinding.BlindedRouteDetails = {
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId(recipient.nodeId))
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId.WithPublicKey.Plain(recipient.nodeId))
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey) :+ recipient.nodeId, intermediatePayloads :+ lastPayload)
Expand Down
14 changes: 9 additions & 5 deletions eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ private class SendingMessage(nodeParams: NodeParams,
Behaviors.receiveMessagePartial {
case SendMessage =>
contactInfo match {
case OfferTypes.BlindedPath(route@BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1, scid), _, _)) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedNodeIdResponse), scid, isNode1)
waitForNodeId(route)
case OfferTypes.BlindedPath(route@BlindedRoute(EncodedNodeId.Plain(publicKey), _, _)) => sendToDestination(OnionMessages.BlindedPath(route), publicKey)
case blindedPath: OfferTypes.BlindedPath =>
blindedPath.route.introductionNodeId match {
case EncodedNodeId.ShortChannelIdDir(isNode1, scid) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedNodeIdResponse), scid, isNode1)
waitForNodeId(blindedPath.route)
case introductionNode: EncodedNodeId.WithPublicKey =>
sendToDestination(OnionMessages.BlindedPath(blindedPath.route), introductionNode.publicKey)
}
case OfferTypes.RecipientNodeId(nodeId) => sendToDestination(OnionMessages.Recipient(nodeId, None), nodeId)
}
}
Expand Down Expand Up @@ -214,7 +218,7 @@ private class SendingMessage(nodeParams: NodeParams,
replyTo ! Postman.MessageFailed(failure.toString)
Behaviors.stopped
case Right(message) =>
val nextNodeId = EncodedNodeId(intermediateNodes.headOption.getOrElse(plainNodeId))
val nextNodeId = EncodedNodeId.WithPublicKey.Plain(intermediateNodes.headOption.getOrElse(plainNodeId))
val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus)))
waitForSent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
private def resolveBlindedPaths(toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] = {
toResolve.headOption match {
case Some(paymentRoute) => paymentRoute.route.introductionNodeId match {
case EncodedNodeId.Plain(ourNodeId) if ourNodeId == nodeParams.nodeId && paymentRoute.route.length == 0 =>
case EncodedNodeId.WithPublicKey.Plain(ourNodeId) if ourNodeId == nodeParams.nodeId && paymentRoute.route.length == 0 =>
context.log.warn("ignoring blinded path (empty route with ourselves as the introduction node)")
resolveBlindedPaths(toResolve.tail, resolved)
case EncodedNodeId.Plain(ourNodeId) if ourNodeId == nodeParams.nodeId =>
case EncodedNodeId.WithPublicKey.Plain(ourNodeId) if ourNodeId == nodeParams.nodeId =>
// We are the introduction node of the blinded route: we need to decrypt the first payload.
val firstBlinding = paymentRoute.route.introductionNode.blindingEphemeralKey
val firstEncryptedPayload = paymentRoute.route.introductionNode.encryptedPayload
Expand Down Expand Up @@ -115,8 +115,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved)
}
}
case EncodedNodeId.Plain(remoteNodeId) =>
val path = ResolvedPath(FullBlindedRoute(remoteNodeId, paymentRoute.route.blindingKey, paymentRoute.route.blindedNodes), paymentRoute.paymentInfo)
case encodedNodeId: EncodedNodeId.WithPublicKey =>
val path = ResolvedPath(FullBlindedRoute(encodedNodeId.publicKey, paymentRoute.route.blindingKey, paymentRoute.route.blindedNodes), paymentRoute.paymentInfo)
resolveBlindedPaths(toResolve.tail, resolved :+ path)
case EncodedNodeId.ShortChannelIdDir(isNode1, scid) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedNodeId), scid, isNode1)
Expand Down Expand Up @@ -160,18 +160,16 @@ private class BlindedPathsResolver(nodeParams: NodeParams,
}
}

/** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.Plain]]. */
private def waitForNodeId(paymentRoute: PaymentBlindedRoute,
toResolve: Seq[PaymentBlindedRoute],
resolved: Seq[ResolvedPath]): Behavior[Command] =
/** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */
private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedNodeId(None) =>
context.log.warn("ignoring blinded path with unknown scid_dir={}", paymentRoute.route.introductionNodeId)
resolveBlindedPaths(toResolve, resolved)
case WrappedNodeId(Some(nodeId)) =>
context.log.debug("successfully resolved scid_dir={} to node_id={}", paymentRoute.route.introductionNodeId, nodeId)
// We've identified the node matching this scid_dir, we retry resolving with that node_id.
val paymentRouteWithNodeId = paymentRoute.copy(route = paymentRoute.route.copy(introductionNodeId = EncodedNodeId.Plain(nodeId)))
val paymentRouteWithNodeId = paymentRoute.copy(route = paymentRoute.route.copy(introductionNodeId = EncodedNodeId.WithPublicKey.Plain(nodeId)))
resolveBlindedPaths(paymentRouteWithNodeId +: toResolve, resolved)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.BlockHash
import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute}
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequestChain, InvoiceRequestPayerNote, InvoiceRequestQuantity, _}
import fr.acinq.eclair.wire.protocol.OfferTypes._
import fr.acinq.eclair.wire.protocol.TlvCodecs.{tlvField, tmillisatoshi, tu32, tu64overflow}
import fr.acinq.eclair.{EncodedNodeId, TimestampSecond, UInt64}
import scodec.Codec
import scodec.codecs._
import scodec.{Attempt, Codec, Err}

object OfferCodecs {
private val offerChains: Codec[OfferChains] = tlvField(list(blockHash).xmap[Seq[BlockHash]](_.toSeq, _.toList))
Expand All @@ -41,16 +41,21 @@ object OfferCodecs {

private val offerAbsoluteExpiry: Codec[OfferAbsoluteExpiry] = tlvField(tu64overflow.as[TimestampSecond])

private val isNode1: Codec[Boolean] = uint8.narrow(
n => if (n == 0) Attempt.Successful(true) else if (n == 1) Attempt.Successful(false) else Attempt.Failure(new Err.MatchingDiscriminatorNotFound(n)),
b => if (b) 0 else 1
)

private val shortChannelIdDirCodec: Codec[ShortChannelIdDir] =
(("isNode1" | isNode1) ::
("scid" | realshortchannelid)).as[ShortChannelIdDir]

val encodedNodeIdCodec: Codec[EncodedNodeId] = choice(shortChannelIdDirCodec.upcast[EncodedNodeId], publicKey.as[EncodedNodeId.Plain].upcast[EncodedNodeId])
/** A 32-bytes codec for public keys where the first byte is set manually. */
private def tweakFirstByteCodec(prefix: Byte): Codec[PublicKey] = bytes(32).xmap(b => PublicKey(prefix +: b), _.value.drop(1))

// The first byte encodes what type of identifier is used.
val encodedNodeIdCodec: Codec[EncodedNodeId] = discriminated[EncodedNodeId].by(uint8)
// If the first byte is 0x00 or 0x01, we're using a shortChannelId and a direction to identify the node.
.subcaseP(0x00) { case e: EncodedNodeId.ShortChannelIdDir if e.isNode1 => e }(realshortchannelid.xmap(EncodedNodeId.ShortChannelIdDir(true, _), _.scid))
.subcaseP(0x01) { case e: EncodedNodeId.ShortChannelIdDir if !e.isNode1 => e }(realshortchannelid.xmap(EncodedNodeId.ShortChannelIdDir(false, _), _.scid))
// If the first byte is 0x02 or 0x03, this is a standard public key.
.subcaseP(0x02) { case e: EncodedNodeId.WithPublicKey.Plain if e.publicKey.value.head == 0x02 => e }(tweakFirstByteCodec(2).xmap[EncodedNodeId.WithPublicKey.Plain](EncodedNodeId.WithPublicKey.Plain, _.publicKey))
.subcaseP(0x03) { case e: EncodedNodeId.WithPublicKey.Plain if e.publicKey.value.head == 0x03 => e }(tweakFirstByteCodec(3).xmap[EncodedNodeId.WithPublicKey.Plain](EncodedNodeId.WithPublicKey.Plain, _.publicKey))
// If the first byte is 0x04 or 0x05, this is a public key for a wallet node: we need to tweak back that first byte
// to be 0x02 or 0x03 to obtain a valid public key.
.subcaseP(0x04) { case e: EncodedNodeId.WithPublicKey.Wallet if e.publicKey.value.head == 0x02 => e }(tweakFirstByteCodec(2).xmap[EncodedNodeId.WithPublicKey.Wallet](EncodedNodeId.WithPublicKey.Wallet, _.publicKey))
.subcaseP(0x05) { case e: EncodedNodeId.WithPublicKey.Wallet if e.publicKey.value.head == 0x03 => e }(tweakFirstByteCodec(3).xmap[EncodedNodeId.WithPublicKey.Wallet](EncodedNodeId.WithPublicKey.Wallet, _.publicKey))

private val blindedNodeCodec: Codec[BlindedNode] =
(("nodeId" | publicKey) ::
Expand Down Expand Up @@ -186,7 +191,7 @@ object OfferCodecs {
.typecase(UInt64(240), signature)
).complete

val invoiceErrorTlvCodec: Codec[TlvStream[InvoiceErrorTlv]] = TlvCodecs.tlvStream[InvoiceErrorTlv](discriminated[InvoiceErrorTlv].by(varint)
private val invoiceErrorTlvCodec: Codec[TlvStream[InvoiceErrorTlv]] = TlvCodecs.tlvStream[InvoiceErrorTlv](discriminated[InvoiceErrorTlv].by(varint)
.typecase(UInt64(1), tlvField(tu64overflow.as[ErroneousField]))
.typecase(UInt64(3), tlvField(bytes.as[SuggestedValue]))
.typecase(UInt64(5), tlvField(utf8.as[Error]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ object RouteBlindingEncryptedDataTlv {

/**
* Id of the next node.
* Warning: the spec only allows a public key here. We allow reading a ShortChannelIdDir for phoenix but we should never write one.
*
* WARNING: the spec only allows a public key here. We allow reading any type of [[EncodedNodeId]] to support relaying
* to mobile wallets, but we should always write an [[EncodedNodeId.WithPublicKey.Plain]].
*/
case class OutgoingNodeId(nodeId: EncodedNodeId) extends RouteBlindingEncryptedDataTlv

object OutgoingNodeId {
def apply(publicKey: PublicKey): OutgoingNodeId = OutgoingNodeId(EncodedNodeId(publicKey))
def apply(publicKey: PublicKey): OutgoingNodeId = OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(publicKey))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.perHopPayloadCodec
import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._
import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessage, OnionMessagePayloadTlv, OnionRoutingCodecs, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{EncodedNodeId, ShortChannelId, UInt64, randomBytes, randomKey}
import org.json4s._
import org.json4s.jackson.JsonMethods._
Expand Down Expand Up @@ -213,7 +213,7 @@ class OnionMessagesSpec extends AnyFunSuite {
assert(message.blindingKey == blindingOverride.publicKey) // blindingSecret was not used as the replyPath was used as is

process(destination, message) match {
case SendMessage(Right(EncodedNodeId.Plain(nextNodeId2)), message2) =>
case SendMessage(Right(EncodedNodeId.WithPublicKey.Plain(nextNodeId2)), message2) =>
assert(nextNodeId2 == destination.publicKey)
process(destination, message2) match {
case ReceiveMessage(finalPayload, _) => assert(finalPayload.pathId_opt.contains(hex"01234567"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
val payerKey = randomKey()
val invoice = createBolt12Invoice(Features.empty, payerKey)
val resolvedPaths = invoice.blindedPaths.map(path => {
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey
ResolvedPath(FullBlindedRoute(introductionNodeId, path.route.blindingKey, path.route.blindedNodes), path.paymentInfo)
})
val req = SendPaymentToNode(sender.ref, finalAmount, invoice, resolvedPaths, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams, payerKey_opt = Some(payerKey))
Expand Down Expand Up @@ -330,7 +330,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
val payerKey = randomKey()
val invoice = createBolt12Invoice(Features(BasicMultiPartPayment -> Optional), payerKey)
val resolvedPaths = invoice.blindedPaths.map(path => {
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey
ResolvedPath(FullBlindedRoute(introductionNodeId, path.route.blindingKey, path.route.blindedNodes), path.paymentInfo)
})
val req = SendPaymentToNode(sender.ref, finalAmount, invoice, resolvedPaths, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams, payerKey_opt = Some(payerKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
val paymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 1 msat, amount_bc, Features.empty)
val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, recipientKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo)))
val resolvedPaths = invoice.blindedPaths.map(path => {
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey
ResolvedPath(FullBlindedRoute(introductionNodeId, path.route.blindingKey, path.route.blindedNodes), path.paymentInfo)
})
val recipient = BlindedRecipient(invoice, resolvedPaths, amount_bc, expiry_bc, Set.empty)
Expand Down Expand Up @@ -497,7 +497,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll {
val paymentInfo = OfferTypes.PaymentInfo(fee_b, 0, channelUpdate_bc.cltvExpiryDelta, 0 msat, amount_bc, Features.empty)
val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, priv_c.privateKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo)))
val resolvedPaths = invoice.blindedPaths.map(path => {
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey
val introductionNodeId = path.route.introductionNodeId.asInstanceOf[EncodedNodeId.WithPublicKey].publicKey
ResolvedPath(FullBlindedRoute(introductionNodeId, path.route.blindingKey, path.route.blindedNodes), path.paymentInfo)
})
val recipient = BlindedRecipient(invoice, resolvedPaths, amount_bc, expiry_bc, Set.empty)
Expand Down
Loading

0 comments on commit 3277e6d

Please sign in to comment.