diff --git a/iroh-willow/src/net.rs b/iroh-willow/src/net.rs index ffe9cdef817..8751e477cf5 100644 --- a/iroh-willow/src/net.rs +++ b/iroh-willow/src/net.rs @@ -1,21 +1,20 @@ use anyhow::ensure; +use futures::TryFutureExt; use iroh_base::{hash::Hash, key::NodeId}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot, task::JoinSet, }; -use tracing::{debug, error_span, instrument, trace, Instrument}; +use tracing::{debug, error_span, instrument, trace, warn, Instrument}; use crate::{ proto::wgps::{ AccessChallenge, ChallengeHash, LogicalChannel, Message, CHALLENGE_HASH_LENGTH, MAX_PAYLOAD_SIZE_POWER, }, - session::{ - coroutine::{Channels, Readyness}, - Role, SessionInit, SessionState, - }, - store::actor::{Interest, Notifier, StoreHandle, ToActor}, + session::{coroutine::Channels, Role, SessionInit, SessionState}, + store::actor::{StoreHandle, ToActor}, util::{ channel::{channel, Receiver, Sender}, Decoder, Encoder, @@ -24,6 +23,8 @@ use crate::{ const CHANNEL_CAP: usize = 1024 * 64; +const ERROR_CODE_CLOSE_GRACEFUL: u16 = 1; + #[instrument(skip_all, fields(me=%me.fmt_short(), role=?our_role))] pub async fn run( me: NodeId, @@ -33,6 +34,7 @@ pub async fn run( our_role: Role, init: SessionInit, ) -> anyhow::Result<()> { + let mut join_set = JoinSet::new(); let (mut control_send_stream, mut control_recv_stream) = match our_role { Role::Alfie => conn.open_bi().await?, Role::Betty => conn.accept_bi().await?, @@ -49,33 +51,30 @@ pub async fn run( .await?; debug!("exchanged comittments"); - let (mut reconciliation_send_stream, mut reconciliation_recv_stream) = match our_role { - Role::Alfie => conn.open_bi().await?, - Role::Betty => conn.accept_bi().await?, - }; - reconciliation_send_stream.write_u8(0u8).await?; - reconciliation_recv_stream.read_u8().await?; - debug!("reconcile channel open"); - - let mut join_set = JoinSet::new(); let (control_send, control_recv) = spawn_channel( &mut join_set, - &store, peer, LogicalChannel::Control, CHANNEL_CAP, control_send_stream, control_recv_stream, ); + + let (mut reconciliation_send_stream, mut reconciliation_recv_stream) = match our_role { + Role::Alfie => conn.open_bi().await?, + Role::Betty => conn.accept_bi().await?, + }; + reconciliation_send_stream.write_u8(0u8).await?; + reconciliation_recv_stream.read_u8().await?; let (reconciliation_send, reconciliation_recv) = spawn_channel( &mut join_set, - &store, peer, LogicalChannel::Reconciliation, CHANNEL_CAP, reconciliation_send_stream, reconciliation_recv_stream, ); + debug!("reconcile channel open"); let channels = Channels { control_send, @@ -83,44 +82,46 @@ pub async fn run( reconciliation_send, reconciliation_recv, }; - let state = SessionState::new( - our_role, - peer, - our_nonce, - received_commitment, - max_payload_size, - ); - let on_complete = state.notify_complete(); + let state = SessionState::new(our_role, our_nonce, received_commitment, max_payload_size); - // let control_loop = ControlLoop::new(state, channels.clone(), store.clone(), init); - // - // let control_fut = control_loop.run(); + let (reply, reply_rx) = oneshot::channel(); store .send(ToActor::InitSession { peer, state, - channels: channels.clone(), + channels, init, + reply, }) .await?; - let notified_fut = async move { - on_complete.notified().await; - tracing::info!("reconciliation complete"); - channels.close_send(); + join_set.spawn(async move { + reply_rx.await??; Ok(()) - }; - // join_set.spawn(control_fut.map_err(anyhow::Error::from)); - join_set.spawn(notified_fut); + }); + + join_all(join_set).await +} + +async fn join_all(mut join_set: JoinSet>) -> anyhow::Result<()> { + let mut final_result = Ok(()); while let Some(res) = join_set.join_next().await { - res??; + let res = match res { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(err), + Err(err) => Err(err.into()), + }; + if res.is_err() && final_result.is_ok() { + final_result = res; + } else if res.is_err() { + warn!("join error after initial error: {res:?}"); + } } - Ok(()) + final_result } fn spawn_channel( join_set: &mut JoinSet>, - store: &StoreHandle, peer: NodeId, ch: LogicalChannel, cap: usize, @@ -130,70 +131,50 @@ fn spawn_channel( let (send_tx, send_rx) = channel(cap); let (recv_tx, recv_rx) = channel(cap); - let recv_fut = recv_loop( - recv_stream, - recv_tx, - store.notifier(peer, Readyness::Channel(ch, Interest::Recv)), - ) - .instrument(error_span!("recv", peer=%peer.fmt_short(), ch=%ch.fmt_short())); + let recv_fut = recv_loop(recv_stream, recv_tx) + .map_err(move |e| e.context(format!("receive loop for {ch:?} failed"))) + .instrument(error_span!("recv", peer=%peer.fmt_short(), ch=%ch.fmt_short())); join_set.spawn(recv_fut); - let send_fut = send_loop( - send_stream, - send_rx, - store.notifier(peer, Readyness::Channel(ch, Interest::Send)), - ) - .instrument(error_span!("send", peer=%peer.fmt_short(), ch=%ch.fmt_short())); + let send_fut = send_loop(send_stream, send_rx) + .map_err(move |e| e.context(format!("send loop for {ch:?} failed"))) + .instrument(error_span!("send", peer=%peer.fmt_short(), ch=%ch.fmt_short())); join_set.spawn(send_fut); (send_tx, recv_rx) } -// #[instrument(skip_all, fields(ch=%notifier.channel().fmt_short()))] async fn recv_loop( mut recv_stream: quinn::RecvStream, - channel_sender: Sender, - notifier: Notifier, + channel_tx: Sender, ) -> anyhow::Result<()> { - loop { - let buf = recv_stream.read_chunk(CHANNEL_CAP, true).await?; - if let Some(buf) = buf { - channel_sender.write_slice_async(&buf.bytes[..]).await; - trace!(len = buf.bytes.len(), "recv"); - if channel_sender.is_receivable_notify_set() { - trace!("notify"); - notifier.notify().await?; - } - } else { - break; - } + while let Some(buf) = recv_stream.read_chunk(CHANNEL_CAP, true).await? { + channel_tx.write_slice_async(&buf.bytes[..]).await?; + trace!(len = buf.bytes.len(), "recv"); } - channel_sender.close(); - debug!("recv_loop close"); + recv_stream.stop(ERROR_CODE_CLOSE_GRACEFUL.into()).ok(); + channel_tx.close(); Ok(()) } -// #[instrument(skip_all, fields(ch=%notifier.channel().fmt_short()))] async fn send_loop( mut send_stream: quinn::SendStream, - channel_receiver: Receiver, - notifier: Notifier, + channel_rx: Receiver, ) -> anyhow::Result<()> { - while let Some(data) = channel_receiver.read_bytes_async().await { + while let Some(data) = channel_rx.read_bytes_async().await { let len = data.len(); send_stream.write_chunk(data).await?; - debug!(len, "sent"); - if channel_receiver.is_sendable_notify_set() { - debug!("notify"); - notifier.notify().await?; - } + trace!(len, "sent"); + } + match send_stream.finish().await { + Ok(()) => {} + // If the other side closed gracefully, we are good. + Err(quinn::WriteError::Stopped(code)) + if code.into_inner() == ERROR_CODE_CLOSE_GRACEFUL as u64 => {} + Err(err) => return Err(err.into()), } - send_stream.flush().await?; - // send_stream.stopped().await?; - send_stream.finish().await.ok(); - debug!("send_loop close"); Ok(()) } @@ -225,20 +206,22 @@ mod tests { use iroh_base::{hash::Hash, key::SecretKey}; use iroh_net::MagicEndpoint; use rand::SeedableRng; + use rand_core::CryptoRngCore; use tracing::{debug, info}; use crate::{ net::run, proto::{ grouping::AreaOfInterest, - keys::{NamespaceId, NamespaceKind, NamespaceSecretKey, UserSecretKey}, + keys::{NamespaceId, NamespaceKind, NamespaceSecretKey, UserPublicKey, UserSecretKey}, meadowcap::{AccessMode, McCapability, OwnedCapability}, - willow::{Entry, Path}, + wgps::ReadCapability, + willow::{Entry, InvalidPath, Path, WriteCapability}, }, session::{Role, SessionInit}, store::{ actor::{StoreHandle, ToActor}, - MemoryStore, Store, + MemoryStore, }, }; @@ -291,84 +274,36 @@ mod tests { let start = Instant::now(); let mut expected_entries = HashSet::new(); - let mut store_alfie = MemoryStore::default(); - let init_alfie = { - let secret_key = UserSecretKey::generate(&mut rng); - let public_key = secret_key.public_key(); - let read_capability = McCapability::Owned(OwnedCapability::new( - &namespace_secret, - public_key, - AccessMode::Read, - )); - let write_capability = McCapability::Owned(OwnedCapability::new( - &namespace_secret, - public_key, - AccessMode::Write, - )); - for i in 0..n_alfie { - let p = format!("alfie{i}"); - let entry = Entry { - namespace_id, - subspace_id: public_key.into(), - path: Path::new(&[p.as_bytes()])?, - timestamp: 10, - payload_length: 2, - payload_digest: Hash::new("cool things"), - }; - expected_entries.insert(entry.clone()); - let entry = entry.attach_authorisation(write_capability.clone(), &secret_key)?; - store_alfie.ingest_entry(&entry)?; - } - let area_of_interest = AreaOfInterest::full(); - SessionInit { - user_secret_key: secret_key, - capability: read_capability, - area_of_interest, - } - }; - let mut store_betty = MemoryStore::default(); - let init_betty = { - let secret_key = UserSecretKey::generate(&mut rng); - let public_key = secret_key.public_key(); - let read_capability = McCapability::Owned(OwnedCapability::new( - &namespace_secret, - public_key, - AccessMode::Read, - )); - let write_capability = McCapability::Owned(OwnedCapability::new( - &namespace_secret, - public_key, - AccessMode::Write, - )); - for i in 0..n_betty { - let p = format!("betty{i}"); - let entry = Entry { - namespace_id, - subspace_id: public_key.into(), - path: Path::new(&[p.as_bytes()])?, - timestamp: 10, - payload_length: 2, - payload_digest: Hash::new("cool things"), - }; - expected_entries.insert(entry.clone()); - let entry = entry.attach_authorisation(write_capability.clone(), &secret_key)?; - store_betty.ingest_entry(&entry)?; - } - let area_of_interest = AreaOfInterest::full(); - SessionInit { - user_secret_key: secret_key, - capability: read_capability, - area_of_interest, - } - }; + let store_alfie = MemoryStore::default(); + let handle_alfie = StoreHandle::spawn(store_alfie, node_id_alfie); + + let store_betty = MemoryStore::default(); + let handle_betty = StoreHandle::spawn(store_betty, node_id_betty); + + let init_alfie = setup_and_insert( + &mut rng, + &handle_alfie, + &namespace_secret, + n_alfie, + &mut expected_entries, + |n| Path::new(&[b"alfie", n.to_string().as_bytes()]), + ) + .await?; + let init_betty = setup_and_insert( + &mut rng, + &handle_betty, + &namespace_secret, + n_betty, + &mut expected_entries, + |n| Path::new(&[b"betty", n.to_string().as_bytes()]), + ) + .await?; debug!("init constructed"); println!("init took {:?}", start.elapsed()); let start = Instant::now(); - let handle_alfie = StoreHandle::spawn(store_alfie, node_id_alfie); - let handle_betty = StoreHandle::spawn(store_betty, node_id_betty); let (res_alfie, res_betty) = tokio::join!( run( node_id_alfie, @@ -405,11 +340,13 @@ mod tests { assert!(res_betty.is_ok()); assert_eq!( get_entries(&handle_alfie, namespace_id).await?, - expected_entries + expected_entries, + "alfie expected entries" ); assert_eq!( get_entries(&handle_betty, namespace_id).await?, - expected_entries + expected_entries, + "bettyexpected entries" ); Ok(()) @@ -429,6 +366,55 @@ mod tests { Ok(entries) } + async fn setup_and_insert( + rng: &mut impl CryptoRngCore, + store: &StoreHandle, + namespace_secret: &NamespaceSecretKey, + count: usize, + track_entries: &mut impl Extend, + path_fn: impl Fn(usize) -> Result, + ) -> anyhow::Result { + let user_secret = UserSecretKey::generate(rng); + let (read_cap, write_cap) = create_capabilities(namespace_secret, user_secret.public_key()); + let subspace_id = user_secret.id(); + let namespace_id = namespace_secret.id(); + for i in 0..count { + let path = path_fn(i); + let entry = Entry { + namespace_id, + subspace_id, + path: path.expect("invalid path"), + timestamp: 10, + payload_length: 2, + payload_digest: Hash::new("cool things"), + }; + track_entries.extend([entry.clone()]); + let entry = entry.attach_authorisation(write_cap.clone(), &user_secret)?; + info!("INGEST {entry:?}"); + store.ingest_entry(entry).await?; + } + let init = SessionInit::with_interest(user_secret, read_cap, AreaOfInterest::full()); + Ok(init) + } + + fn create_capabilities( + namespace_secret: &NamespaceSecretKey, + user_public_key: UserPublicKey, + ) -> (ReadCapability, WriteCapability) { + let read_capability = McCapability::Owned(OwnedCapability::new( + &namespace_secret, + user_public_key, + AccessMode::Read, + )); + let write_capability = McCapability::Owned(OwnedCapability::new( + &namespace_secret, + user_public_key, + AccessMode::Write, + )); + (read_capability, write_capability) + // let init = SessionInit::with_interest(secret_key, read_capability, AreaOfInterest::full()) + } + // async fn get_entries_debug( // store: &StoreHandle, // namespace: NamespaceId, diff --git a/iroh-willow/src/proto/grouping.rs b/iroh-willow/src/proto/grouping.rs index e913e2012e7..9ecaeb70bac 100644 --- a/iroh-willow/src/proto/grouping.rs +++ b/iroh-willow/src/proto/grouping.rs @@ -212,7 +212,7 @@ impl RangeEnd { } /// A grouping of Entries that are among the newest in some store. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)] pub struct AreaOfInterest { /// To be included in this AreaOfInterest, an Entry must be included in the area. pub area: Area, diff --git a/iroh-willow/src/proto/keys.rs b/iroh-willow/src/proto/keys.rs index 2d343116dce..9d43e758937 100644 --- a/iroh-willow/src/proto/keys.rs +++ b/iroh-willow/src/proto/keys.rs @@ -332,12 +332,24 @@ pub struct NamespaceSignature(ed25519_dalek::Signature); bytestring!(NamespaceSignature, SIGNATURE_LENGTH); +impl std::hash::Hash for NamespaceSignature { + fn hash(&self, state: &mut H) { + self.0.to_bytes().hash(state); + } +} + /// The signature obtained by signing a message with a [`UserSecretKey`]. #[derive(Serialize, Deserialize, Clone, From, PartialEq, Eq, Deref)] pub struct UserSignature(ed25519_dalek::Signature); bytestring!(UserSignature, SIGNATURE_LENGTH); +impl std::hash::Hash for UserSignature { + fn hash(&self, state: &mut H) { + self.0.to_bytes().hash(state); + } +} + /// [`UserPublicKey`] in bytes #[derive( Default, diff --git a/iroh-willow/src/proto/meadowcap.rs b/iroh-willow/src/proto/meadowcap.rs index 22bac8f7bc4..3190158a0c2 100644 --- a/iroh-willow/src/proto/meadowcap.rs +++ b/iroh-willow/src/proto/meadowcap.rs @@ -102,7 +102,7 @@ impl From<(McCapability, UserSignature)> for MeadowcapAuthorisationToken { } } -#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, derive_more::From)] +#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash, derive_more::From)] pub enum McCapability { Communal(CommunalCapability), Owned(OwnedCapability), @@ -158,14 +158,14 @@ impl McCapability { } } -#[derive(Debug, Serialize, Deserialize, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Hash)] pub enum AccessMode { Read, Write, } /// A capability that authorizes reads or writes in communal namespaces. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct CommunalCapability { /// The kind of access this grants. access_mode: AccessMode, @@ -206,7 +206,7 @@ impl CommunalCapability { } /// A capability that authorizes reads or writes in owned namespaces. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct OwnedCapability { /// The kind of access this grants. access_mode: AccessMode, diff --git a/iroh-willow/src/proto/wgps.rs b/iroh-willow/src/proto/wgps.rs index a282d7772c7..054d2133490 100644 --- a/iroh-willow/src/proto/wgps.rs +++ b/iroh-willow/src/proto/wgps.rs @@ -100,8 +100,8 @@ pub enum LogicalChannel { impl LogicalChannel { pub fn fmt_short(&self) -> &str { match self { - LogicalChannel::Control => "C", - LogicalChannel::Reconciliation => "R", + LogicalChannel::Control => "Ctl", + LogicalChannel::Reconciliation => "Rec", } } } diff --git a/iroh-willow/src/proto/willow.rs b/iroh-willow/src/proto/willow.rs index 89bbc7c4624..b1dd9dc19fd 100644 --- a/iroh-willow/src/proto/willow.rs +++ b/iroh-willow/src/proto/willow.rs @@ -16,6 +16,9 @@ pub type NamespaceId = keys::NamespaceId; /// A type for identifying subspaces. pub type SubspaceId = keys::UserId; +/// The capability type needed to authorize writes. +pub type WriteCapability = McCapability; + /// A Timestamp is a 64-bit unsigned integer, that is, a natural number between zero (inclusive) and 2^64 - 1 (exclusive). /// Timestamps are to be interpreted as a time in microseconds since the Unix epoch. pub type Timestamp = u64; diff --git a/iroh-willow/src/session.rs b/iroh-willow/src/session.rs index 247c9d6df49..8bee8468106 100644 --- a/iroh-willow/src/session.rs +++ b/iroh-willow/src/session.rs @@ -1,3 +1,5 @@ +use std::collections::{HashMap, HashSet}; + use crate::proto::{grouping::AreaOfInterest, keys::UserSecretKey, wgps::ReadCapability}; pub mod coroutine; @@ -9,23 +11,42 @@ mod util; pub use self::error::Error; pub use self::state::{SessionState, SharedSessionState}; +/// To break symmetry, we refer to the peer that initiated the synchronisation session as Alfie, +/// and the other peer as Betty. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Role { - Betty, + /// The peer that initiated the synchronisation session. Alfie, + /// The peer that accepted the synchronisation session. + Betty, } +/// The bind scope for resources. +/// +/// Resources are bound by either peer #[derive(Copy, Clone, Debug)] pub enum Scope { + /// Resources bound by ourselves. Ours, + /// Resources bound by the other peer. Theirs, } #[derive(Debug)] pub struct SessionInit { pub user_secret_key: UserSecretKey, - // TODO: allow multiple capabilities? - pub capability: ReadCapability, - // TODO: allow multiple areas of interest? - pub area_of_interest: AreaOfInterest, + pub interests: HashMap>, +} + +impl SessionInit { + pub fn with_interest( + user_secret_key: UserSecretKey, + capability: ReadCapability, + area_of_interest: AreaOfInterest, + ) -> Self { + Self { + user_secret_key, + interests: HashMap::from_iter([(capability, HashSet::from_iter([area_of_interest]))]), + } + } } diff --git a/iroh-willow/src/session/coroutine.rs b/iroh-willow/src/session/coroutine.rs index 389d5e64a87..6afbd53b745 100644 --- a/iroh-willow/src/session/coroutine.rs +++ b/iroh-willow/src/session/coroutine.rs @@ -3,6 +3,7 @@ use std::{ rc::Rc, }; +use anyhow::anyhow; use genawaiter::sync::Co; use iroh_net::NodeId; @@ -15,14 +16,11 @@ use crate::{ wgps::{ AreaOfInterestHandle, Fingerprint, LengthyEntry, LogicalChannel, Message, ReconciliationAnnounceEntries, ReconciliationSendEntry, ReconciliationSendFingerprint, - ResourceHandle, StaticToken, StaticTokenHandle, + ResourceHandle, SetupBindAreaOfInterest, StaticToken, StaticTokenHandle, }, willow::AuthorisedEntry, }, - store::{ - actor::{CoroutineNotifier, Interest}, - ReadonlyStore, SplitAction, Store, SyncConfig, - }, + store::{actor::Interest, ReadonlyStore, SplitAction, Store, SyncConfig}, util::channel::{ReadOutcome, Receiver, Sender, WriteOutcome}, }; @@ -46,7 +44,7 @@ pub struct Coroutine { pub store_writer: Rc>, pub channels: Channels, pub state: SharedSessionState, - pub notifier: CoroutineNotifier, + // pub waker: CoroutineWaker, #[debug(skip)] pub co: Co, } @@ -60,6 +58,12 @@ pub struct Channels { } impl Channels { + pub fn close_all(&self) { + self.control_send.close(); + self.control_recv.close(); + self.reconciliation_send.close(); + self.reconciliation_recv.close(); + } pub fn close_send(&self) { self.control_send.close(); self.reconciliation_send.close(); @@ -85,6 +89,7 @@ impl Coroutine { mut self, start: Option<(AreaOfInterestHandle, AreaOfInterestHandle)>, ) -> Result<(), Error> { + debug!(init = start.is_some(), "start reconciliation"); if let Some((our_handle, their_handle)) = start { self.init_reconciliation(our_handle, their_handle).await?; } @@ -92,8 +97,8 @@ impl Coroutine { while let Some(message) = self.recv(LogicalChannel::Reconciliation).await { let message = message?; self.on_reconciliation_message(message).await?; - if self.state_mut().trigger_notify_if_complete() { - break; + if self.state_mut().reconciliation_is_complete() { + self.channels.close_send(); } } @@ -104,46 +109,35 @@ impl Coroutine { let reveal_message = self.state_mut().commitment_reveal()?; self.send_control(reveal_message).await?; + let mut init = Some(init); while let Some(message) = self.recv(LogicalChannel::Control).await { let message = message?; - debug!(%message, "run_control recv"); - self.on_control_message(message, &init).await?; - if self.state_mut().trigger_notify_if_complete() { - break; - } - } - - Ok(()) - } - - async fn on_control_message( - &mut self, - message: Message, - init: &SessionInit, - ) -> Result<(), Error> { - match message { - Message::CommitmentReveal(msg) => { - let setup_messages = self.state_mut().on_commitment_reveal(msg, &init)?; - for message in setup_messages { - debug!(%message, "send"); - self.send_control(message).await?; + match message { + Message::CommitmentReveal(msg) => { + self.state_mut().on_commitment_reveal(msg)?; + let init = init + .take() + .ok_or_else(|| Error::InvalidMessageInCurrentState)?; + self.setup(init).await?; } + Message::SetupBindReadCapability(msg) => { + self.state_mut().on_setup_bind_read_capability(msg)?; + } + Message::SetupBindStaticToken(msg) => { + self.state_mut().on_setup_bind_static_token(msg); + } + Message::SetupBindAreaOfInterest(msg) => { + let start = self.state_mut().on_setup_bind_area_of_interest(msg)?; + // if let Some(start) = st + self.co.yield_(Yield::StartReconciliation(start)).await; + } + Message::ControlFreeHandle(_msg) => { + // TODO: Free handles + } + _ => return Err(Error::UnsupportedMessage), } - Message::SetupBindReadCapability(msg) => { - self.state_mut().on_setup_bind_read_capability(msg)?; - } - Message::SetupBindStaticToken(msg) => { - self.state_mut().on_setup_bind_static_token(msg); - } - Message::SetupBindAreaOfInterest(msg) => { - let (_peer, start) = self.state_mut().on_setup_bind_area_of_interest(msg)?; - self.co.yield_(Yield::StartReconciliation(start)).await; - } - Message::ControlFreeHandle(_msg) => { - // TODO: Free handles - } - _ => return Err(Error::UnsupportedMessage), } + Ok(()) } @@ -162,6 +156,44 @@ impl Coroutine { Ok(()) } + async fn setup(&mut self, init: SessionInit) -> Result<(), Error> { + debug!(?init, "init"); + for (capability, aois) in init.interests.into_iter() { + if *capability.receiver() != init.user_secret_key.public_key() { + return Err(Error::WrongSecretKeyForCapability); + } + + // TODO: implement private area intersection + let intersection_handle = 0.into(); + let (our_capability_handle, message) = self.state_mut().bind_and_sign_capability( + &init.user_secret_key, + intersection_handle, + capability, + )?; + if let Some(message) = message { + self.send_control(message).await?; + } + + for area_of_interest in aois { + // for area in areas_of_interest { + let msg = SetupBindAreaOfInterest { + area_of_interest, + authorisation: our_capability_handle, + }; + let (_our_handle, is_new) = self + .state_mut() + .our_resources + .areas_of_interest + // TODO: avoid clone + .bind_if_new(msg.clone()); + if is_new { + self.send_control(msg).await?; + } + } + } + Ok(()) + } + async fn init_reconciliation( &mut self, our_handle: AreaOfInterestHandle, @@ -169,13 +201,16 @@ impl Coroutine { ) -> Result<(), Error> { debug!("init reconciliation"); let mut state = self.state_mut(); - let our_aoi = state.our_resources.areas_of_interest.get(&our_handle)?; - let their_aoi = state.their_resources.areas_of_interest.get(&their_handle)?; + let our_aoi = state.our_resources.areas_of_interest.try_get(&our_handle)?; + let their_aoi = state + .their_resources + .areas_of_interest + .try_get(&their_handle)?; let our_capability = state .our_resources .capabilities - .get(&our_aoi.authorisation)?; + .try_get(&our_aoi.authorisation)?; let namespace: NamespaceId = our_capability.granted_namespace().into(); let common_aoi = &our_aoi @@ -293,7 +328,9 @@ impl Coroutine { } async fn on_send_entry(&mut self, message: ReconciliationSendEntry) -> Result<(), Error> { - let static_token = self.get_static_token(message.static_token_handle).await; + let static_token = self + .get_static_token_eventually(message.static_token_handle) + .await; self.state_mut().on_send_entry()?; @@ -308,16 +345,10 @@ impl Coroutine { Ok(()) } - async fn get_static_token(&mut self, handle: StaticTokenHandle) -> StaticToken { + async fn get_static_token_eventually(&mut self, handle: StaticTokenHandle) -> StaticToken { loop { - let mut state = self.state.borrow_mut(); - match state - .their_resources - .static_tokens - .get_or_notify(&handle, || { - self.notifier - .notifier(self.peer, Readyness::Resource(handle.into())) - }) { + let state = self.state.borrow_mut(); + match state.their_resources.static_tokens.get(&handle) { Some(token) => break token.clone(), None => { drop(state); @@ -459,7 +490,7 @@ impl Coroutine { async fn recv(&self, channel: LogicalChannel) -> Option> { let receiver = self.channels.receiver(channel); loop { - match receiver.read_message_or_set_notify() { + match receiver.read_message() { Err(err) => return Some(Err(err)), Ok(outcome) => match outcome { ReadOutcome::Closed => { @@ -472,7 +503,7 @@ impl Coroutine { .await; } ReadOutcome::Item(message) => { - debug!(%message, "recv"); + debug!(ch=%channel.fmt_short(), %message, "recv"); return Some(Ok(message)); } }, @@ -495,7 +526,11 @@ impl Coroutine { let sender = self.channels.sender(channel); loop { - match sender.send_or_set_notify(&message)? { + match sender.send(&message)? { + WriteOutcome::Closed => { + debug!("send: closed"); + return Err(anyhow!("channel closed")); + } WriteOutcome::Ok => { debug!(msg=%message, ch=%channel.fmt_short(), "sent"); break Ok(()); @@ -510,41 +545,3 @@ impl Coroutine { } } } -// async fn recv_bulk( -// &self, -// channel: LogicalChannel, -// ) -> Option>> { -// let receiver = self.channels.receiver(channel); -// let mut buf = SmallVec::<[Message; N]>::new(); -// loop { -// match receiver.read_message_or_set_notify() { -// Err(err) => return Some(Err(err)), -// Ok(outcome) => match outcome { -// ReadOutcome::Closed => { -// if buf.is_empty() { -// debug!("recv: closed"); -// return None; -// } else { -// return Some(Ok(buf)); -// } -// } -// ReadOutcome::ReadBufferEmpty => { -// if buf.is_empty() { -// self.co -// .yield_(Yield::Pending(Readyness::Channel(channel, Interest::Recv))) -// .await; -// } else { -// return Some(Ok(buf)); -// } -// } -// ReadOutcome::Item(message) => { -// debug!(%message, "recv"); -// buf.push(message); -// if buf.len() == N { -// return Some(Ok(buf)); -// } -// } -// }, -// } -// } -// } diff --git a/iroh-willow/src/session/error.rs b/iroh-willow/src/session/error.rs index da918d3c0a0..7ceb6c7abe8 100644 --- a/iroh-willow/src/session/error.rs +++ b/iroh-willow/src/session/error.rs @@ -30,6 +30,8 @@ pub enum Error { BrokenCommittement, #[error("received an actor message for unknown session")] SessionNotFound, + #[error("invalid parameters: {0}")] + InvalidParameters(&'static str) } impl From for Error { diff --git a/iroh-willow/src/session/resource.rs b/iroh-willow/src/session/resource.rs index b27ab07da34..06db355f1ef 100644 --- a/iroh-willow/src/session/resource.rs +++ b/iroh-willow/src/session/resource.rs @@ -5,7 +5,7 @@ use crate::{ AreaOfInterestHandle, CapabilityHandle, IsHandle, ReadCapability, ResourceHandle, SetupBindAreaOfInterest, StaticToken, StaticTokenHandle, }, - store::actor::Notifier, + store::actor::AssignedWaker, }; use super::Error; @@ -17,12 +17,12 @@ pub struct ScopedResources { pub static_tokens: ResourceMap, } impl ScopedResources { - pub fn register_notify(&mut self, handle: ResourceHandle, notify: Notifier) { - tracing::debug!(?handle, "register_notify"); + pub fn register_waker(&mut self, handle: ResourceHandle, waker: AssignedWaker) { + tracing::trace!(?handle, "register_notify"); match handle { - ResourceHandle::AreaOfInterest(h) => self.areas_of_interest.register_notify(h, notify), - ResourceHandle::Capability(h) => self.capabilities.register_notify(h, notify), - ResourceHandle::StaticToken(h) => self.static_tokens.register_notify(h, notify), + ResourceHandle::AreaOfInterest(h) => self.areas_of_interest.register_waker(h, waker), + ResourceHandle::Capability(h) => self.capabilities.register_waker(h, waker), + ResourceHandle::StaticToken(h) => self.static_tokens.register_waker(h, waker), ResourceHandle::Intersection(_h) => unimplemented!(), } } @@ -41,7 +41,7 @@ impl ScopedResources { pub struct ResourceMap { next_handle: u64, map: HashMap>, - notify: HashMap>, + wakers: HashMap>, } impl Default for ResourceMap { @@ -49,7 +49,7 @@ impl Default for ResourceMap { Self { next_handle: 0, map: Default::default(), - notify: Default::default(), + wakers: Default::default(), } } } @@ -59,16 +59,20 @@ where H: IsHandle, R: Eq + PartialEq, { + pub fn iter(&self) -> impl Iterator + '_ { + self.map.iter().map(|(h, r)| (h, &r.value)) + } + pub fn bind(&mut self, resource: R) -> H { let handle: H = self.next_handle.into(); self.next_handle += 1; let resource = Resource::new(resource); self.map.insert(handle, resource); - tracing::debug!(?handle, "bind"); - if let Some(mut notify) = self.notify.remove(&handle) { - tracing::debug!(?handle, "notify {}", notify.len()); - for notify in notify.drain(..) { - if let Err(err) = notify.notify_sync() { + tracing::trace!(?handle, "bind"); + if let Some(mut wakers) = self.wakers.remove(&handle) { + tracing::trace!(?handle, "notify {}", wakers.len()); + for waker in wakers.drain(..) { + if let Err(err) = waker.wake() { tracing::warn!(?err, "notify failed for {handle:?}"); } } @@ -76,8 +80,8 @@ where handle } - pub fn register_notify(&mut self, handle: H, notifier: Notifier) { - self.notify.entry(handle).or_default().push_back(notifier) + pub fn register_waker(&mut self, handle: H, notifier: AssignedWaker) { + self.wakers.entry(handle).or_default().push_back(notifier) } pub fn bind_if_new(&mut self, resource: R) -> (H, bool) { @@ -94,7 +98,7 @@ where } } - pub fn get(&self, handle: &H) -> Result<&R, Error> { + pub fn try_get(&self, handle: &H) -> Result<&R, Error> { self.map .get(handle) .as_ref() @@ -102,17 +106,26 @@ where .ok_or_else(|| Error::MissingResource((*handle).into())) } - pub fn get_or_notify(&mut self, handle: &H, notify: impl FnOnce() -> Notifier) -> Option<&R> { - if let Some(resource) = self.map.get(handle).as_ref().map(|r| &r.value) { - Some(resource) - } else { - self.notify - .entry(*handle) - .or_default() - .push_back((notify)()); - None - } + pub fn get(&self, handle: &H) -> Option<&R> { + self.map.get(handle).as_ref().map(|r| &r.value) } + + // pub async fn get_eventually(&self, handle: &H) -> Result<&R, Error> { + // if let Some(resource) = self.map.get(handle).as_ref().map(|r| &r.value) { + // Some(resource) + // } else { + // // self.on_notify(handle) + // } + // } + + // pub fn get_or_notify(&mut self, handle: &H, notifier: CoroutineWaker) -> Option<&R> { + // if let Some(resource) = self.map.get(handle).as_ref().map(|r| &r.value) { + // Some(resource) + // } else { + // self.register_waker(*handle, notifier); + // None + // } + // } } // #[derive(Debug)] diff --git a/iroh-willow/src/session/state.rs b/iroh-willow/src/session/state.rs index 24fdd284b5d..e860f7ab0fd 100644 --- a/iroh-willow/src/session/state.rs +++ b/iroh-willow/src/session/state.rs @@ -1,42 +1,35 @@ -use std::{cell::RefCell, collections::HashSet, rc::Rc, sync::Arc}; +use std::{cell::RefCell, collections::HashSet, rc::Rc}; -use iroh_net::NodeId; - -use tokio::sync::Notify; -use tracing::{debug, trace, warn}; +use tracing::{trace, warn}; use crate::proto::{ challenge::ChallengeState, grouping::ThreeDRange, - keys::NamespaceId, + keys::{NamespaceId, UserSecretKey}, wgps::{ - AccessChallenge, AreaOfInterestHandle, ChallengeHash, CommitmentReveal, Message, - SetupBindAreaOfInterest, SetupBindReadCapability, SetupBindStaticToken, StaticToken, - StaticTokenHandle, + AccessChallenge, AreaOfInterestHandle, CapabilityHandle, ChallengeHash, CommitmentReveal, + IntersectionHandle, Message, ReadCapability, SetupBindAreaOfInterest, + SetupBindReadCapability, SetupBindStaticToken, StaticToken, StaticTokenHandle, }, }; -use super::{resource::ScopedResources, Error, Role, Scope, SessionInit}; +use super::{resource::ScopedResources, Error, Role, Scope}; pub type SharedSessionState = Rc>; #[derive(Debug)] pub struct SessionState { pub our_role: Role, - peer: NodeId, pub our_resources: ScopedResources, pub their_resources: ScopedResources, pub reconciliation_started: bool, pub pending_ranges: HashSet<(AreaOfInterestHandle, ThreeDRange)>, pub pending_entries: Option, - notify_complete: Arc, - challenge: ChallengeState, - our_current_aoi: Option, + pub challenge: ChallengeState, } impl SessionState { pub fn new( our_role: Role, - peer: NodeId, our_nonce: AccessChallenge, received_commitment: ChallengeHash, _their_maximum_payload_size: usize, @@ -47,15 +40,12 @@ impl SessionState { }; Self { our_role, - peer, challenge: challenge_state, reconciliation_started: false, our_resources: Default::default(), their_resources: Default::default(), pending_ranges: Default::default(), pending_entries: Default::default(), - notify_complete: Default::default(), - our_current_aoi: Default::default(), } } fn resources(&self, scope: Scope) -> &ScopedResources { @@ -64,7 +54,7 @@ impl SessionState { Scope::Theirs => &self.their_resources, } } - pub fn is_complete(&self) -> bool { + pub fn reconciliation_is_complete(&self) -> bool { let is_complete = self.reconciliation_started && self.pending_ranges.is_empty() && self.pending_entries.is_none(); @@ -77,19 +67,28 @@ impl SessionState { is_complete } - pub fn trigger_notify_if_complete(&mut self) -> bool { - if self.is_complete() { - self.notify_complete.notify_waiters(); - true - } else { - false - } - } + pub fn bind_and_sign_capability( + &mut self, + user_secret_key: &UserSecretKey, + our_intersection_handle: IntersectionHandle, + capability: ReadCapability, + ) -> Result<(CapabilityHandle, Option), Error> { + let signature = self.challenge.sign(user_secret_key)?; - pub fn notify_complete(&self) -> Arc { - Arc::clone(&self.notify_complete) + let (our_handle, is_new) = self + .our_resources + .capabilities + .bind_if_new(capability.clone()); + let maybe_message = is_new.then(|| SetupBindReadCapability { + capability, + handle: our_intersection_handle, + signature, + }); + Ok((our_handle, maybe_message)) } + // pub fn bind_aoi() + pub fn commitment_reveal(&mut self) -> Result { match self.challenge { ChallengeState::Committed { our_nonce, .. } => { @@ -100,13 +99,9 @@ impl SessionState { // let msg = CommitmentReveal { nonce: our_nonce }; } - pub fn on_commitment_reveal( - &mut self, - msg: CommitmentReveal, - init: &SessionInit, - ) -> Result<[Message; 2], Error> { + pub fn on_commitment_reveal(&mut self, msg: CommitmentReveal) -> Result<(), Error> { self.challenge.reveal(self.our_role, msg.nonce)?; - self.setup(init) + Ok(()) } pub fn on_setup_bind_read_capability( @@ -125,55 +120,43 @@ impl SessionState { self.their_resources.static_tokens.bind(msg.static_token); } - fn setup(&mut self, init: &SessionInit) -> Result<[Message; 2], Error> { - let area_of_interest = init.area_of_interest.clone(); - let capability = init.capability.clone(); - - debug!(?init, "init"); - if *capability.receiver() != init.user_secret_key.public_key() { - return Err(Error::WrongSecretKeyForCapability); - } - - // TODO: implement private area intersection - let intersection_handle = 0.into(); - let signature = self.challenge.sign(&init.user_secret_key)?; - - let our_capability_handle = self.our_resources.capabilities.bind(capability.clone()); - let msg1 = SetupBindReadCapability { - capability, - handle: intersection_handle, - signature, - }; - - let msg2 = SetupBindAreaOfInterest { - area_of_interest, - authorisation: our_capability_handle, - }; - let our_aoi_handle = self.our_resources.areas_of_interest.bind(msg2.clone()); - self.our_current_aoi = Some(our_aoi_handle); - Ok([msg1.into(), msg2.into()]) - } - pub fn on_setup_bind_area_of_interest( &mut self, msg: SetupBindAreaOfInterest, - ) -> Result<(NodeId, Option<(AreaOfInterestHandle, AreaOfInterestHandle)>), Error> { + ) -> Result, Error> { let capability = self - .resources(Scope::Theirs) + .their_resources .capabilities - .get(&msg.authorisation)?; + .try_get(&msg.authorisation)?; capability.try_granted_area(&msg.area_of_interest.area)?; let their_handle = self.their_resources.areas_of_interest.bind(msg); + + // only initiate reconciliation if we are alfie, and if we have a shared aoi + // TODO: abort if no shared aoi? let start = if self.our_role == Role::Alfie { - let our_handle = self - .our_current_aoi - .clone() - .ok_or(Error::InvalidMessageInCurrentState)?; - Some((our_handle, their_handle)) + self.find_shared_aoi(&their_handle)? + .map(|our_handle| (our_handle, their_handle)) } else { None }; - Ok((self.peer, start)) + Ok(start) + } + + pub fn find_shared_aoi( + &self, + their_handle: &AreaOfInterestHandle, + ) -> Result, Error> { + let their_aoi = self + .their_resources + .areas_of_interest + .try_get(their_handle)?; + let maybe_our_handle = self + .our_resources + .areas_of_interest + .iter() + .find(|(_handle, aoi)| aoi.area().intersection(their_aoi.area()).is_some()) + .map(|(handle, _aoi)| *handle); + Ok(maybe_our_handle) } pub fn on_send_entry(&mut self) -> Result<(), Error> { @@ -223,8 +206,11 @@ impl SessionState { scope: Scope, handle: &AreaOfInterestHandle, ) -> Result { - let aoi = self.resources(scope).areas_of_interest.get(handle)?; - let capability = self.resources(scope).capabilities.get(&aoi.authorisation)?; + let aoi = self.resources(scope).areas_of_interest.try_get(handle)?; + let capability = self + .resources(scope) + .capabilities + .try_get(&aoi.authorisation)?; let namespace_id = capability.granted_namespace().into(); Ok(namespace_id) } @@ -254,6 +240,6 @@ impl SessionState { scope: Scope, handle: &AreaOfInterestHandle, ) -> Result<&SetupBindAreaOfInterest, Error> { - self.resources(scope).areas_of_interest.get(handle) + self.resources(scope).areas_of_interest.try_get(handle) } } diff --git a/iroh-willow/src/store/actor.rs b/iroh-willow/src/store/actor.rs index 8c1c2dde6ad..70797c40d63 100644 --- a/iroh-willow/src/store/actor.rs +++ b/iroh-willow/src/store/actor.rs @@ -9,12 +9,16 @@ use std::{ use futures::{future::LocalBoxFuture, FutureExt}; use genawaiter::{sync::Gen, GeneratorState}; use tokio::sync::oneshot; -use tracing::{debug, error, error_span, instrument, warn, Span}; +use tracing::{debug, error, error_span, instrument, trace, warn, Span}; // use iroh_net::NodeId; use super::Store; use crate::{ - proto::{grouping::ThreeDRange, keys::NamespaceId, willow::Entry}, + proto::{ + grouping::ThreeDRange, + keys::NamespaceId, + willow::{AuthorisedEntry, Entry}, + }, session::{ coroutine::{Channels, Coroutine, Readyness, Yield}, Error, SessionInit, SessionState, SharedSessionState, @@ -36,60 +40,68 @@ pub enum Interest { Recv, } -#[derive(Debug)] -pub struct CoroutineNotifier { - tx: flume::Sender, +// #[derive(Debug)] +// pub struct Notifier { +// tx: flume::Sender, +// } +// impl Notifier { +// pub async fn notify(&self, peer: NodeId, notify: Readyness) -> anyhow::Result<()> { +// let msg = ToActor::Resume { peer, notify }; +// self.tx.send_async(msg).await?; +// Ok(()) +// } +// pub fn notify_sync(&self, peer: NodeId, notify: Readyness) -> anyhow::Result<()> { +// let msg = ToActor::Resume { peer, notify }; +// self.tx.send(msg)?; +// Ok(()) +// } +// pub fn notifier(&self, peer: NodeId) -> Notifier { +// Notifier { +// tx: self.tx.clone(), +// } +// } +// } + +#[derive(Debug, Clone)] +pub struct AssignedWaker { + waker: CoroutineWaker, + peer: NodeId, + notify: Readyness, } -impl CoroutineNotifier { - pub async fn notify(&self, peer: NodeId, notify: Readyness) -> anyhow::Result<()> { - let msg = ToActor::Resume { peer, notify }; - self.tx.send_async(msg).await?; - Ok(()) - } - pub fn notify_sync(&self, peer: NodeId, notify: Readyness) -> anyhow::Result<()> { - let msg = ToActor::Resume { peer, notify }; - self.tx.send(msg)?; - Ok(()) - } - pub fn notifier(&self, peer: NodeId, notify: Readyness) -> Notifier { - Notifier { - tx: self.tx.clone(), - peer, - notify, - } + +impl AssignedWaker { + pub fn wake(&self) -> anyhow::Result<()> { + self.waker.wake(self.peer, self.notify) } } #[derive(Debug, Clone)] -pub struct Notifier { +pub struct CoroutineWaker { tx: flume::Sender, - notify: Readyness, - peer: NodeId, } -impl Notifier { - pub async fn notify(&self) -> anyhow::Result<()> { - let msg = ToActor::Resume { - peer: self.peer, - notify: self.notify, - }; - self.tx.send_async(msg).await?; - Ok(()) - } - pub fn notify_sync(&self) -> anyhow::Result<()> { - let msg = ToActor::Resume { - peer: self.peer, - notify: self.notify, - }; +impl CoroutineWaker { + pub fn wake(&self, peer: NodeId, notify: Readyness) -> anyhow::Result<()> { + let msg = ToActor::Resume { peer, notify }; + // TODO: deadlock self.tx.send(msg)?; Ok(()) } + + pub fn with_notify(&self, peer: NodeId, notify: Readyness) -> AssignedWaker { + AssignedWaker { + waker: self.clone(), + peer, + notify, + } + } } impl StoreHandle { pub fn spawn(store: S, me: NodeId) -> StoreHandle { let (tx, rx) = flume::bounded(CHANNEL_CAP); - let actor_tx = tx.clone(); + // let actor_tx = tx.clone(); + let waker = CoroutineWaker { tx: tx.clone() }; let join_handle = std::thread::Builder::new() .name("sync-actor".to_string()) .spawn(move || { @@ -100,7 +112,7 @@ impl StoreHandle { store: Rc::new(RefCell::new(store)), sessions: Default::default(), actor_rx: rx, - actor_tx, + waker, }; if let Err(error) = actor.run() { error!(?error, "storage thread failed"); @@ -118,13 +130,22 @@ impl StoreHandle { self.tx.send(action)?; Ok(()) } - pub fn notifier(&self, peer: NodeId, notify: Readyness) -> Notifier { - Notifier { + pub fn waker(&self) -> CoroutineWaker { + CoroutineWaker { tx: self.tx.clone(), - peer, - notify, } } + pub async fn ingest_entry(&self, entry: AuthorisedEntry) -> anyhow::Result<()> { + let (reply, reply_rx) = oneshot::channel(); + self.send(ToActor::IngestEntry { entry, reply }).await?; + reply_rx.await??; + Ok(()) + } + // + // pub fn ingest_stream(&self, stream: impl Stream) -> Result<()> { + // } + // pub fn ingest_iter(&self, iter: impl ) -> Result<()> { + // } } impl Drop for StoreHandle { @@ -148,10 +169,11 @@ pub enum ToActor { #[debug(skip)] channels: Channels, init: SessionInit, + reply: oneshot::Sender>, }, - DropSession { - peer: NodeId, - }, + // DropSession { + // peer: NodeId, + // }, Resume { peer: NodeId, notify: Readyness, @@ -161,6 +183,10 @@ pub enum ToActor { #[debug(skip)] reply: flume::Sender, }, + IngestEntry { + entry: AuthorisedEntry, + reply: oneshot::Sender>, + }, Shutdown { #[debug(skip)] reply: Option>, @@ -168,26 +194,27 @@ pub enum ToActor { } #[derive(Debug)] -struct StorageSession { +struct Session { state: SharedSessionState, channels: Channels, pending: PendingCoroutines, + on_done: oneshot::Sender>, } #[derive(derive_more::Debug, Default)] struct PendingCoroutines { #[debug(skip)] - inner: HashMap>, + inner: HashMap>, } impl PendingCoroutines { - fn get_mut(&mut self, pending_on: Readyness) -> &mut VecDeque { + fn get_mut(&mut self, pending_on: Readyness) -> &mut VecDeque { self.inner.entry(pending_on).or_default() } - fn push_back(&mut self, pending_on: Readyness, generator: ReconcileGen) { + fn push_back(&mut self, pending_on: Readyness, generator: CoroutineState) { self.get_mut(pending_on).push_back(generator); } - fn pop_front(&mut self, pending_on: Readyness) -> Option { + fn pop_front(&mut self, pending_on: Readyness) -> Option { self.get_mut(pending_on).pop_front() } // fn push_front(&mut self, pending_on: Readyness, generator: ReconcileGen) { @@ -205,13 +232,13 @@ impl PendingCoroutines { #[derive(Debug)] pub struct StorageThread { store: Rc>, - sessions: HashMap, + sessions: HashMap, actor_rx: flume::Receiver, - actor_tx: flume::Sender, + waker: CoroutineWaker, // actor_tx: flume::Sender, } type ReconcileFut = LocalBoxFuture<'static, Result<(), Error>>; -type ReconcileGen = (Span, Gen); +type ReconcileGen = Gen; impl StorageThread { pub fn run(&mut self) -> anyhow::Result<()> { @@ -234,34 +261,47 @@ impl StorageThread { } fn handle_message(&mut self, message: ToActor) -> Result<(), Error> { - debug!(%message, "tick: handle_message"); + trace!(%message, "tick: handle_message"); match message { ToActor::Shutdown { .. } => unreachable!("handled in run"), ToActor::InitSession { peer, state, channels, - init, // start, + init, + reply, } => { - let session = StorageSession { + let session = Session { state: Rc::new(RefCell::new(state)), channels, pending: Default::default(), + on_done: reply, }; self.sessions.insert(peer, session); + debug!("start coroutine control"); - self.start_coroutine( + + if let Err(error) = self.start_coroutine( peer, |routine| routine.run_control(init).boxed_local(), error_span!("control", peer=%peer.fmt_short()), - )?; - } - ToActor::DropSession { peer } => { - self.sessions.remove(&peer); + true, + ) { + warn!(?error, peer=%peer.fmt_short(), "abort session: starting failed"); + self.remove_session(&peer, Err(error)); + } } ToActor::Resume { peer, notify } => { - self.resume_next(peer, notify)?; + if self.sessions.contains_key(&peer) { + if let Err(error) = self.resume_next(peer, notify) { + warn!(?error, peer=%peer.fmt_short(), "abort session: coroutine failed"); + self.remove_session(&peer, Err(error)); + } + } } + // ToActor::DropSession { peer } => { + // self.remove_session(&peer, Ok(())); + // } ToActor::GetEntries { namespace, reply } => { let store = self.store.borrow(); let entries = store @@ -271,11 +311,22 @@ impl StorageThread { reply.send(entry).ok(); } } + ToActor::IngestEntry { entry, reply } => { + let res = self.store.borrow_mut().ingest_entry(&entry); + reply.send(res).ok(); + } } Ok(()) } - fn session_mut(&mut self, peer: &NodeId) -> Result<&mut StorageSession, Error> { - self.sessions.get_mut(peer).ok_or(Error::SessionNotFound) + + fn remove_session(&mut self, peer: &NodeId, result: Result<(), Error>) { + let session = self.sessions.remove(peer); + if let Some(session) = session { + session.channels.close_all(); + session.on_done.send(result).ok(); + } else { + warn!("remove_session called for unknown session"); + } } fn start_coroutine( @@ -283,6 +334,7 @@ impl StorageThread { peer: NodeId, producer: impl FnOnce(Coroutine) -> ReconcileFut, span: Span, + finalizes_session: bool, ) -> Result<(), Error> { let session = self.sessions.get_mut(&peer).ok_or(Error::SessionNotFound)?; let store_snapshot = Rc::new(self.store.borrow_mut().snapshot()?); @@ -290,23 +342,26 @@ impl StorageThread { let channels = session.channels.clone(); let state = session.state.clone(); let store_writer = Rc::clone(&self.store); - let notifier = CoroutineNotifier { - tx: self.actor_tx.clone(), - }; + // let waker = self.waker.clone(); - let generator = Gen::new(move |co| { + let gen = Gen::new(move |co| { let routine = Coroutine { peer, store_snapshot, store_writer, - notifier, + // waker, channels, state, co, }; (producer)(routine) }); - self.resume_coroutine(peer, (span, generator)) + let state = CoroutineState { + gen, + span, + finalizes_session, + }; + self.resume_coroutine(peer, state) } #[instrument(skip_all, fields(session=%peer.fmt_short()))] @@ -322,19 +377,40 @@ impl StorageThread { } } - fn resume_coroutine(&mut self, peer: NodeId, generator: ReconcileGen) -> Result<(), Error> { - let (span, mut generator) = generator; - let _guard = span.enter(); - debug!("resume"); + fn resume_coroutine(&mut self, peer: NodeId, mut state: CoroutineState) -> Result<(), Error> { + let _guard = state.span.enter(); + trace!(peer=%peer.fmt_short(), "resume"); loop { - match generator.resume() { + match state.gen.resume() { GeneratorState::Yielded(yielded) => { - debug!(?yielded, "yield"); + trace!(?yielded, "yield"); match yielded { - Yield::Pending(notify) => { - let session = self.session_mut(&peer)?; + Yield::Pending(resume_on) => { + let session = + self.sessions.get_mut(&peer).ok_or(Error::SessionNotFound)?; drop(_guard); - session.pending.push_back(notify, (span, generator)); + match resume_on { + Readyness::Channel(ch, interest) => { + let waker = self + .waker + .with_notify(peer, Readyness::Channel(ch, interest)); + match interest { + Interest::Send => { + session.channels.sender(ch).register_waker(waker) + } + Interest::Recv => { + session.channels.receiver(ch).register_waker(waker) + } + }; + } + Readyness::Resource(handle) => { + let waker = + self.waker.with_notify(peer, Readyness::Resource(handle)); + let mut state = session.state.borrow_mut(); + state.their_resources.register_waker(handle, waker); + } + } + session.pending.push_back(resume_on, state); break Ok(()); } Yield::StartReconciliation(start) => { @@ -343,15 +419,25 @@ impl StorageThread { peer, |routine| routine.run_reconciliation(start).boxed_local(), error_span!("reconcile"), + false, )?; } } } GeneratorState::Complete(res) => { debug!(?res, "complete"); - break res; + if res.is_err() || state.finalizes_session { + self.remove_session(&peer, res) + } + break Ok(()); } } } } } + +struct CoroutineState { + gen: ReconcileGen, + span: Span, + finalizes_session: bool, +} diff --git a/iroh-willow/src/util/channel.rs b/iroh-willow/src/util/channel.rs index 7990c16b775..6a25a4eb6db 100644 --- a/iroh-willow/src/util/channel.rs +++ b/iroh-willow/src/util/channel.rs @@ -4,21 +4,51 @@ use std::{ sync::{Arc, Mutex}, }; +use anyhow::anyhow; use bytes::{Buf, Bytes, BytesMut}; use tokio::sync::Notify; -use tracing::{debug, trace}; +use tracing::trace; + +use crate::store::actor::AssignedWaker; use super::{DecodeOutcome, Decoder, Encoder}; +pub fn channel(cap: usize) -> (Sender, Receiver) { + let shared = Shared::new(cap); + let shared = Arc::new(Mutex::new(shared)); + let sender = Sender { + shared: shared.clone(), + _ty: PhantomData, + }; + let receiver = Receiver { + shared, + _ty: PhantomData, + }; + (sender, receiver) +} + +#[derive(Debug)] +pub enum ReadOutcome { + ReadBufferEmpty, + Closed, + Item(T), +} + +#[derive(Debug)] +pub enum WriteOutcome { + BufferFull, + Closed, + Ok, +} + #[derive(Debug)] struct Shared { buf: BytesMut, max_buffer_size: usize, notify_readable: Arc, notify_writable: Arc, - write_blocked: bool, - need_read_notify: bool, - need_write_notify: bool, + wakers_on_writable: Vec, + wakers_on_readable: Vec, closed: bool, } @@ -29,48 +59,46 @@ impl Shared { max_buffer_size: cap, notify_readable: Default::default(), notify_writable: Default::default(), - write_blocked: false, - need_read_notify: false, - need_write_notify: false, + wakers_on_writable: Default::default(), + wakers_on_readable: Default::default(), closed: false, } } fn close(&mut self) { self.closed = true; - self.notify_writable.notify_waiters(); - self.notify_readable.notify_waiters(); + self.notify_writable(); + self.notify_readable(); } + fn closed(&self) -> bool { self.closed } - fn read_slice(&self) -> &[u8] { + + fn peek_read(&self) -> &[u8] { &self.buf[..] } - fn read_buf_empty(&self) -> bool { + fn read_is_empty(&self) -> bool { self.buf.is_empty() } fn read_advance(&mut self, cnt: usize) { self.buf.advance(cnt); if cnt > 0 { - // self.write_blocked = false; - self.notify_writable.notify_waiters(); + self.notify_writable(); } } fn read_bytes(&mut self) -> Bytes { let len = self.buf.len(); if len > 0 { - // self.write_blocked = false; - self.notify_writable.notify_waiters(); + self.notify_writable(); } self.buf.split_to(len).freeze() } fn write_slice(&mut self, len: usize) -> Option<&mut [u8]> { if self.remaining_write_capacity() < len { - self.write_blocked = true; None } else { let old_len = self.buf.len(); @@ -83,15 +111,13 @@ impl Shared { fn write_message(&mut self, item: &T) -> anyhow::Result { let len = item.encoded_len(); - // debug!(?item, len = len, "write_message"); + if self.closed() { + return Ok(WriteOutcome::Closed); + } if let Some(slice) = self.write_slice(len) { - // debug!(len = slice.len(), "write_message got slice"); let mut cursor = io::Cursor::new(slice); item.encode_into(&mut cursor)?; - // debug!("RES {res:?}"); - // res?; - self.notify_readable.notify_one(); - // debug!("wrote and notified"); + self.notify_readable(); Ok(WriteOutcome::Ok) } else { Ok(WriteOutcome::BufferFull) @@ -99,7 +125,7 @@ impl Shared { } fn read_message(&mut self) -> anyhow::Result> { - let data = self.read_slice(); + let data = self.peek_read(); trace!("read, remaining {}", data.len()); let res = match T::decode_from(data)? { DecodeOutcome::NeedMoreData => { @@ -117,29 +143,22 @@ impl Shared { Ok(res) } - // fn receiver_want_notify(&mut self::) { - // self.need_read_notify = true; - // } - // fn need_write_notify(&mut self) { - // self.need_write_notify = true; - // } - fn remaining_write_capacity(&self) -> usize { self.max_buffer_size - self.buf.len() } -} - -#[derive(Debug)] -pub enum ReadOutcome { - ReadBufferEmpty, - Closed, - Item(T), -} -#[derive(Debug)] -pub enum WriteOutcome { - BufferFull, - Ok, + fn notify_readable(&mut self) { + self.notify_readable.notify_waiters(); + for waker in self.wakers_on_readable.drain(..) { + waker.wake().ok(); + } + } + fn notify_writable(&mut self) { + self.notify_writable.notify_waiters(); + for waker in self.wakers_on_writable.drain(..) { + waker.wake().ok(); + } + } } #[derive(Debug)] @@ -158,6 +177,10 @@ impl Clone for Receiver { } impl Receiver { + pub fn close(&self) { + self.shared.lock().unwrap().close() + } + pub fn read_bytes(&self) -> Bytes { self.shared.lock().unwrap().read_bytes() } @@ -166,7 +189,7 @@ impl Receiver { loop { let notify = { let mut shared = self.shared.lock().unwrap(); - if !shared.read_buf_empty() { + if !shared.read_is_empty() { return Some(shared.read_bytes()); } if shared.closed() { @@ -178,29 +201,14 @@ impl Receiver { } } - pub fn read_message_or_set_notify(&self) -> anyhow::Result> { + pub fn read_message(&self) -> anyhow::Result> { let mut shared = self.shared.lock().unwrap(); let outcome = shared.read_message()?; - if matches!(outcome, ReadOutcome::ReadBufferEmpty) { - shared.need_read_notify = true; - } Ok(outcome) } - pub fn set_notify_on_receivable(&self) { - self.shared.lock().unwrap().need_read_notify = true; - } - pub fn is_sendable_notify_set(&self) -> bool { - self.shared.lock().unwrap().need_write_notify - } - pub async fn notify_readable(&self) { - let shared = self.shared.lock().unwrap(); - if !shared.read_slice().is_empty() { - return; - } - let notify = shared.notify_readable.clone(); - drop(shared); - notify.notified().await + pub fn register_waker(&self, waker: AssignedWaker) { + self.shared.lock().unwrap().wakers_on_readable.push(waker); } pub async fn recv_async(&self) -> Option> { @@ -213,15 +221,12 @@ impl Receiver { ReadOutcome::ReadBufferEmpty => shared.notify_readable.clone(), ReadOutcome::Closed => return None, ReadOutcome::Item(item) => { - // debug!("read_message_async read"); return Some(Ok(item)); } }, } }; - // debug!("read_message_async NeedMoreData wait"); notify.notified().await; - // debug!("read_message_async NeedMoreData notified"); } } } @@ -242,100 +247,57 @@ impl Clone for Sender { } impl Sender { - // fn write_slice_into(&self, len: usize) -> Option<&mut [u8]> { - // let mut shared = self.shared.lock().unwrap(); - // shared.write_slice(len) - // } - pub fn set_notify_on_sendable(&self) { - self.shared.lock().unwrap().need_write_notify = true; + pub fn close(&self) { + self.shared.lock().unwrap().close() } - pub fn is_receivable_notify_set(&self) -> bool { - self.shared.lock().unwrap().need_read_notify + pub fn register_waker(&self, waker: AssignedWaker) { + self.shared.lock().unwrap().wakers_on_writable.push(waker); } - pub fn close(&self) { - self.shared.lock().unwrap().close() + pub async fn notify_closed(&self) { + tracing::info!("notify_close IN"); + loop { + let notify = { + let shared = self.shared.lock().unwrap(); + if shared.closed() { + tracing::info!("notify_close closed!"); + return; + } else { + tracing::info!("notify_close not closed - wait"); + + } + shared.notify_writable.clone() + }; + notify.notified().await; + } } - // fn write_slice(&self, data: &[u8]) -> bool { - // let mut shared = self.shared.lock().unwrap(); - // match shared.write_slice(data.len()) { - // None => false, - // Some(out) => { - // out.copy_from_slice(data); - // true - // } - // } - // } - - pub async fn write_slice_async(&self, data: &[u8]) { + pub async fn write_slice_async(&self, data: &[u8]) -> anyhow::Result<()> { loop { let notify = { let mut shared = self.shared.lock().unwrap(); + if shared.closed() { + break Err(anyhow!("channel closed")); + } if shared.remaining_write_capacity() < data.len() { let notify = shared.notify_writable.clone(); notify.clone() } else { let out = shared.write_slice(data.len()).expect("just checked"); out.copy_from_slice(data); - shared.notify_readable.notify_waiters(); - break; - // return true; + shared.notify_readable(); + break Ok(()); } }; notify.notified().await; } } - pub async fn notify_writable(&self) { - let shared = self.shared.lock().unwrap(); - if shared.remaining_write_capacity() > 0 { - return; - } - let notify = shared.notify_readable.clone(); - drop(shared); - notify.notified().await; - } - - // fn remaining_write_capacity(&self) -> usize { - // self.shared.lock().unwrap().remaining_write_capacity() - // } - - pub fn send_or_set_notify(&self, message: &T) -> anyhow::Result { - let mut shared = self.shared.lock().unwrap(); - let outcome = shared.write_message(message)?; - if matches!(outcome, WriteOutcome::BufferFull) { - shared.need_write_notify = true; - } - debug!("send buf remaining: {}", shared.remaining_write_capacity()); - Ok(outcome) - } - pub fn send(&self, message: &T) -> anyhow::Result { self.shared.lock().unwrap().write_message(message) } - // pub async fn sNamespacePublicKeyend_co( - // &self, - // message: &T, - // yield_fn: F, - // // co: &genawaiter::sync::Co, - // // yield_value: Y, - // ) -> anyhow::Result<()> - // where - // F: Fn() -> Fut, - // Fut: std::future::Future, - // { - // loop { - // let res = self.shared.lock().unwrap().write_message(message)?; - // match res { - // WriteOutcome::BufferFull => (yield_fn)().await, - // WriteOutcome::Ok => break Ok(()), - // } - // } - // } - pub async fn send_async(&self, message: &T) -> anyhow::Result<()> { loop { let notify = { @@ -343,6 +305,7 @@ impl Sender { match shared.write_message(message)? { WriteOutcome::Ok => return Ok(()), WriteOutcome::BufferFull => shared.notify_writable.clone(), + WriteOutcome::Closed => return Err(anyhow!("channel is closed")), } }; notify.notified().await; @@ -350,105 +313,13 @@ impl Sender { } } -pub fn channel(cap: usize) -> (Sender, Receiver) { - let shared = Shared::new(cap); - let shared = Arc::new(Mutex::new(shared)); - let sender = Sender { - shared: shared.clone(), - _ty: PhantomData, - }; - let receiver = Receiver { - shared, - _ty: PhantomData, - }; - (sender, receiver) -} - -// #[derive(Debug)] -// pub struct ChannelSender { -// id: u64, -// buf: rtrb::Producer, -// // waker: Option, -// } -// -// impl ChannelSender { -// pub fn remaining_capacity(&self) -> usize { -// self.buf.slots() +// pub async fn notify_readable(&self) { +// let shared = self.shared.lock().unwrap(); +// if !shared.peek_read().is_empty() { +// return; // } -// pub fn can_write_message(&mut self, message: &Message) -> bool { -// message.encoded_len() <= self.remaining_capacity() -// } -// -// pub fn write_message(&mut self, message: &Message) -> bool { -// let encoded_len = message.encoded_len(); -// if encoded_len > self.remaining_capacity() { -// return false; -// } -// message.encode_into(&mut self.buf).expect("length checked"); -// if let Some(waker) = self.waker.take() { -// waker.wake(); -// } -// true -// } -// -// pub fn set_waker(&mut self, waker: Waker) { -// self.waker = Some(waker); -// } -// } -// -// #[derive(Debug)] -// pub enum ToStoreActor { -// // NotifyWake(u64, Arc), -// Resume(u64), +// let notify = shared.notify_readable.clone(); +// drop(shared); +// notify.notified().await // } // -// #[derive(Debug)] -// pub struct ChannelReceiver { -// id: u64, -// // buf: rtrb::Consumer, -// buf: BytesMut, -// to_actor: flume::Sender, -// notify_readable: Arc, -// } -// -// impl ChannelReceiver { -// pub async fn read_chunk(&mut self) -> Result, ChunkError> { -// if self.is_empty() { -// self.acquire().await; -// } -// self.buf.read_chunk(self.readable_len()) -// } -// -// pub fn is_empty(&self) -> bool { -// self.buf.is_empty() -// } -// -// pub fn readable_len(&self) -> usize { -// self.buf.slots() -// } -// -// pub async fn resume(&mut self) { -// self.to_actor -// .send_async(ToStoreActor::Resume(self.id)) -// .await -// .unwrap(); -// } -// -// pub async fn acquire(&mut self) { -// if !self.is_empty() { -// return; -// } -// self.notify_readable.notified().await; -// } -// } -// -// pub struct ChannelSender { -// id: u64, -// buf: rtrb::Producer, -// to_actor: flume::Sender, -// notify_readable: Arc, -// } -// -// impl ChannelSender { -// pub -// }