diff --git a/iroh-willow/src/auth.rs b/iroh-willow/src/auth.rs index 54d685a511..23d71a631f 100644 --- a/iroh-willow/src/auth.rs +++ b/iroh-willow/src/auth.rs @@ -165,7 +165,7 @@ impl Auth { selector: &CapSelector, ) -> Result, AuthError> { let cap = self.caps.read().unwrap().get_write_cap(selector); - debug!(?selector, ?cap, "get write cap"); + // debug!(?selector, ?cap, "get write cap"); Ok(cap) } @@ -174,7 +174,7 @@ impl Auth { selector: &CapSelector, ) -> Result, AuthError> { let cap = self.caps.read().unwrap().get_read_cap(selector); - debug!(?selector, ?cap, "get read cap"); + // debug!(?selector, ?cap, "get read cap"); Ok(cap) } diff --git a/iroh-willow/src/proto/sync.rs b/iroh-willow/src/proto/sync.rs index 2568ee65d5..62b53e9c3b 100644 --- a/iroh-willow/src/proto/sync.rs +++ b/iroh-willow/src/proto/sync.rs @@ -338,6 +338,18 @@ impl Message { pub fn same_kind(&self, other: &Self) -> bool { std::mem::discriminant(self) == std::mem::discriminant(other) } + + pub fn covers_region(&self) -> Option<(AreaOfInterestHandle, u64)> { + match self { + Message::ReconciliationSendFingerprint(msg) => { + msg.covers.map(|covers| (msg.receiver_handle, covers)) + } + Message::ReconciliationAnnounceEntries(msg) => { + msg.covers.map(|covers| (msg.receiver_handle, covers)) + } + _ => None, + } + } } impl Encoder for Message { diff --git a/iroh-willow/src/session/data.rs b/iroh-willow/src/session/data.rs index 2b2575b4c3..d067667a75 100644 --- a/iroh-willow/src/session/data.rs +++ b/iroh-willow/src/session/data.rs @@ -1,5 +1,3 @@ -use futures_lite::StreamExt; - use tokio::sync::broadcast; use crate::{ @@ -11,7 +9,6 @@ use crate::{ store::{traits::Storage, Origin, Store}, }; -use super::channels::MessageReceiver; use super::payload::{send_payload_chunked, CurrentPayload}; use super::Session; @@ -81,26 +78,18 @@ pub struct DataReceiver { session: Session, store: Store, current_payload: CurrentPayload, - recv: MessageReceiver, } impl DataReceiver { - pub fn new(session: Session, store: Store, recv: MessageReceiver) -> Self { + pub fn new(session: Session, store: Store) -> Self { Self { session, store, current_payload: Default::default(), - recv, - } - } - pub async fn run(mut self) -> Result<(), Error> { - while let Some(message) = self.recv.try_next().await? { - self.on_message(message).await?; } - Ok(()) } - async fn on_message(&mut self, message: DataMessage) -> Result<(), Error> { + pub async fn on_message(&mut self, message: DataMessage) -> Result<(), Error> { match message { DataMessage::SendEntry(message) => self.on_send_entry(message).await?, DataMessage::SendPayload(message) => self.on_send_payload(message).await?, diff --git a/iroh-willow/src/session/reconciler.rs b/iroh-willow/src/session/reconciler.rs index 1ecb7a9a26..5f05c6f179 100644 --- a/iroh-willow/src/session/reconciler.rs +++ b/iroh-willow/src/session/reconciler.rs @@ -21,14 +21,14 @@ use crate::{ traits::{EntryReader, EntryStorage, SplitAction, SplitOpts, Storage}, Origin, Store, }, - util::channel::WriteError, + util::{channel::WriteError, stream::Cancelable}, }; #[derive(derive_more::Debug)] pub struct Reconciler { session: Session, store: Store, - recv: MessageReceiver, + recv: Cancelable>, snapshot: ::Snapshot, current_payload: CurrentPayload, } @@ -37,7 +37,7 @@ impl Reconciler { pub fn new( session: Session, store: Store, - recv: MessageReceiver, + recv: Cancelable>, ) -> Result { let snapshot = store.entries().snapshot()?; Ok(Self { @@ -68,11 +68,8 @@ impl Reconciler { } } } - if self.session.reconciliation_is_complete() - && !self.session.mode().is_live() - && !self.current_payload.is_active() - { - debug!("reconciliation complete and not in live mode: close session"); + if self.session.reconciliation_is_complete() && !self.current_payload.is_active() { + debug!("reconciliation complete"); break; } } @@ -222,7 +219,7 @@ impl Reconciler { their_handle: AreaOfInterestHandle, covers: Option, ) -> anyhow::Result<()> { - self.session.mark_range_pending(our_handle); + self.session.mark_our_range_pending(our_handle); let msg = ReconciliationSendFingerprint { range, fingerprint, @@ -259,7 +256,7 @@ impl Reconciler { covers, }; if want_response { - self.session.mark_range_pending(our_handle); + self.session.mark_our_range_pending(our_handle); } self.send(msg).await?; for authorised_entry in self diff --git a/iroh-willow/src/session/run.rs b/iroh-willow/src/session/run.rs index cea8219c7a..57b56c05bb 100644 --- a/iroh-willow/src/session/run.rs +++ b/iroh-willow/src/session/run.rs @@ -1,13 +1,13 @@ use futures_lite::StreamExt; use strum::IntoEnumIterator; use tokio_util::sync::CancellationToken; -use tracing::{debug, error_span, trace}; +use tracing::{debug, error_span, trace, warn}; use crate::{ proto::sync::{ControlIssueGuarantee, LogicalChannel, Message, SetupBindAreaOfInterest}, session::{channels::LogicalChannelReceivers, Error, Scope, Session, SessionInit}, store::{traits::Storage, Store}, - util::channel::Receiver, + util::{channel::Receiver, stream::Cancelable}, }; use super::{ @@ -32,13 +32,21 @@ impl Session { logical_recv: LogicalChannelReceivers { reconciliation_recv, - mut static_tokens_recv, - mut capability_recv, - mut aoi_recv, + static_tokens_recv, + capability_recv, + aoi_recv, data_recv, }, } = recv; + // Make all our receivers close once the cancel_token is triggered. + let control_recv = Cancelable::new(control_recv, cancel_token.clone()); + let reconciliation_recv = Cancelable::new(reconciliation_recv, cancel_token.clone()); + let mut static_tokens_recv = Cancelable::new(static_tokens_recv, cancel_token.clone()); + let mut capability_recv = Cancelable::new(capability_recv, cancel_token.clone()); + let mut aoi_recv = Cancelable::new(aoi_recv, cancel_token.clone()); + let mut data_recv = Cancelable::new(data_recv, cancel_token.clone()); + // Spawn a task to handle incoming static tokens. self.spawn(error_span!("stt"), move |session| async move { while let Some(message) = static_tokens_recv.try_next().await? { @@ -52,7 +60,10 @@ impl Session { self.spawn(error_span!("dat:r"), { let store = store.clone(); move |session| async move { - DataReceiver::new(session, store, data_recv).run().await?; + let mut data_receiver = DataReceiver::new(session, store); + while let Some(message) = data_recv.try_next().await? { + data_receiver.on_message(message).await?; + } Ok(()) } }); @@ -74,13 +85,11 @@ impl Session { }); // Spawn a task to handle incoming areas of interest. - self.spawn(error_span!("aoi"), { - move |session| async move { - while let Some(message) = aoi_recv.try_next().await? { - session.on_bind_area_of_interest(message).await?; - } - Ok(()) + self.spawn(error_span!("aoi"), move |session| async move { + while let Some(message) = aoi_recv.try_next().await? { + session.on_bind_area_of_interest(message).await?; } + Ok(()) }); // Spawn a task to handle reconciliation messages @@ -88,18 +97,21 @@ impl Session { let cancel_token = cancel_token.clone(); let store = store.clone(); move |session| async move { - let res = Reconciler::new(session, store, reconciliation_recv)? + let res = Reconciler::new(session.clone(), store, reconciliation_recv)? .run() .await; - cancel_token.cancel(); + if !session.mode().is_live() { + debug!("reconciliation complete and not in live mode: close session"); + cancel_token.cancel(); + } res } }); // Spawn a task to handle control messages self.spawn(error_span!("ctl"), { - let cancel_token = cancel_token.clone(); let store = store.clone(); + let cancel_token = cancel_token.clone(); move |session| async move { let res = control_loop(session, store, control_recv, init).await; cancel_token.cancel(); @@ -107,50 +119,55 @@ impl Session { } }); - // Spawn a task to handle session termination. - self.spawn(error_span!("fin"), { - let cancel_token = cancel_token.clone(); - move |session| async move { - // Wait until the session is cancelled: - // * either because SessionMode is ReconcileOnce and reconciliation finished - // * or because the session was cancelled from the outside session handle - cancel_token.cancelled().await; - debug!("closing session"); - // Then close all senders. This will make all other tasks terminate once the remote - // closed their senders as well. - session.close_senders(); - // Unsubscribe from the store. This stops the data send task. - store.entries().unsubscribe(session.id()); - Ok(()) + // Wait until the session is cancelled, or until a task fails. + let result = loop { + tokio::select! { + _ = cancel_token.cancelled() => { + break Ok(()); + }, + Some((span, result)) = self.join_next_task() => { + let _guard = span.enter(); + trace!(?result, remaining = self.remaining_tasks(), "task complete"); + if let Err(err) = result { + warn!(?err, "session task failed: abort session"); + break Err(err); + } + }, } - }); + }; + + if result.is_err() { + self.abort_all_tasks(); + } else { + debug!("closing session"); + } - // Wait for all tasks to complete. - // We are not cancelling here so we have to make sure that all tasks terminate (structured - // concurrency basically). - let mut final_result = Ok(()); + // Unsubscribe from the store. This stops the data send task. + store.entries().unsubscribe(self.id()); + + // Wait for remaining tasks to terminate to catch any panics. + // TODO: Add timeout? while let Some((span, result)) = self.join_next_task().await { let _guard = span.enter(); - // trace!(?result, remaining = self.remaining_tasks(), "task complete"); - debug!(?result, remaining = self.remaining_tasks(), "task complete"); + trace!(?result, remaining = self.remaining_tasks(), "task complete"); if let Err(err) = result { - tracing::warn!(?err, "task failed: {err}"); - cancel_token.cancel(); - // self.abort_all_tasks(); - if final_result.is_ok() { - final_result = Err(err); - } + warn!("task failed: {err:?}"); } } - debug!(success = final_result.is_ok(), "session complete"); - final_result + + // Close our channel senders. + // This will stop the network send loop after all pending data has been sent. + self.close_senders(); + + debug!(success = result.is_ok(), "session complete"); + result } } async fn control_loop( session: Session, store: Store, - mut control_recv: Receiver, + mut control_recv: Cancelable>, init: SessionInit, ) -> Result<(), Error> { debug!(role = ?session.our_role(), "start session"); diff --git a/iroh-willow/src/session/state.rs b/iroh-willow/src/session/state.rs index b2388f5cfb..2323ade36d 100644 --- a/iroh-willow/src/session/state.rs +++ b/iroh-willow/src/session/state.rs @@ -8,7 +8,7 @@ use std::{ }; use futures_lite::Stream; -use tracing::{Instrument, Span}; +use tracing::{debug, trace, Instrument, Span}; use crate::{ proto::{ @@ -103,9 +103,19 @@ impl Session { self.0.tasks.borrow_mut().abort_all(); } - pub fn remaining_tasks(&self) -> usize { + // pub fn remaining_tasks(&self) -> usize { + // let tasks = self.0.tasks.borrow(); + // tasks.len() + // } + + pub fn remaining_tasks(&self) -> String { let tasks = self.0.tasks.borrow(); - tasks.len() + let mut out = vec![]; + for (span, _k) in tasks.iter() { + let name = span.metadata().unwrap().name(); + out.push(name.to_string()); + } + out.join(",") } pub fn log_remaining_tasks(&self) { @@ -114,10 +124,21 @@ impl Session { .iter() .map(|t| t.0.metadata().unwrap().name()) .collect::>(); - tracing::debug!(tasks=?names, "active_tasks"); + debug!(tasks=?names, "active_tasks"); } pub async fn send(&self, message: impl Into) -> Result<(), WriteError> { + let message: Message = message.into(); + if let Some((their_handle, range_count)) = message.covers_region() { + if let Err(err) = self + .state_mut() + .mark_their_range_covered(their_handle, range_count) + { + // TODO: Is this really unreachable? I think so, as this would indicate a logic + // error purely on our side. + unreachable!("mark_their_range_covered: {err:?}"); + } + } self.0.send.send(message).await } @@ -196,7 +217,7 @@ impl Session { Ok((our_handle, maybe_message)) } - pub fn mark_range_pending(&self, our_handle: AreaOfInterestHandle) { + pub fn mark_our_range_pending(&self, our_handle: AreaOfInterestHandle) { let mut state = self.state_mut(); state.reconciliation_started = true; let range_count = state.our_range_counter; @@ -211,7 +232,7 @@ impl Session { let range_count = { let mut state = self.state_mut(); if let Some(range_count) = message.covers { - state.mark_range_covered(message.receiver_handle, range_count)?; + state.mark_our_range_covered(message.receiver_handle, range_count)?; } if state.pending_announced_entries.is_some() { return Err(Error::InvalidMessageInCurrentState); @@ -220,8 +241,7 @@ impl Session { state.pending_announced_entries = Some(message.count); } if message.want_response { - let range_count = state.their_range_counter; - state.their_range_counter += 1; + let range_count = state.add_pending_range_theirs(message.sender_handle); Some(range_count) } else { None @@ -245,11 +265,9 @@ impl Session { let mut state = self.state_mut(); state.reconciliation_started = true; if let Some(range_count) = message.covers { - state.mark_range_covered(message.receiver_handle, range_count)?; + state.mark_our_range_covered(message.receiver_handle, range_count)?; } - let range_count = state.their_range_counter; - state.their_range_counter += 1; - range_count + state.add_pending_range_theirs(message.sender_handle) }; let namespace = self @@ -295,7 +313,7 @@ impl Session { pub fn on_setup_bind_read_capability(&self, msg: SetupBindReadCapability) -> Result<(), Error> { // TODO: verify intersection handle - tracing::debug!("setup bind cap {msg:?}"); + trace!("received capability {msg:?}"); msg.capability.validate()?; let mut state = self.state_mut(); state @@ -308,14 +326,16 @@ impl Session { pub fn reconciliation_is_complete(&self) -> bool { let state = self.state(); // tracing::debug!( - // "reconciliation_is_complete started {} pending_ranges {}, pending_entries {:?} mode {:?}", + // "reconciliation_is_complete started {} our_pending_ranges {}, their_pending_ranges {}, pending_entries {:?} mode {:?}", // state.reconciliation_started, // state.our_uncovered_ranges.len(), + // state.their_uncovered_ranges.len(), // state.pending_announced_entries, // self.mode(), // ); state.reconciliation_started && state.our_uncovered_ranges.is_empty() + && state.their_uncovered_ranges.is_empty() && state.pending_announced_entries.is_none() } @@ -457,6 +477,7 @@ struct SessionState { our_range_counter: u64, their_range_counter: u64, our_uncovered_ranges: HashSet<(AreaOfInterestHandle, u64)>, + their_uncovered_ranges: HashSet<(AreaOfInterestHandle, u64)>, pending_announced_entries: Option, intersection_queue: Queue, } @@ -476,6 +497,7 @@ impl SessionState { our_range_counter: 0, their_range_counter: 0, our_uncovered_ranges: Default::default(), + their_uncovered_ranges: Default::default(), pending_announced_entries: Default::default(), intersection_queue: Default::default(), } @@ -542,7 +564,7 @@ impl SessionState { Ok(()) } - fn mark_range_covered( + fn mark_our_range_covered( &mut self, our_handle: AreaOfInterestHandle, range_count: u64, @@ -553,4 +575,29 @@ impl SessionState { Ok(()) } } + + fn mark_their_range_covered( + &mut self, + their_handle: AreaOfInterestHandle, + range_count: u64, + ) -> Result<(), Error> { + // trace!(?their_handle, ?range_count, "mark_their_range_covered"); + if !self + .their_uncovered_ranges + .remove(&(their_handle, range_count)) + { + Err(Error::InvalidMessageInCurrentState) + } else { + Ok(()) + } + } + + fn add_pending_range_theirs(&mut self, their_handle: AreaOfInterestHandle) -> u64 { + let range_count = self.their_range_counter; + self.their_range_counter += 1; + // debug!(?their_handle, ?range_count, "add_pending_range_theirs"); + self.their_uncovered_ranges + .insert((their_handle, range_count)); + range_count + } } diff --git a/iroh-willow/src/util.rs b/iroh-willow/src/util.rs index ff1d8002ba..f417fb773f 100644 --- a/iroh-willow/src/util.rs +++ b/iroh-willow/src/util.rs @@ -3,5 +3,6 @@ pub mod channel; pub mod codec; pub mod queue; +pub mod stream; pub mod task; pub mod time; diff --git a/iroh-willow/src/util/channel.rs b/iroh-willow/src/util/channel.rs index 87f56a0229..b86d179860 100644 --- a/iroh-willow/src/util/channel.rs +++ b/iroh-willow/src/util/channel.rs @@ -181,13 +181,13 @@ impl Shared { } fn writable_slice_exact(&mut self, len: usize) -> Option<&mut [u8]> { - tracing::trace!( - "write {}, remaining {} (guarantees {}, buf capacity {})", - len, - self.remaining_write_capacity(), - self.guarantees.get(), - self.max_buffer_size - self.buf.len() - ); + // tracing::trace!( + // "write {}, remaining {} (guarantees {}, buf capacity {})", + // len, + // self.remaining_write_capacity(), + // self.guarantees.get(), + // self.max_buffer_size - self.buf.len() + // ); if self.remaining_write_capacity() < len { None } else { diff --git a/iroh-willow/src/util/queue.rs b/iroh-willow/src/util/queue.rs index b131c9edbe..325cece9b4 100644 --- a/iroh-willow/src/util/queue.rs +++ b/iroh-willow/src/util/queue.rs @@ -8,10 +8,10 @@ use std::{ use futures_lite::Stream; -/// A simple unbounded queue. +/// A simple unbounded FIFO queue. /// /// Values are pushed into the queue, synchronously. -/// The queue can be polled for the next value from the start. +/// The queue can be polled for the next value asynchronously. #[derive(Debug)] pub struct Queue { items: VecDeque, diff --git a/iroh-willow/src/util/stream.rs b/iroh-willow/src/util/stream.rs new file mode 100644 index 0000000000..3ac9cc6776 --- /dev/null +++ b/iroh-willow/src/util/stream.rs @@ -0,0 +1,48 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::Stream; +use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned}; + +/// Wrapper around [`Stream`] that takes a cancel token to cancel the stream. +/// +/// Once the cancel token is cancelled, this stream will continue to yield all items which are +/// ready immediately and then return [`None`]. +#[derive(Debug)] +pub struct Cancelable { + stream: S, + cancelled: Pin>, + is_cancelled: bool, +} + +impl Cancelable { + pub fn new(stream: S, cancel_token: CancellationToken) -> Self { + Self { + stream, + cancelled: Box::pin(cancel_token.cancelled_owned()), + is_cancelled: false, + } + } +} + +impl Stream for Cancelable { + type Item = S::Item; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_cancelled { + return Poll::Ready(None); + } + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(r) => Poll::Ready(r), + Poll::Pending => match Pin::new(&mut self.cancelled).poll(cx) { + Poll::Ready(()) => { + self.is_cancelled = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + }, + } + } +} diff --git a/iroh-willow/src/util/task.rs b/iroh-willow/src/util/task.rs index 77dea91243..f6394303a0 100644 --- a/iroh-willow/src/util/task.rs +++ b/iroh-willow/src/util/task.rs @@ -51,8 +51,7 @@ impl JoinMap { pub fn spawn_local + 'static>(&mut self, key: K, future: F) -> TaskKey { let handle = tokio::task::spawn_local(future); let abort_handle = handle.abort_handle(); - let k = self.tasks.insert(handle); - let k = TaskKey(k); + let k = TaskKey(self.tasks.insert(handle)); self.keys.insert(k, key); self.abort_handles.insert(k, abort_handle); k