Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sciddir_or_pubkey #2752

Merged
merged 12 commits into from
Nov 14, 2023
14 changes: 7 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
} else {
val recipientAmount = recipientAmount_opt.getOrElse(invoice.amount_opt.getOrElse(route.amount))
val trampoline_opt = trampolineFees_opt.map(fees => TrampolineAttempt(trampolineSecret_opt.getOrElse(randomBytes32()), fees, trampolineExpiryDelta_opt.get))
val sendPayment = SendPaymentToRoute(recipientAmount, invoice, route, externalId_opt, parentId_opt, trampoline_opt)
val sendPayment = SendPaymentToRoute(recipientAmount, invoice, Nil, route, externalId_opt, parentId_opt, trampoline_opt)
(appKit.paymentInitiator ? sendPayment).mapTo[SendPaymentToRouteResponse]
}
}
Expand All @@ -442,7 +442,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
externalId_opt match {
case Some(externalId) if externalId.length > externalIdMaxLength => Left(new IllegalArgumentException(s"externalId is too long: cannot exceed $externalIdMaxLength characters"))
case _ if invoice.isExpired() => Left(new IllegalArgumentException("invoice has expired"))
case _ => Right(SendPaymentToNode(ActorRef.noSender, amount, invoice, maxAttempts, externalId_opt, routeParams = routeParams))
case _ => Right(SendPaymentToNode(ActorRef.noSender, amount, invoice, Nil, maxAttempts, externalId_opt, routeParams = routeParams))
}
case Left(t) => Left(t)
}
Expand Down Expand Up @@ -663,15 +663,15 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
userCustomContent: ByteVector)(implicit timeout: Timeout): Future[SendOnionMessageResponse] = {
TlvCodecs.tlvStream(MessageOnionCodecs.onionTlvCodec).decode(userCustomContent.bits) match {
case Attempt.Successful(DecodeResult(userTlvs, _)) =>
val destination = recipient match {
case Left(key) => OnionMessages.Recipient(key, None)
case Right(route) => OnionMessages.BlindedPath(route)
val contactInfo = recipient match {
case Left(key) => OfferTypes.RecipientNodeId(key)
case Right(route) => OfferTypes.BlindedPath(route)
}
val routingStrategy = intermediateNodes_opt match {
case Some(intermediateNodes) => OnionMessages.RoutingStrategy.UseRoute(intermediateNodes)
case None => OnionMessages.RoutingStrategy.FindRoute
}
appKit.postman.ask(ref => Postman.SendMessage(destination, routingStrategy, userTlvs, expectsReply, ref)).map {
appKit.postman.ask(ref => Postman.SendMessage(contactInfo, routingStrategy, userTlvs, expectsReply, ref)).map {
case Postman.Response(payload) => SendOnionMessageResponse(sent = true, None, Some(SendOnionMessageResponsePayload(payload.records)))
case Postman.NoReply => SendOnionMessageResponse(sent = true, Some("No response"), None)
case Postman.MessageSent => SendOnionMessageResponse(sent = true, None, None)
Expand Down Expand Up @@ -702,7 +702,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
case Left(t) => return Future.failed(t)
}
val sendPaymentConfig = OfferPayment.SendPaymentConfig(externalId_opt, connectDirectly, maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts), routeParams, blocking)
val offerPayment = appKit.system.spawnAnonymous(OfferPayment(appKit.nodeParams, appKit.postman, appKit.paymentInitiator))
val offerPayment = appKit.system.spawnAnonymous(OfferPayment(appKit.nodeParams, appKit.postman, appKit.router, appKit.paymentInitiator))
offerPayment.ask((ref: typed.ActorRef[Any]) => OfferPayment.PayOffer(ref.toClassic, offer, amount, quantity, sendPaymentConfig)).flatMap {
case f: OfferPayment.Failure => Future.failed(new Exception(f.toString))
case x => Future.successful(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,17 @@ object InvoiceSerializer extends MinimalSerializer({
UnknownFeatureSerializer
)),
JField("blindedPaths", JArray(p.blindedPaths.map(path => {
val introductionNode = path.route match {
case OfferTypes.BlindedPath(route) => route.introductionNodeId.toString
case OfferTypes.CompactBlindedPath(shortIdDir, _, _) => s"${if (shortIdDir.isNode1) '0' else '1'}x${shortIdDir.scid.toString}"
}
val blindedNodes = path.route match {
case OfferTypes.BlindedPath(route) => route.blindedNodes
case OfferTypes.CompactBlindedPath(_, _, nodes) => nodes
}
JObject(List(
JField("introductionNodeId", JString(path.route.introductionNodeId.toString())),
JField("blindedNodeIds", JArray(path.route.blindedNodes.map(n => JString(n.blindedPublicKey.toString())).toList))
JField("introductionNodeId", JString(introductionNode)),
JField("blindedNodeIds", JArray(blindedNodes.map(n => JString(n.blindedPublicKey.toString)).toList))
))
}).toList)),
JField("createdAt", JLong(p.createdAt.toLong)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ object OnionMessages {
}

// @formatter:off
sealed trait Destination
case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination
sealed trait Destination {
def nodeId: PublicKey
}
case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination {
override def nodeId: PublicKey = route.introductionNodeId
}
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination
// @formatter:on

Expand Down
72 changes: 45 additions & 27 deletions eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute
import fr.acinq.eclair.io.MessageRelay
import fr.acinq.eclair.io.MessageRelay.RelayPolicy
import fr.acinq.eclair.message.OnionMessages.{Destination, RoutingStrategy}
import fr.acinq.eclair.payment.offer.OfferManager
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, MessageRouteResponse}
import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload}
import fr.acinq.eclair.wire.protocol.{OnionMessage, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{NodeParams, ShortChannelId, randomBytes32, randomKey}
import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, ContactInfo}
import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey}

import scala.collection.mutable

Expand All @@ -40,13 +41,13 @@ object Postman {
/**
* Builds a message packet and send it to the destination using the provided path.
*
* @param destination Recipient of the message
* @param contactInfo Recipient of the message
* @param routingStrategy How to reach the destination (recipient or blinded path introduction node).
* @param message Content of the message to send
* @param expectsReply Whether the message expects a reply
* @param replyTo Actor to send the status and reply to
*/
case class SendMessage(destination: Destination,
case class SendMessage(contactInfo: ContactInfo,
routingStrategy: RoutingStrategy,
message: TlvStream[OnionMessagePayloadTlv],
expectsReply: Boolean,
Expand All @@ -63,7 +64,7 @@ object Postman {
case class MessageFailed(reason: String) extends MessageStatus
// @formatter:on

def apply(nodeParams: NodeParams, switchboard: akka.actor.ActorRef, router: ActorRef[Router.MessageRouteRequest], register: akka.actor.ActorRef, offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = {
def apply(nodeParams: NodeParams, switchboard: akka.actor.ActorRef, router: ActorRef[Router.PostmanRequest], register: akka.actor.ActorRef, offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = {
Behaviors.setup(context => {
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[OnionMessages.ReceiveMessage](r => WrappedMessage(r.finalPayload)))

Expand Down Expand Up @@ -110,31 +111,32 @@ object SendingMessage {
case object SendMessage extends Command
private case class SendingStatus(status: MessageRelay.Status) extends Command
private case class WrappedMessageRouteResponse(response: MessageRouteResponse) extends Command
private case class WrappedNodeIdResponse(nodeId_opt: Option[PublicKey]) extends Command
// @formatter:on

def apply(nodeParams: NodeParams,
router: ActorRef[Router.MessageRouteRequest],
router: ActorRef[Router.PostmanRequest],
postman: ActorRef[Postman.Command],
switchboard: akka.actor.ActorRef,
register: akka.actor.ActorRef,
destination: Destination,
contactInfo: ContactInfo,
message: TlvStream[OnionMessagePayloadTlv],
routingStrategy: RoutingStrategy,
expectsReply: Boolean,
replyTo: ActorRef[Postman.OnionMessageResponse]): Behavior[Command] = {
Behaviors.setup(context => {
val actor = new SendingMessage(nodeParams, router, postman, switchboard, register, destination, message, routingStrategy, expectsReply, replyTo, context)
val actor = new SendingMessage(nodeParams, router, postman, switchboard, register, contactInfo, message, routingStrategy, expectsReply, replyTo, context)
actor.start()
})
}
}

private class SendingMessage(nodeParams: NodeParams,
router: ActorRef[Router.MessageRouteRequest],
router: ActorRef[Router.PostmanRequest],
postman: ActorRef[Postman.Command],
switchboard: akka.actor.ActorRef,
register: akka.actor.ActorRef,
destination: Destination,
contactInfo: ContactInfo,
message: TlvStream[OnionMessagePayloadTlv],
routingStrategy: RoutingStrategy,
expectsReply: Boolean,
Expand All @@ -146,40 +148,56 @@ private class SendingMessage(nodeParams: NodeParams,
def start(): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case SendMessage =>
val targetNodeId = destination match {
case OnionMessages.BlindedPath(route) => route.introductionNodeId
case OnionMessages.Recipient(nodeId, _, _, _) => nodeId
}
routingStrategy match {
case RoutingStrategy.UseRoute(intermediateNodes) => sendToRoute(intermediateNodes, targetNodeId)
case RoutingStrategy.FindRoute if targetNodeId == nodeParams.nodeId =>
context.self ! WrappedMessageRouteResponse(MessageRoute(Nil, targetNodeId))
waitForRouteFromRouter()
case RoutingStrategy.FindRoute =>
router ! Router.MessageRouteRequest(context.messageAdapter(WrappedMessageRouteResponse), nodeParams.nodeId, targetNodeId, Set.empty)
waitForRouteFromRouter()
contactInfo match {
case compact: OfferTypes.CompactBlindedPath =>
router ! Router.GetNodeId(context.messageAdapter(WrappedNodeIdResponse), compact.introductionNode.scid, compact.introductionNode.isNode1)
waitForNodeId(compact)
case OfferTypes.BlindedPath(route) => sendToDestination(OnionMessages.BlindedPath(route))
case OfferTypes.RecipientNodeId(nodeId) => sendToDestination(OnionMessages.Recipient(nodeId, None))
}
}
}

private def waitForRouteFromRouter(): Behavior[Command] = {
private def waitForNodeId(compactBlindedPath: CompactBlindedPath): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedNodeIdResponse(None) =>
replyTo ! Postman.MessageFailed(s"Could not resolve introduction node for compact blinded path (scid=${compactBlindedPath.introductionNode.scid.toCoordinatesString})")
Behaviors.stopped
case WrappedNodeIdResponse(Some(nodeId)) =>
sendToDestination(OnionMessages.BlindedPath(BlindedRoute(nodeId, compactBlindedPath.blindingKey, compactBlindedPath.blindedNodes)))
}
}

private def sendToDestination(destination: Destination): Behavior[Command] = {
routingStrategy match {
case RoutingStrategy.UseRoute(intermediateNodes) => sendToRoute(intermediateNodes, destination)
case RoutingStrategy.FindRoute if destination.nodeId == nodeParams.nodeId =>
context.self ! WrappedMessageRouteResponse(MessageRoute(Nil, destination.nodeId))
waitForRouteFromRouter(destination)
case RoutingStrategy.FindRoute =>
router ! Router.MessageRouteRequest(context.messageAdapter(WrappedMessageRouteResponse), nodeParams.nodeId, destination.nodeId, Set.empty)
waitForRouteFromRouter(destination)
}
}

private def waitForRouteFromRouter(destination: Destination): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedMessageRouteResponse(MessageRoute(intermediateNodes, targetNodeId)) =>
context.log.debug("Found route: {}", (intermediateNodes :+ targetNodeId).mkString(" -> "))
sendToRoute(intermediateNodes, targetNodeId)
sendToRoute(intermediateNodes, destination)
case WrappedMessageRouteResponse(MessageRouteNotFound(targetNodeId)) =>
context.log.debug("No route found to {}", targetNodeId)
replyTo ! Postman.MessageFailed("No route found")
Behaviors.stopped
}
}

private def sendToRoute(intermediateNodes: Seq[PublicKey], targetNodeId: PublicKey): Behavior[Command] = {
private def sendToRoute(intermediateNodes: Seq[PublicKey], destination: Destination): Behavior[Command] = {
val messageId = randomBytes32()
val replyRoute =
if (expectsReply) {
val numHopsToAdd = 0.max(nodeParams.onionMessageConfig.minIntermediateHops - intermediateNodes.length - 1)
val intermediateHops = (Seq(targetNodeId) ++ intermediateNodes.reverse ++ Seq.fill(numHopsToAdd)(nodeParams.nodeId)).map(OnionMessages.IntermediateNode(_))
val intermediateHops = (Seq(destination.nodeId) ++ intermediateNodes.reverse ++ Seq.fill(numHopsToAdd)(nodeParams.nodeId)).map(OnionMessages.IntermediateNode(_))
val lastHop = OnionMessages.Recipient(nodeParams.nodeId, Some(messageId))
Some(OnionMessages.buildRoute(randomKey(), intermediateHops, lastHop))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import fr.acinq.bitcoin.Bech32
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, Crypto}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute
import fr.acinq.eclair.wire.protocol.OfferTypes._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferCodecs, OfferTypes, TlvStream}
Expand Down Expand Up @@ -52,7 +53,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice {
// We add invoice features that are implicitly required for Bolt 12 (the spec doesn't allow explicitly setting them).
f.add(Features.VariableLengthOnion, FeatureSupport.Mandatory).add(Features.RouteBlinding, FeatureSupport.Mandatory)
}
val blindedPaths: Seq[PaymentBlindedRoute] = records.get[InvoicePaths].get.paths.zip(records.get[InvoiceBlindedPay].get.paymentInfo).map { case (route, info) => PaymentBlindedRoute(route, info) }
val blindedPaths: Seq[PaymentBlindedContactInfo] = records.get[InvoicePaths].get.paths.zip(records.get[InvoiceBlindedPay].get.paymentInfo).map { case (route, info) => PaymentBlindedContactInfo(route, info) }
val fallbacks: Option[Seq[FallbackAddress]] = records.get[InvoiceFallbacks].map(_.addresses)
val signature: ByteVector64 = records.get[Signature].get.signature

Expand Down Expand Up @@ -86,7 +87,9 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice {

}

case class PaymentBlindedRoute(route: Sphinx.RouteBlinding.BlindedRoute, paymentInfo: PaymentInfo)
case class PaymentBlindedContactInfo(route: BlindedContactInfo, paymentInfo: PaymentInfo)

case class PaymentBlindedRoute(route: BlindedRoute, paymentInfo: PaymentInfo)

object Bolt12Invoice {
val hrp = "lni"
Expand All @@ -107,7 +110,7 @@ object Bolt12Invoice {
nodeKey: PrivateKey,
invoiceExpiry: FiniteDuration,
features: Features[Bolt12Feature],
paths: Seq[PaymentBlindedRoute],
paths: Seq[PaymentBlindedContactInfo],
additionalTlvs: Set[InvoiceTlv] = Set.empty,
customTlvs: Set[GenericTlv] = Set.empty): Bolt12Invoice = {
require(request.amount.nonEmpty || request.offer.amount.nonEmpty)
Expand Down
Loading