diff --git a/src/fiber/channel.rs b/src/fiber/channel.rs index 1ae4ccb4..71d3b3b4 100644 --- a/src/fiber/channel.rs +++ b/src/fiber/channel.rs @@ -480,14 +480,14 @@ where FiberChannelMessage::AddTlc(add_tlc) => { // TODO: here we only check the error which sender didn't follow agreed rules, // if any error happened here we need go to shutdown procedure - state.check_for_tlc_update(Some(add_tlc.amount))?; + state.check_for_tlc_update(Some(add_tlc.amount), false)?; state .staging_tlc_operations .push(TlcOperation::AddTlc(add_tlc.clone())); Ok(()) } FiberChannelMessage::RemoveTlc(remove_tlc) => { - state.check_for_tlc_update(None)?; + state.check_for_tlc_update(None, false)?; state .staging_tlc_operations .push(TlcOperation::RemoveTlc(remove_tlc.clone())); @@ -1164,7 +1164,7 @@ where ) -> Result { debug!("handle add tlc command : {:?}", &command); state.check_ack_status_for_tlc()?; - state.check_for_tlc_update(Some(command.amount))?; + state.check_for_tlc_update(Some(command.amount), true)?; state.check_tlc_expiry(command.expiry)?; let tlc = state.create_outbounding_tlc(command); state.insert_tlc(tlc.clone())?; @@ -1206,7 +1206,7 @@ where command: RemoveTlcCommand, ) -> ProcessingChannelResult { state.check_ack_status_for_tlc()?; - state.check_for_tlc_update(None)?; + state.check_for_tlc_update(None, false)?; let tlc = state.remove_tlc_with_reason(TLCId::Received(command.id), &command.reason)?; let msg = FiberMessageWithPeerId::new( state.get_remote_peer_id(), @@ -4228,7 +4228,11 @@ impl ChannelActorState { Ok(()) } - fn check_for_tlc_update(&self, add_tlc_amount: Option) -> ProcessingChannelResult { + fn check_for_tlc_update( + &self, + add_tlc_amount: Option, + check_for_remote: bool, + ) -> ProcessingChannelResult { match self.state { ChannelState::ChannelReady() => {} ChannelState::ShuttingDown(_) if add_tlc_amount.is_none() => {} @@ -4246,23 +4250,37 @@ impl ChannelActorState { } if let Some(add_amount) = add_tlc_amount { - let active_tls_number = self.get_active_offered_tlcs(true).count() - + self.get_active_received_tlcs(true).count(); - - if active_tls_number as u64 + 1 > self.max_tlc_number_in_flight { - return Err(ProcessingChannelError::TlcNumberExceedLimit); + self.check_tlc_limits(add_amount, true)?; + if check_for_remote { + // TODO: this should be replaced by using the remote channel's max_tlc_number_in_flight and max_tlc_value_in_flight + self.check_tlc_limits(add_amount, false)?; } + } + Ok(()) + } - if self - .get_active_received_tlcs(true) - .chain(self.get_active_offered_tlcs(true)) - .fold(0_u128, |sum, tlc| sum + tlc.tlc.amount) - + add_amount - > self.max_tlc_value_in_flight - { - return Err(ProcessingChannelError::TlcValueInflightExceedLimit); - } + fn check_tlc_limits( + &self, + add_amount: u128, + local: bool, + ) -> Result<(), ProcessingChannelError> { + let active_tls_number = self.get_active_offered_tlcs(local).count() + + self.get_active_received_tlcs(local).count(); + + if active_tls_number as u64 + 1 > self.max_tlc_number_in_flight { + return Err(ProcessingChannelError::TlcNumberExceedLimit); + } + + if self + .get_active_received_tlcs(local) + .chain(self.get_active_offered_tlcs(local)) + .fold(0_u128, |sum, tlc| sum + tlc.tlc.amount) + + add_amount + > self.max_tlc_value_in_flight + { + return Err(ProcessingChannelError::TlcValueInflightExceedLimit); } + Ok(()) }