Skip to content

Commit

Permalink
Use balance estimates from past payments in path-finding
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Nov 12, 2024
1 parent f02c98b commit 5fc7841
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 211 deletions.
2 changes: 2 additions & 0 deletions eclair-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ eclair {
// probability of success, however is penalizes less the paths with a low probability of success.
use-log-probability = false

use-past-relay-data = false

mpp {
min-amount-satoshis = 15000 // minimum amount sent via partial HTLCs
max-parts = 5 // maximum number of HTLCs sent per payment: increasing this value will impact performance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ object NodeParams extends Logging {
failureCost = getRelayFees(config.getConfig("failure-cost")),
hopCost = getRelayFees(config.getConfig("hop-cost")),
useLogProbability = config.getBoolean("use-log-probability"),
usePastRelaysData = config.getBoolean("use-past-relay-data"),
))
},
mpp = MultiPartParams(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ object EclairInternalsSerializer {
("lockedFundsRisk" | double) ::
("failureCost" | relayFeesCodec) ::
("hopCost" | relayFeesCodec) ::
("useLogProbability" | bool(8))).as[HeuristicsConstants]
("useLogProbability" | bool(8)) ::
("usePastRelaysData" | bool(8))).as[HeuristicsConstants]

val multiPartParamsCodec: Codec[MultiPartParams] = (
("minPartAmount" | millisatoshi) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Route}
import fr.acinq.eclair.wire.protocol.NodeAnnouncement
import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond, TimestampSecondLong, ToMilliSatoshiConversion}

import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.duration.FiniteDuration

/**
* Estimates the balance between a pair of nodes
Expand Down Expand Up @@ -235,6 +235,8 @@ object BalanceEstimate {
case class BalancesEstimates(balances: Map[(PublicKey, PublicKey), BalanceEstimate], defaultHalfLife: FiniteDuration) {
private def get(a: PublicKey, b: PublicKey): Option[BalanceEstimate] = balances.get((a, b))

def get(edge: GraphEdge): BalanceEstimate = get(edge.desc.a, edge.desc.b).getOrElse(BalanceEstimate.empty(defaultHalfLife).addEdge(edge))

def addEdge(edge: GraphEdge): BalancesEstimates = BalancesEstimates(
balances.updatedWith((edge.desc.a, edge.desc.b))(balance =>
Some(balance.getOrElse(BalanceEstimate.empty(defaultHalfLife)).addEdge(edge))
Expand Down Expand Up @@ -284,7 +286,7 @@ case class BalancesEstimates(balances: Map[(PublicKey, PublicKey), BalanceEstima

}

case class GraphWithBalanceEstimates(graph: DirectedGraph, private val balances: BalancesEstimates) {
case class GraphWithBalanceEstimates(graph: DirectedGraph, balances: BalancesEstimates) {
def addOrUpdateVertex(ann: NodeAnnouncement): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.addOrUpdateVertex(ann), balances)

def addEdge(edge: GraphEdge): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.addEdge(edge), balances.addEdge(edge))
Expand Down Expand Up @@ -317,13 +319,6 @@ case class GraphWithBalanceEstimates(graph: DirectedGraph, private val balances:
def channelCouldNotSend(hop: ChannelHop, amount: MilliSatoshi): GraphWithBalanceEstimates = {
GraphWithBalanceEstimates(graph, balances.channelCouldNotSend(hop, amount))
}

def canSend(amount: MilliSatoshi, edge: GraphEdge): Double = {
balances.balances.get((edge.desc.a, edge.desc.b)) match {
case Some(estimate) => estimate.canSend(amount)
case None => BalanceEstimate.empty(1 hour).addEdge(edge).canSend(amount)
}
}
}

object GraphWithBalanceEstimates {
Expand Down
46 changes: 27 additions & 19 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ object Graph {
* The fee for a failed attempt and the fee per hop are never actually spent, they are used to incentivize shorter
* paths or path with higher success probability.
*
* @param lockedFundsRisk cost of having funds locked in htlc in msat per msat per block
* @param failureCost fee for a failed attempt
* @param hopCost virtual fee per hop (how much we're willing to pay to make the route one hop shorter)
* @param lockedFundsRisk cost of having funds locked in htlc in msat per msat per block
* @param failureCost fee for a failed attempt
* @param hopCost virtual fee per hop (how much we're willing to pay to make the route one hop shorter)
* @param usePastRelaysData use data from past relays to estimate the balance of the channels
*/
case class HeuristicsConstants(lockedFundsRisk: Double, failureCost: RelayFees, hopCost: RelayFees, useLogProbability: Boolean)
case class HeuristicsConstants(lockedFundsRisk: Double, failureCost: RelayFees, hopCost: RelayFees, useLogProbability: Boolean, usePastRelaysData: Boolean)

case class WeightedNode(key: PublicKey, weight: RichWeight)

Expand Down Expand Up @@ -109,7 +110,7 @@ object Graph {
* @param boundaries a predicate function that can be used to impose limits on the outcome of the search
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
def yenKshortestPaths(graph: DirectedGraph,
def yenKshortestPaths(g: GraphWithBalanceEstimates,
sourceNode: PublicKey,
targetNode: PublicKey,
amount: MilliSatoshi,
Expand All @@ -123,7 +124,7 @@ object Graph {
includeLocalChannelCost: Boolean): Seq[WeightedPath] = {
// find the shortest path (k = 0)
val targetWeight = RichWeight(amount, 0, CltvExpiryDelta(0), 1.0, 0 msat, 0 msat, 0.0)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
val shortestPath = dijkstraShortestPath(g, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
if (shortestPath.isEmpty) {
return Seq.empty // if we can't even find a single path, avoid returning a Seq(Seq.empty)
}
Expand All @@ -135,7 +136,7 @@ object Graph {

var allSpurPathsFound = false
val shortestPaths = new mutable.Queue[PathWithSpur]
shortestPaths.enqueue(PathWithSpur(WeightedPath(shortestPath, pathWeight(sourceNode, shortestPath, amount, currentBlockHeight, wr, includeLocalChannelCost)), 0))
shortestPaths.enqueue(PathWithSpur(WeightedPath(shortestPath, pathWeight(g.balances, sourceNode, shortestPath, amount, currentBlockHeight, wr, includeLocalChannelCost)), 0))
// stores the candidates for the k-th shortest path, sorted by path cost
val candidates = new mutable.PriorityQueue[PathWithSpur]

Expand All @@ -160,12 +161,12 @@ object Graph {
val alreadyExploredEdges = shortestPaths.collect { case p if p.p.path.takeRight(i) == rootPathEdges => p.p.path(p.p.path.length - 1 - i).desc }.toSet
// we also want to ignore any vertex on the root path to prevent loops
val alreadyExploredVertices = rootPathEdges.map(_.desc.b).toSet
val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr, includeLocalChannelCost)
val rootPathWeight = pathWeight(g.balances, sourceNode, rootPathEdges, amount, currentBlockHeight, wr, includeLocalChannelCost)
// find the "spur" path, a sub-path going from the spur node to the target avoiding previously found sub-paths
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
val spurPath = dijkstraShortestPath(g, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
if (spurPath.nonEmpty) {
val completePath = spurPath ++ rootPathEdges
val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr, includeLocalChannelCost))
val candidatePath = WeightedPath(completePath, pathWeight(g.balances, sourceNode, completePath, amount, currentBlockHeight, wr, includeLocalChannelCost))
candidates.enqueue(PathWithSpur(candidatePath, i))
}
}
Expand Down Expand Up @@ -200,7 +201,7 @@ object Graph {
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
private def dijkstraShortestPath(g: DirectedGraph,
private def dijkstraShortestPath(g: GraphWithBalanceEstimates,
sourceNode: PublicKey,
targetNode: PublicKey,
ignoredEdges: Set[ChannelDesc],
Expand All @@ -212,8 +213,8 @@ object Graph {
wr: Either[WeightRatios, HeuristicsConstants],
includeLocalChannelCost: Boolean): Seq[GraphEdge] = {
// the graph does not contain source/destination nodes
val sourceNotInGraph = !g.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode)
val targetNotInGraph = !g.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode)
val sourceNotInGraph = !g.graph.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode)
val targetNotInGraph = !g.graph.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode)
if (sourceNotInGraph || targetNotInGraph) {
return Seq.empty
}
Expand Down Expand Up @@ -242,7 +243,7 @@ object Graph {
val neighborEdges = {
val extraNeighbors = extraEdges.filter(_.desc.b == current.key)
// the resulting set must have only one element per shortChannelId; we prioritize extra edges
g.getIncomingEdgesOf(current.key).collect{case e: GraphEdge if !extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId) => e} ++ extraNeighbors
g.graph.getIncomingEdgesOf(current.key).collect{case e: GraphEdge if !extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId) => e} ++ extraNeighbors
}
neighborEdges.foreach { edge =>
val neighbor = edge.desc.a
Expand All @@ -254,7 +255,7 @@ object Graph {
!ignoredVertices.contains(neighbor)) {
// NB: this contains the amount (including fees) that will need to be sent to `neighbor`, but the amount that
// will be relayed through that edge is the one in `currentWeight`.
val neighborWeight = addEdgeWeight(sourceNode, edge, current.weight, currentBlockHeight, wr, includeLocalChannelCost)
val neighborWeight = addEdgeWeight(sourceNode, edge, g.balances.get(edge), current.weight, currentBlockHeight, wr, includeLocalChannelCost)
if (boundaries(neighborWeight)) {
val previousNeighborWeight = bestWeights.getOrElse(neighbor, RichWeight(MilliSatoshi(Long.MaxValue), Int.MaxValue, CltvExpiryDelta(Int.MaxValue), 0.0, MilliSatoshi(Long.MaxValue), MilliSatoshi(Long.MaxValue), Double.MaxValue))
// if this path between neighbor and the target has a shorter distance than previously known, we select it
Expand Down Expand Up @@ -298,7 +299,7 @@ object Graph {
* @param weightRatios ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
private def addEdgeWeight(sender: PublicKey, edge: GraphEdge, prev: RichWeight, currentBlockHeight: BlockHeight, weightRatios: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = {
private def addEdgeWeight(sender: PublicKey, edge: GraphEdge, balance: BalanceEstimate, prev: RichWeight, currentBlockHeight: BlockHeight, weightRatios: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = {
val totalAmount = if (edge.desc.a == sender && !includeLocalChannelCost) prev.amount else addEdgeFees(edge, prev.amount)
val fee = totalAmount - prev.amount
val totalFees = prev.fees + fee
Expand Down Expand Up @@ -335,7 +336,14 @@ object Graph {
val hopCost = nodeFee(heuristicsConstants.hopCost, prev.amount)
val totalHopsCost = prev.virtualFees + hopCost
// If we know the balance of the channel, then we will check separately that it can relay the payment.
val successProbability = if (edge.balance_opt.nonEmpty) 1.0 else 1.0 - prev.amount.toLong.toDouble / edge.capacity.toMilliSatoshi.toLong.toDouble
val successProbability =
if (edge.balance_opt.nonEmpty){
1.0
} else if (heuristicsConstants.usePastRelaysData) {
balance.canSend(prev.amount)
} else {
1.0 - prev.amount.toLong.toDouble / edge.capacity.toMilliSatoshi.toLong.toDouble
}
if (successProbability < 0) {
throw NegativeProbability(edge, prev, heuristicsConstants)
}
Expand Down Expand Up @@ -396,9 +404,9 @@ object Graph {
* @param wr ratios used to 'weight' edges when searching for the shortest path
* @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel
*/
def pathWeight(sender: PublicKey, path: Seq[GraphEdge], amount: MilliSatoshi, currentBlockHeight: BlockHeight, wr: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = {
def pathWeight(balances: BalancesEstimates, sender: PublicKey, path: Seq[GraphEdge], amount: MilliSatoshi, currentBlockHeight: BlockHeight, wr: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = {
path.foldRight(RichWeight(amount, 0, CltvExpiryDelta(0), 1.0, 0 msat, 0 msat, 0.0)) { (edge, prev) =>
addEdgeWeight(sender, edge, prev, currentBlockHeight, wr, includeLocalChannelCost)
addEdgeWeight(sender, edge, balances.get(edge), prev, currentBlockHeight, wr, includeLocalChannelCost)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ object RouteCalculation {
val tags = TagSet.Empty.withTag(Tags.MultiPart, r.allowMultiPart).withTag(Tags.Amount, Tags.amountBucket(amountToSend))
KamonExt.time(Metrics.FindRouteDuration.withTags(tags.withTag(Tags.NumberOfRoutes, routesToFind.toLong))) {
val result = if (r.allowMultiPart) {
findMultiPartRoute(d.graphWithBalances.graph, r.source, targetNodeId, amountToSend, maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, r.routeParams, currentBlockHeight)
findMultiPartRoute(d.graphWithBalances, r.source, targetNodeId, amountToSend, maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, r.routeParams, currentBlockHeight)
} else {
findRoute(d.graphWithBalances.graph, r.source, targetNodeId, amountToSend, maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, r.routeParams, currentBlockHeight)
findRoute(d.graphWithBalances, r.source, targetNodeId, amountToSend, maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, r.routeParams, currentBlockHeight)
}
result.map(routes => addFinalHop(r.target, routes)) match {
case Success(routes) =>
Expand Down Expand Up @@ -294,7 +294,7 @@ object RouteCalculation {
* @param routeParams a set of parameters that can restrict the route search
* @return the computed routes to the destination @param targetNodeId
*/
def findRoute(g: DirectedGraph,
def findRoute(g: GraphWithBalanceEstimates,
localNodeId: PublicKey,
targetNodeId: PublicKey,
amount: MilliSatoshi,
Expand All @@ -312,7 +312,7 @@ object RouteCalculation {
}

@tailrec
private def findRouteInternal(g: DirectedGraph,
private def findRouteInternal(g: GraphWithBalanceEstimates,
localNodeId: PublicKey,
targetNodeId: PublicKey,
amount: MilliSatoshi,
Expand Down Expand Up @@ -370,7 +370,7 @@ object RouteCalculation {
* @param routeParams a set of parameters that can restrict the route search
* @return a set of disjoint routes to the destination @param targetNodeId with the payment amount split between them
*/
def findMultiPartRoute(g: DirectedGraph,
def findMultiPartRoute(g: GraphWithBalanceEstimates,
localNodeId: PublicKey,
targetNodeId: PublicKey,
amount: MilliSatoshi,
Expand All @@ -394,7 +394,7 @@ object RouteCalculation {
}
}

private def findMultiPartRouteInternal(g: DirectedGraph,
private def findMultiPartRouteInternal(g: GraphWithBalanceEstimates,
localNodeId: PublicKey,
targetNodeId: PublicKey,
amount: MilliSatoshi,
Expand All @@ -409,7 +409,7 @@ object RouteCalculation {
// When the recipient is a direct peer, we have complete visibility on our local channels so we can use more accurate MPP parameters.
val routeParams1 = {
case class DirectChannel(balance: MilliSatoshi, isEmpty: Boolean)
val directChannels = g.getEdgesBetween(localNodeId, targetNodeId).collect {
val directChannels = g.graph.getEdgesBetween(localNodeId, targetNodeId).collect {
// We should always have balance information available for local channels.
// NB: htlcMinimumMsat is set by our peer and may be 0 msat (even though it's not recommended).
case GraphEdge(_, params, _, Some(balance)) => DirectChannel(balance, balance <= 0.msat || balance < params.htlcMinimum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,16 @@ class BalanceEstimateSpec extends AnyFunSuite {
val edge_ab = makeEdge(a, b, 1, 10 sat)
val edge_ba = makeEdge(b, a, 1, 10 sat)
val edge_bc = makeEdge(b, c, 6, 10 sat)
assert(graphWithBalances.canSend(27500 msat, edge_ab) === 0.75 +- 0.01)
assert(graphWithBalances.canSend(55000 msat, edge_ab) === 0.5 +- 0.01)
assert(graphWithBalances.canSend(30000 msat, edge_ba) === 0.75 +- 0.01)
assert(graphWithBalances.canSend(60000 msat, edge_ba) === 0.5 +- 0.01)
assert(graphWithBalances.canSend(75000 msat, edge_bc) === 0.5 +- 0.01)
assert(graphWithBalances.canSend(100000 msat, edge_bc) === 0.33 +- 0.01)
assert(graphWithBalances.balances.get(edge_ab).canSend(27500 msat) === 0.75 +- 0.01)
assert(graphWithBalances.balances.get(edge_ab).canSend(55000 msat) === 0.5 +- 0.01)
assert(graphWithBalances.balances.get(edge_ba).canSend(30000 msat) === 0.75 +- 0.01)
assert(graphWithBalances.balances.get(edge_ba).canSend(60000 msat) === 0.5 +- 0.01)
assert(graphWithBalances.balances.get(edge_bc).canSend(75000 msat) === 0.5 +- 0.01)
assert(graphWithBalances.balances.get(edge_bc).canSend(100000 msat) === 0.33 +- 0.01)
val unknownEdge = makeEdge(42, 40 sat)
assert(graphWithBalances.canSend(10000 msat, unknownEdge) === 0.75 +- 0.01)
assert(graphWithBalances.canSend(20000 msat, unknownEdge) === 0.5 +- 0.01)
assert(graphWithBalances.canSend(30000 msat, unknownEdge) === 0.25 +- 0.01)
assert(graphWithBalances.balances.get(unknownEdge).canSend(10000 msat) === 0.75 +- 0.01)
assert(graphWithBalances.balances.get(unknownEdge).canSend(20000 msat) === 0.5 +- 0.01)
assert(graphWithBalances.balances.get(unknownEdge).canSend(30000 msat) === 0.25 +- 0.01)
}

}
Loading

0 comments on commit 5fc7841

Please sign in to comment.