diff --git a/src/fiber/graph.rs b/src/fiber/graph.rs index 2adb7b8c..e55bfd32 100644 --- a/src/fiber/graph.rs +++ b/src/fiber/graph.rs @@ -1,4 +1,3 @@ -use super::channel::ChannelActorStateStore; use super::history::{Direction, InternalResult, PaymentHistory, TimedResult}; use super::network::{get_chain_hash, SendPaymentData, SendPaymentResponse}; use super::path::NodeHeap; @@ -184,7 +183,7 @@ pub struct PathEdge { impl NetworkGraph where - S: ChannelActorStateStore + NetworkGraphStateStore + Clone + Send + Sync + 'static, + S: NetworkGraphStateStore + Clone + Send + Sync + 'static, { pub fn new(store: S, source: Pubkey) -> Self { let mut network_graph = Self { diff --git a/src/fiber/history.rs b/src/fiber/history.rs index 5a5ce490..c1384f6d 100644 --- a/src/fiber/history.rs +++ b/src/fiber/history.rs @@ -3,14 +3,13 @@ // we only use direct channel probability now. use super::{ - channel::ChannelActorStateStore, graph::{NetworkGraphStateStore, SessionRouteNode}, types::{Pubkey, TlcErr}, }; use crate::{fiber::types::TlcErrorCode, now_timestamp_as_millis_u64}; use ckb_types::packed::OutPoint; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use tracing::{debug, error}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -48,16 +47,10 @@ pub(crate) struct InternalPairResult { #[derive(Debug, Clone, PartialEq, Eq, Default)] pub(crate) struct InternalResult { pub pairs: HashMap<(OutPoint, Direction), InternalPairResult>, + pub nodes_to_channel_map: HashMap>, pub fail_node: Option, } -fn current_time() -> u128 { - std::time::UNIX_EPOCH - .elapsed() - .expect("unix epoch") - .as_millis() -} - pub(crate) fn output_direction(node1: Pubkey, node2: Pubkey) -> (Direction, Direction) { if node1 < node2 { (Direction::Forward, Direction::Backward) @@ -76,13 +69,24 @@ impl InternalResult { amount: u128, success: bool, ) { - let pair = InternalPairResult { - success, - time, - amount, - }; let (direction, _) = output_direction(node_1, node_2); - self.pairs.insert((channel, direction), pair); + self.add_node_channel_map(node_1, channel.clone()); + self.add_node_channel_map(node_2, channel.clone()); + self.pairs.insert( + (channel, direction), + InternalPairResult { + success, + time, + amount, + }, + ); + } + + fn add_node_channel_map(&mut self, node: Pubkey, channel: OutPoint) { + self.nodes_to_channel_map + .entry(node) + .or_default() + .insert(channel); } pub fn add_fail_pair(&mut self, from: Pubkey, target: Pubkey, channel: OutPoint) { @@ -294,7 +298,7 @@ impl InternalResult { #[derive(Debug, Clone)] pub(crate) struct PaymentHistory { pub inner: HashMap<(OutPoint, Direction), TimedResult>, - pub failed_nodes: HashMap, + pub nodes_to_channel_map: HashMap>, // The minimum interval between two failed payments in milliseconds pub min_fail_relax_interval: u64, pub bimodal_scale_msat: f64, @@ -307,13 +311,13 @@ pub(crate) struct PaymentHistory { impl PaymentHistory where - S: ChannelActorStateStore + NetworkGraphStateStore + Clone + Send + Sync + 'static, + S: NetworkGraphStateStore + Clone + Send + Sync + 'static, { pub(crate) fn new(source: Pubkey, min_fail_relax_interval: Option, store: S) -> Self { let mut s = PaymentHistory { source, inner: HashMap::new(), - failed_nodes: HashMap::new(), + nodes_to_channel_map: HashMap::new(), min_fail_relax_interval: min_fail_relax_interval .unwrap_or(DEFAULT_MIN_FAIL_RELAX_INTERVAL), bimodal_scale_msat: DEFAULT_BIMODAL_SCALE_SHANNONS, @@ -326,6 +330,7 @@ where #[cfg(test)] pub(crate) fn reset(&mut self) { self.inner.clear(); + self.nodes_to_channel_map.clear(); } pub(crate) fn add_result( @@ -343,11 +348,22 @@ where .insert_payment_history_result(channel, direction, result); } + fn add_node_channel_map(&mut self, node: Pubkey, channel: OutPoint) { + self.nodes_to_channel_map + .entry(node) + .or_default() + .insert(channel); + } + pub(crate) fn load_from_store(&mut self) { let results = self.store.get_payment_history_result(); for (channel, direction, result) in results.into_iter() { self.inner.insert((channel, direction), result); } + for channel in self.store.get_channels(None).iter() { + self.add_node_channel_map(channel.node1(), channel.out_point()); + self.add_node_channel_map(channel.node2(), channel.out_point()); + } } pub(crate) fn apply_pair_result( @@ -395,7 +411,11 @@ where } pub(crate) fn apply_internal_result(&mut self, result: InternalResult) { - let InternalResult { pairs, fail_node } = result; + let InternalResult { + pairs, + fail_node, + nodes_to_channel_map, + } = result; for ((channel, direction), pair_result) in pairs.into_iter() { self.apply_pair_result( channel, @@ -405,8 +425,32 @@ where pair_result.time, ); } + for (node, channels) in nodes_to_channel_map.into_iter() { + self.nodes_to_channel_map + .entry(node) + .or_default() + .extend(channels); + } if let Some(fail_node) = fail_node { - self.failed_nodes.insert(fail_node, current_time()); + let channels = self + .nodes_to_channel_map + .get(&fail_node) + .expect("channels not found"); + let pairs: Vec<(OutPoint, Direction)> = self + .inner + .iter() + .flat_map(|((outpoint, direction), _)| { + if channels.contains(outpoint) { + Some((outpoint.clone(), *direction)) + } else { + None + } + }) + .collect(); + + for (channel, direction) in pairs.into_iter() { + self.apply_pair_result(channel, direction, 0, false, now_timestamp_as_millis_u64()); + } } } diff --git a/src/fiber/tests/history.rs b/src/fiber/tests/history.rs index 438da557..fbc0cec4 100644 --- a/src/fiber/tests/history.rs +++ b/src/fiber/tests/history.rs @@ -434,6 +434,149 @@ fn test_history_apply_internal_result_fail_node() { )); } +#[test] +fn test_history_fail_node_with_multiple_channels() { + let mut internal_result = InternalResult::default(); + let mut history = PaymentHistory::new(generate_pubkey().into(), None, MemoryStore::default()); + let node1 = generate_pubkey(); + let node2 = generate_pubkey(); + let node3 = generate_pubkey(); + let channel_outpoint1 = gen_rand_outpoint(); + let channel_outpoint2 = gen_rand_outpoint(); + let channel_outpoint3 = gen_rand_outpoint(); + let channel_outpoint4 = gen_rand_outpoint(); + + let route1 = vec![ + SessionRouteNode { + pubkey: node1, + amount: 10, + channel_outpoint: channel_outpoint1.clone(), + }, + SessionRouteNode { + pubkey: node2, + amount: 5, + channel_outpoint: channel_outpoint2.clone(), + }, + SessionRouteNode { + pubkey: node3, + amount: 3, + channel_outpoint: OutPoint::default(), + }, + ]; + + let route2 = vec![ + SessionRouteNode { + pubkey: node1, + amount: 10, + channel_outpoint: channel_outpoint3.clone(), + }, + SessionRouteNode { + pubkey: node2, + amount: 5, + channel_outpoint: channel_outpoint4.clone(), + }, + SessionRouteNode { + pubkey: node3, + amount: 3, + channel_outpoint: OutPoint::default(), + }, + ]; + + let (direction1, rev_direction1) = output_direction(node1, node2); + let (direction2, rev_direction2) = output_direction(node2, node3); + + internal_result.succeed_range_pairs(&route1, 0, 2); + history.apply_internal_result(internal_result.clone()); + + assert!(matches!( + history.get_result(&channel_outpoint1, direction1), + Some(&TimedResult { + fail_amount: 0, + fail_time: 0, + success_amount: 10, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint2, direction2), + Some(&TimedResult { + fail_amount: 0, + fail_time: 0, + success_amount: 5, + .. + }) + )); + + internal_result.fail_node(&route2, 1); + assert_eq!(internal_result.pairs.len(), 6); + history.apply_internal_result(internal_result); + + assert!(matches!( + history.get_result(&channel_outpoint1, direction1), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint2, direction2), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint1, rev_direction1), + None, + )); + + assert!(matches!( + history.get_result(&channel_outpoint2, rev_direction2), + None, + )); + + assert!(matches!( + history.get_result(&channel_outpoint3, direction1), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint4, direction2), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint3, rev_direction1), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); + + assert!(matches!( + history.get_result(&channel_outpoint4, rev_direction2), + Some(&TimedResult { + fail_amount: 0, + success_amount: 0, + .. + }) + )); +} + #[test] fn test_history_interal_success_fail() { let mut history = PaymentHistory::new(generate_pubkey().into(), None, MemoryStore::default());