diff --git a/Cargo.lock b/Cargo.lock index 0bb3a7fd79..2ab9d71e30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2901,6 +2901,7 @@ dependencies = [ "futures", "futures-concurrency", "futures-lite 2.3.0", + "futures-util", "genawaiter", "hex", "iroh-base", diff --git a/iroh-willow/Cargo.toml b/iroh-willow/Cargo.toml index aecbcd75f7..b87c4a9101 100644 --- a/iroh-willow/Cargo.toml +++ b/iroh-willow/Cargo.toml @@ -54,6 +54,7 @@ smallvec = "1.13.2" itertools = "0.12.1" futures-lite = "2.3.0" futures-concurrency = "7.6.0" +futures-util = "0.3.30" [dev-dependencies] iroh-test = { path = "../iroh-test" } diff --git a/iroh-willow/src/actor.rs b/iroh-willow/src/actor.rs index 45b14f0b9d..b20dacffe7 100644 --- a/iroh-willow/src/actor.rs +++ b/iroh-willow/src/actor.rs @@ -9,7 +9,11 @@ use std::{ thread::JoinHandle, }; -use futures::{future::LocalBoxFuture, FutureExt}; +use futures_lite::{ + future::{Boxed as BoxFuture, BoxedLocal as LocalBoxFuture}, + stream::Stream, +}; +use futures_util::future::{FutureExt, Shared}; use genawaiter::{ sync::{Co, Gen}, GeneratorState, @@ -19,6 +23,7 @@ use tokio::sync::oneshot; use tracing::{debug, error, error_span, trace, warn, Span}; use crate::{ + net::InitialTransmission, proto::{ grouping::ThreeDRange, keys::NamespaceId, @@ -27,7 +32,7 @@ use crate::{ }, session::{ coroutine::{ControlRoutine, ReconcileRoutine}, - Channels, Error, SessionInit, SessionState, SharedSessionState, + Channels, Error, Role, SessionInit, SessionState, SharedSessionState, }, store::Store, }; @@ -37,7 +42,7 @@ pub const INBOX_CAP: usize = 1024; pub type SessionId = NodeId; #[derive(Debug, Clone)] -pub struct StoreHandle { +pub struct WillowHandle { tx: flume::Sender, join_handle: Arc>>, } @@ -85,8 +90,8 @@ impl Notifier { } } -impl StoreHandle { - pub fn spawn(store: S, me: NodeId) -> StoreHandle { +impl WillowHandle { + pub fn spawn(store: S, me: NodeId) -> WillowHandle { let (tx, rx) = flume::bounded(INBOX_CAP); // This channel only tracks wake to resume messages to coroutines, which are a sinlge u64 // per wakeup. We want to issue wake calls synchronosuly without blocking, so we use an @@ -116,7 +121,7 @@ impl StoreHandle { }) .expect("failed to spawn thread"); let join_handle = Arc::new(Some(join_handle)); - StoreHandle { tx, join_handle } + WillowHandle { tx, join_handle } } pub async fn send(&self, action: ToActor) -> anyhow::Result<()> { self.tx.send_async(action).await?; @@ -132,9 +137,56 @@ impl StoreHandle { reply_rx.await??; Ok(()) } + + pub async fn get_entries( + &self, + namespace: NamespaceId, + range: ThreeDRange, + ) -> anyhow::Result> { + let (tx, rx) = flume::bounded(1024); + self.send(ToActor::GetEntries { + namespace, + reply: tx, + range, + }) + .await?; + Ok(rx.into_stream()) + } + + pub async fn init_session( + &self, + peer: NodeId, + our_role: Role, + initial_transmission: InitialTransmission, + channels: Channels, + init: SessionInit, + ) -> anyhow::Result { + let state = SessionState::new(our_role, initial_transmission); + + let (on_finish_tx, on_finish_rx) = oneshot::channel(); + self.send(ToActor::InitSession { + peer, + state, + channels, + init, + on_finish: on_finish_tx, + }) + .await?; + + let on_finish = on_finish_rx + .map(|r| match r { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(Arc::new(err.into())), + Err(_) => Err(Arc::new(Error::ActorFailed)), + }) + .boxed(); + let on_finish = on_finish.shared(); + let handle = SessionHandle { on_finish }; + Ok(handle) + } } -impl Drop for StoreHandle { +impl Drop for WillowHandle { fn drop(&mut self) { // this means we're dropping the last reference if let Some(handle) = Arc::get_mut(&mut self.join_handle) { @@ -146,6 +198,21 @@ impl Drop for StoreHandle { } } } + +#[derive(Debug)] +pub struct SessionHandle { + on_finish: Shared>>>, +} + +impl SessionHandle { + /// Wait for the session to finish. + /// + /// Returns an error if the session failed to complete. + pub async fn on_finish(self) -> Result<(), Arc> { + self.on_finish.await + } +} + #[derive(derive_more::Debug, strum::Display)] pub enum ToActor { InitSession { @@ -155,7 +222,7 @@ pub enum ToActor { #[debug(skip)] channels: Channels, init: SessionInit, - on_done: oneshot::Sender>, + on_finish: oneshot::Sender>, }, GetEntries { namespace: NamespaceId, @@ -195,7 +262,7 @@ pub struct StorageThread { next_coro_id: u64, } -type CoroFut = LocalBoxFuture<'static, Result<(), Error>>; +type CoroFut = LocalBoxFuture>; #[derive(derive_more::Debug)] struct CoroutineState { @@ -257,7 +324,7 @@ impl StorageThread { state, channels, init, - on_done, + on_finish: on_done, } => { let span = error_span!("session", peer=%peer.fmt_short()); let session = Session { diff --git a/iroh-willow/src/net.rs b/iroh-willow/src/net.rs index 9de0d5d46d..a6d24ad278 100644 --- a/iroh-willow/src/net.rs +++ b/iroh-willow/src/net.rs @@ -4,20 +4,19 @@ use futures_concurrency::future::TryJoin; use iroh_base::{hash::Hash, key::NodeId}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - sync::oneshot, task::JoinSet, }; use tracing::{debug, error_span, instrument, trace, warn, Instrument}; use crate::{ - actor::{StoreHandle, ToActor}, + actor::WillowHandle, proto::wgps::{ AccessChallenge, ChallengeHash, LogicalChannel, Message, CHALLENGE_HASH_LENGTH, MAX_PAYLOAD_SIZE_POWER, }, session::{ channels::{Channels, LogicalChannelReceivers, LogicalChannelSenders}, - Role, SessionInit, SessionState, + Role, SessionInit, }, util::channel::{ inbound_channel, outbound_channel, Guarantees, Reader, Receiver, Sender, Writer, @@ -29,7 +28,7 @@ pub const CHANNEL_CAP: usize = 1024 * 64; #[instrument(skip_all, name = "willow_net", fields(me=%me.fmt_short(), peer=%peer.fmt_short()))] pub async fn run( me: NodeId, - store: StoreHandle, + store: WillowHandle, conn: quinn::Connection, peer: NodeId, our_role: Role, @@ -37,19 +36,15 @@ pub async fn run( ) -> anyhow::Result<()> { debug!(?our_role, "connected"); let mut join_set = JoinSet::new(); + let (mut control_send_stream, mut control_recv_stream) = match our_role { Role::Alfie => conn.open_bi().await?, Role::Betty => conn.accept_bi().await?, }; control_send_stream.set_priority(i32::MAX)?; - let our_nonce: AccessChallenge = rand::random(); - let (received_commitment, max_payload_size) = exchange_commitments( - &mut control_send_stream, - &mut control_recv_stream, - &our_nonce, - ) - .await?; + let initial_transmission = + exchange_commitments(&mut control_send_stream, &mut control_recv_stream).await?; debug!("commitments exchanged"); let (control_send, control_recv) = spawn_channel( @@ -70,21 +65,12 @@ pub async fn run( logical_send, logical_recv, }; - let state = SessionState::new(our_role, our_nonce, received_commitment, max_payload_size); - - let (on_done, on_done_rx) = oneshot::channel(); - store - .send(ToActor::InitSession { - peer, - state, - channels, - init, - on_done, - }) + let handle = store + .init_session(peer, our_role, initial_transmission, channels, init) .await?; join_set.spawn(async move { - on_done_rx.await??; + handle.on_finish().await?; Ok(()) }); @@ -93,23 +79,6 @@ pub async fn run( Ok(()) } -async fn join_all(mut join_set: JoinSet>) -> anyhow::Result<()> { - let mut final_result = Ok(()); - while let Some(res) = join_set.join_next().await { - let res = match res { - Ok(Ok(())) => Ok(()), - Ok(Err(err)) => Err(err), - Err(err) => Err(err.into()), - }; - if res.is_err() && final_result.is_ok() { - final_result = res; - } else if res.is_err() { - warn!("join error after initial error: {res:?}"); - } - } - final_result -} - #[derive(Debug, thiserror::Error)] #[error("missing channel: {0:?}")] struct MissingChannel(LogicalChannel); @@ -122,6 +91,8 @@ async fn open_logical_channels( let cap = CHANNEL_CAP; let channels = [LogicalChannel::Reconciliation, LogicalChannel::StaticToken]; let mut channels = match our_role { + // Alfie opens a quic stream for each logical channel, and sends a single byte with the + // channel id. Role::Alfie => { channels .map(|ch| { @@ -136,6 +107,8 @@ async fn open_logical_channels( .try_join() .await } + // Alfie accepts as many quick streams as there are logical channels, and reads a single + // byte on each, which is expected to contain a channel id. Role::Betty => { channels .map(|_| async { @@ -149,7 +122,7 @@ async fn open_logical_channels( } }?; - let mut take_channel = |ch| { + let mut take_and_spawn_channel = |ch| { channels .iter_mut() .find(|(c, _)| *c == ch) @@ -169,8 +142,8 @@ async fn open_logical_channels( }) }; - let rec = take_channel(LogicalChannel::Reconciliation)?; - let stt = take_channel(LogicalChannel::StaticToken)?; + let rec = take_and_spawn_channel(LogicalChannel::Reconciliation)?; + let stt = take_and_spawn_channel(LogicalChannel::StaticToken)?; Ok(( LogicalChannelSenders { reconciliation: rec.0, @@ -183,25 +156,6 @@ async fn open_logical_channels( )) } -// async fn open_logical_channel( -// join_set: &mut JoinSet>, -// conn: &quinn::Connection, -// ch: LogicalChannel, -// ) -> anyhow::Result<(Sender, Receiver)> { -// let (mut send_stream, recv_stream) = conn.open_bi().await?; -// send_stream.write_u8(ch as u8).await?; -// let cap = CHANNEL_CAP; -// Ok(spawn_channel( -// join_set, -// ch, -// cap, -// cap, -// Guarantees::Limited(0), -// send_stream, -// recv_stream, -// )) -// } - fn spawn_channel( join_set: &mut JoinSet>, ch: LogicalChannel, @@ -233,7 +187,8 @@ async fn recv_loop( mut recv_stream: quinn::RecvStream, mut channel_writer: Writer, ) -> anyhow::Result<()> { - while let Some(buf) = recv_stream.read_chunk(CHANNEL_CAP, true).await? { + let max_buffer_size = channel_writer.max_buffer_size(); + while let Some(buf) = recv_stream.read_chunk(max_buffer_size, true).await? { channel_writer.write_all(&buf.bytes[..]).await?; trace!(len = buf.bytes.len(), "recv"); } @@ -255,30 +210,61 @@ async fn send_loop( } async fn exchange_commitments( - send: &mut quinn::SendStream, - recv: &mut quinn::RecvStream, - our_nonce: &AccessChallenge, -) -> anyhow::Result<(ChallengeHash, usize)> { + send_stream: &mut quinn::SendStream, + recv_stream: &mut quinn::RecvStream, +) -> anyhow::Result { + let our_nonce: AccessChallenge = rand::random(); let challenge_hash = Hash::new(&our_nonce); - send.write_u8(MAX_PAYLOAD_SIZE_POWER).await?; - send.write_all(challenge_hash.as_bytes()).await?; + send_stream.write_u8(MAX_PAYLOAD_SIZE_POWER).await?; + send_stream.write_all(challenge_hash.as_bytes()).await?; let their_max_payload_size = { - let power = recv.read_u8().await?; + let power = recv_stream.read_u8().await?; ensure!(power <= 64, "max payload size too large"); - 2usize.pow(power as u32) + 2u64.pow(power as u32) }; let mut received_commitment = [0u8; CHALLENGE_HASH_LENGTH]; - recv.read_exact(&mut received_commitment).await?; - Ok((received_commitment, their_max_payload_size)) + recv_stream.read_exact(&mut received_commitment).await?; + Ok(InitialTransmission { + our_nonce, + received_commitment, + their_max_payload_size, + }) +} + +#[derive(Debug)] +pub struct InitialTransmission { + pub our_nonce: AccessChallenge, + pub received_commitment: ChallengeHash, + pub their_max_payload_size: u64, +} + +async fn join_all(mut join_set: JoinSet>) -> anyhow::Result<()> { + let mut final_result = Ok(()); + while let Some(res) = join_set.join_next().await { + let res = match res { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(err), + Err(err) => Err(err.into()), + }; + if res.is_err() && final_result.is_ok() { + final_result = res; + } else if res.is_err() { + warn!("join error after initial error: {res:?}"); + } + } + final_result } #[cfg(test)] mod tests { - use std::{collections::HashSet, time::Instant}; + use std::{ + collections::HashSet, + time::{Instant}, + }; - use futures::StreamExt; + use futures_lite::StreamExt; use iroh_base::{hash::Hash, key::SecretKey}; use iroh_net::MagicEndpoint; use rand::SeedableRng; @@ -286,7 +272,7 @@ mod tests { use tracing::{debug, info}; use crate::{ - actor::{StoreHandle, ToActor}, + actor::WillowHandle, net::run, proto::{ grouping::{AreaOfInterest, ThreeDRange}, @@ -350,10 +336,10 @@ mod tests { let mut expected_entries = HashSet::new(); let store_alfie = MemoryStore::default(); - let handle_alfie = StoreHandle::spawn(store_alfie, node_id_alfie); + let handle_alfie = WillowHandle::spawn(store_alfie, node_id_alfie); let store_betty = MemoryStore::default(); - let handle_betty = StoreHandle::spawn(store_betty, node_id_betty); + let handle_betty = WillowHandle::spawn(store_betty, node_id_betty); let init_alfie = setup_and_insert( &mut rng, @@ -378,6 +364,34 @@ mod tests { println!("init took {:?}", start.elapsed()); let start = Instant::now(); + // tokio::task::spawn({ + // let handle_alfie = handle_alfie.clone(); + // let handle_betty = handle_betty.clone(); + // async move { + // loop { + // info!( + // "alfie count: {}", + // handle_alfie + // .get_entries(namespace_id, ThreeDRange::full()) + // .await + // .unwrap() + // .count() + // .await + // ); + // info!( + // "betty count: {}", + // handle_betty + // .get_entries(namespace_id, ThreeDRange::full()) + // .await + // .unwrap() + // .count() + // .await + // ); + // tokio::time::sleep(Duration::from_secs(1)).await; + // } + // } + // }); + let (res_alfie, res_betty) = tokio::join!( run( node_id_alfie, @@ -426,24 +440,20 @@ mod tests { Ok(()) } async fn get_entries( - store: &StoreHandle, + store: &WillowHandle, namespace: NamespaceId, ) -> anyhow::Result> { - let (tx, rx) = flume::bounded(1024); - store - .send(ToActor::GetEntries { - namespace, - reply: tx, - range: ThreeDRange::full(), - }) - .await?; - let entries: HashSet<_> = rx.into_stream().collect::>().await; + let entries: HashSet<_> = store + .get_entries(namespace, ThreeDRange::full()) + .await? + .collect::>() + .await; Ok(entries) } async fn setup_and_insert( rng: &mut impl CryptoRngCore, - store: &StoreHandle, + store: &WillowHandle, namespace_secret: &NamespaceSecretKey, count: usize, track_entries: &mut impl Extend, diff --git a/iroh-willow/src/session/coroutine.rs b/iroh-willow/src/session/coroutine.rs index 28a81176b0..847fdb6960 100644 --- a/iroh-willow/src/session/coroutine.rs +++ b/iroh-willow/src/session/coroutine.rs @@ -8,7 +8,6 @@ use tracing::{debug, trace}; use crate::{ actor::{InitWithArea, WakeableCo, Yield}, - net::CHANNEL_CAP, proto::{ grouping::ThreeDRange, keys::NamespaceId, @@ -24,6 +23,8 @@ use crate::{ util::channel::{ReadError, WriteError}, }; +const INITIAL_GUARANTEES: u64 = u64::MAX; + #[derive(derive_more::Debug)] pub struct ControlRoutine { channels: Channels, @@ -43,7 +44,7 @@ impl ControlRoutine { let reveal_message = self.state().commitment_reveal()?; self.send(reveal_message).await?; let msg = ControlIssueGuarantee { - amount: CHANNEL_CAP as u64, + amount: INITIAL_GUARANTEES, channel: LogicalChannel::Reconciliation, }; self.send(msg).await?; diff --git a/iroh-willow/src/session/error.rs b/iroh-willow/src/session/error.rs index 694ba965bd..9580fa5b32 100644 --- a/iroh-willow/src/session/error.rs +++ b/iroh-willow/src/session/error.rs @@ -41,6 +41,8 @@ pub enum Error { InvalidParameters(&'static str), #[error("reached an invalid state")] InvalidState(&'static str), + #[error("actor failed to respond")] + ActorFailed, } impl From for Error { diff --git a/iroh-willow/src/session/state.rs b/iroh-willow/src/session/state.rs index 5f18efd6c3..fb3d2f2de7 100644 --- a/iroh-willow/src/session/state.rs +++ b/iroh-willow/src/session/state.rs @@ -2,14 +2,17 @@ use std::{cell::RefCell, collections::HashSet, rc::Rc}; use tracing::warn; -use crate::proto::{ - challenge::ChallengeState, - grouping::ThreeDRange, - keys::{NamespaceId, UserSecretKey}, - wgps::{ - AccessChallenge, AreaOfInterestHandle, CapabilityHandle, ChallengeHash, CommitmentReveal, - IntersectionHandle, Message, ReadCapability, SetupBindAreaOfInterest, - SetupBindReadCapability, SetupBindStaticToken, StaticToken, StaticTokenHandle, +use crate::{ + net::InitialTransmission, + proto::{ + challenge::ChallengeState, + grouping::ThreeDRange, + keys::{NamespaceId, UserSecretKey}, + wgps::{ + AreaOfInterestHandle, CapabilityHandle, CommitmentReveal, IntersectionHandle, Message, + ReadCapability, SetupBindAreaOfInterest, SetupBindReadCapability, SetupBindStaticToken, + StaticToken, StaticTokenHandle, + }, }, }; @@ -28,16 +31,12 @@ pub struct SessionState { } impl SessionState { - pub fn new( - our_role: Role, - our_nonce: AccessChallenge, - received_commitment: ChallengeHash, - _their_maximum_payload_size: usize, - ) -> Self { + pub fn new(our_role: Role, initial_transmission: InitialTransmission) -> Self { let challenge_state = ChallengeState::Committed { - our_nonce, - received_commitment, + our_nonce: initial_transmission.our_nonce, + received_commitment: initial_transmission.received_commitment, }; + // TODO: make use of initial_transmission.their_max_payload_size. Self { our_role, challenge: challenge_state, diff --git a/iroh-willow/src/util/channel.rs b/iroh-willow/src/util/channel.rs index 6fe887649e..5a7e4c408a 100644 --- a/iroh-willow/src/util/channel.rs +++ b/iroh-willow/src/util/channel.rs @@ -158,6 +158,13 @@ impl Shared { } fn writable_slice_exact(&mut self, len: usize) -> Option<&mut [u8]> { + tracing::trace!( + "write {}, remaining {} (guarantees {}, buf capacity {})", + len, + self.remaining_write_capacity(), + self.guarantees.get(), + self.max_buffer_size - self.buf.len() + ); if self.remaining_write_capacity() < len { None } else { @@ -290,6 +297,9 @@ impl Writer { pub fn close(&self) { self.shared.lock().unwrap().close() } + pub fn max_buffer_size(&self) -> usize { + self.shared.lock().unwrap().max_buffer_size + } } impl AsyncWrite for Writer {