Skip to content

Commit

Permalink
hide refcells from public apis
Browse files Browse the repository at this point in the history
  • Loading branch information
Frando committed May 15, 2024
1 parent 7ef6085 commit 7c327bf
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 209 deletions.
18 changes: 8 additions & 10 deletions iroh-willow/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
meadowcap,
willow::{AuthorisedEntry, Entry},
},
session::{coroutine::ControlRoutine, Channels, Error, Role, SessionInit, Session},
session::{coroutine::ControlRoutine, Channels, Error, Role, Session, SessionInit},
store::{KeyStore, Store},
util::task_set::{TaskKey, TaskMap},
};
Expand Down Expand Up @@ -186,8 +186,7 @@ pub enum ToActor {
#[derive(Debug)]
struct ActiveSession {
on_finish: oneshot::Sender<Result<(), Error>>,
task_key: TaskKey
// state: SharedSessionState<S>
task_key: TaskKey, // state: SharedSessionState<S>
}

#[derive(Debug)]
Expand Down Expand Up @@ -245,19 +244,18 @@ impl<S: Store> StorageThread<S> {
} => {
let session_id = peer;
let Channels { send, recv } = channels;
let session = Session::new(
self.store.clone(),
send,
our_role,
initial_transmission,
);
let session =
Session::new(self.store.clone(), send, our_role, initial_transmission);

let task_key = self.session_tasks.spawn_local(
session_id,
ControlRoutine::run(session, recv, init)
.instrument(error_span!("session", peer = %peer.fmt_short())),
);
let active_session = ActiveSession { on_finish, task_key };
let active_session = ActiveSession {
on_finish,
task_key,
};
self.sessions.insert(session_id, active_session);
}
ToActor::GetEntries {
Expand Down
2 changes: 1 addition & 1 deletion iroh-willow/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod util;

pub use self::channels::Channels;
pub use self::error::Error;
pub use self::state::{SessionState, Session};
pub use self::state::Session;

/// To break symmetry, we refer to the peer that initiated the synchronisation session as Alfie,
/// and the other peer as Betty.
Expand Down
114 changes: 30 additions & 84 deletions iroh-willow/src/session/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cell::RefMut, rc::Rc};
use std::rc::Rc;

use futures_lite::StreamExt;
use strum::IntoEnumIterator;
Expand All @@ -15,10 +15,7 @@ use crate::{
},
willow::AuthorisedEntry,
},
session::{
channels::LogicalChannelReceivers, Error, Scope, SessionInit, SessionState,
Session,
},
session::{channels::LogicalChannelReceivers, Error, Scope, Session, SessionInit},
store::{ReadonlyStore, SplitAction, Store, SyncConfig},
util::channel::{Receiver, WriteError},
};
Expand All @@ -33,7 +30,7 @@ const INITIAL_GUARANTEES: u64 = u64::MAX;
#[derive(derive_more::Debug)]
pub struct ControlRoutine<S> {
control_recv: Receiver<Message>,
state: Session<S>,
session: Session<S>,
init: Option<SessionInit>,
}

Expand All @@ -57,23 +54,23 @@ impl<S: Store> ControlRoutine<S> {
// Spawn a task to handle incoming static tokens.
session.spawn(error_span!("stt"), move |session| async move {
while let Some(message) = static_tokens_recv.try_next().await? {
session.state_mut().on_setup_bind_static_token(message);
session.on_setup_bind_static_token(message);
}
Ok(())
});

// Spawn a task to handle incoming capabilities.
session.spawn(error_span!("cap"), move |session| async move {
while let Some(message) = capability_recv.try_next().await? {
session.state_mut().on_setup_bind_read_capability(message)?;
session.on_setup_bind_read_capability(message)?;
}
Ok(())
});

// Spawn a task to handle incoming areas of interest.
session.spawn(error_span!("aoi"), move |session| async move {
while let Some(message) = aoi_recv.try_next().await? {
Self::on_bind_area_of_interest(session.clone(), message).await?;
session.on_bind_area_of_interest(message).await?;
}
Ok(())
});
Expand Down Expand Up @@ -103,7 +100,7 @@ impl<S: Store> ControlRoutine<S> {
// TODO: We'll want to emit the completion event back to the application and
// let it decide what to do (stop, keep open) - or pass relevant config in
// SessionInit.
if session.state_mut().reconciliation_is_complete() {
if session.reconciliation_is_complete() {
tracing::debug!("stop session: reconciliation is complete");
drop(guard);
break;
Expand All @@ -113,37 +110,33 @@ impl<S: Store> ControlRoutine<S> {
// Close all our send streams.
//
// This makes the networking send loops stop.
session.send.close_all();
session.close_senders();

Ok(())
}

pub fn new(
session: Session<S>,
control_recv: Receiver<Message>,
init: SessionInit,
) -> Self {
pub fn new(session: Session<S>, control_recv: Receiver<Message>, init: SessionInit) -> Self {
Self {
control_recv,
state: session,
session,
init: Some(init),
}
}

async fn run_inner(mut self) -> Result<(), Error> {
debug!(role = ?self.state().our_role, "start session");
debug!(role = ?self.session.our_role(), "start session");

// Reveal our nonce.
let reveal_message = self.state().commitment_reveal()?;
self.state.send(reveal_message).await?;
let reveal_message = self.session.reveal_commitment()?;
self.session.send(reveal_message).await?;

// Issue guarantees for all logical channels.
for channel in LogicalChannel::iter() {
let msg = ControlIssueGuarantee {
amount: INITIAL_GUARANTEES,
channel,
};
self.state.send(msg).await?;
self.session.send(msg).await?;
}

while let Some(message) = self.control_recv.try_next().await? {
Expand All @@ -157,39 +150,25 @@ impl<S: Store> ControlRoutine<S> {
debug!(%message, "recv");
match message {
Message::CommitmentReveal(msg) => {
self.state().on_commitment_reveal(msg)?;
self.session.on_commitment_reveal(msg)?;
let init = self
.init
.take()
.ok_or_else(|| Error::InvalidMessageInCurrentState)?;
self.state
self.session
.spawn(error_span!("setup"), |state| Self::setup(state, init));
}
Message::ControlIssueGuarantee(msg) => {
let ControlIssueGuarantee { amount, channel } = msg;
let sender = self.state.send.get_logical(channel);
debug!(?channel, %amount, "add guarantees");
sender.add_guarantees(amount);
self.session.add_guarantees(channel, amount);
}
_ => return Err(Error::UnsupportedMessage),
}

Ok(())
}

async fn on_bind_area_of_interest(
session: Session<S>,
message: SetupBindAreaOfInterest,
) -> Result<(), Error> {
session
.get_their_resource_eventually(|r| &mut r.capabilities, message.authorisation)
.await;
session
.state_mut()
.bind_area_of_interest(Scope::Theirs, message)?;
Ok(())
}

async fn setup(session: Session<S>, init: SessionInit) -> Result<(), Error> {
debug!(interests = init.interests.len(), "start setup");
for (capability, aois) in init.interests.into_iter() {
Expand All @@ -207,19 +186,13 @@ impl<S: Store> ControlRoutine<S> {
authorisation: our_capability_handle,
};
// TODO: We could skip the clone if we re-enabled sending by reference.
session
.state_mut()
.bind_area_of_interest(Scope::Ours, msg.clone())?;
session.bind_area_of_interest(Scope::Ours, msg.clone())?;
session.send(msg).await?;
}
}
debug!("setup done");
Ok(())
}

fn state(&mut self) -> RefMut<SessionState> {
self.state.state_mut()
}
}

#[derive(derive_more::Debug)]
Expand All @@ -243,7 +216,7 @@ impl<S: Store> Reconciler<S> {
}

pub async fn run(mut self) -> Result<(), Error> {
let our_role = self.state().our_role;
let our_role = self.session.our_role();
loop {
tokio::select! {
message = self.recv.try_next() => {
Expand All @@ -258,7 +231,7 @@ impl<S: Store> Reconciler<S> {
}
}
}
if self.state().reconciliation_is_complete() {
if self.session.reconciliation_is_complete() {
debug!("reconciliation complete, close session");
break;
}
Expand Down Expand Up @@ -288,7 +261,6 @@ impl<S: Store> Reconciler<S> {
} = intersection;
let range = intersection.into_range();
let fingerprint = self.snapshot.fingerprint(namespace, &range)?;
self.session.state_mut().reconciliation_started = true;
self.send_fingerprint(range, fingerprint, our_handle, their_handle, None)
.await?;
Ok(())
Expand All @@ -298,22 +270,16 @@ impl<S: Store> Reconciler<S> {
&mut self,
message: ReconciliationSendFingerprint,
) -> Result<(), Error> {
let namespace = self.session.on_send_fingerprint(&message)?;
trace!("on_send_fingerprint start");
let ReconciliationSendFingerprint {
range,
fingerprint: their_fingerprint,
sender_handle: their_handle,
receiver_handle: our_handle,
is_final_reply_for_range,
is_final_reply_for_range: _,
} = message;

let namespace = {
let mut state = self.state();
state.reconciliation_started = true;
state.clear_pending_range_if_some(our_handle, is_final_reply_for_range)?;
state.range_is_authorised(&range, &our_handle, &their_handle)?
};

let our_fingerprint = self.snapshot.fingerprint(namespace, &range)?;

// case 1: fingerprint match.
Expand Down Expand Up @@ -356,28 +322,17 @@ impl<S: Store> Reconciler<S> {
message: ReconciliationAnnounceEntries,
) -> Result<(), Error> {
trace!("on_announce_entries start");
let namespace = self.session.on_announce_entries(&message)?;
let ReconciliationAnnounceEntries {
range,
count,
count: _,
want_response,
will_sort: _,
sender_handle: their_handle,
receiver_handle: our_handle,
is_final_reply_for_range,
is_final_reply_for_range: _,
} = message;

let namespace = {
let mut state = self.state();
state.clear_pending_range_if_some(our_handle, is_final_reply_for_range)?;
if state.pending_entries.is_some() {
return Err(Error::InvalidMessageInCurrentState);
}
let namespace = state.range_is_authorised(&range, &our_handle, &their_handle)?;
if count != 0 {
state.pending_entries = Some(count);
}
namespace
};
if want_response {
self.announce_and_send_entries(
namespace,
Expand All @@ -400,7 +355,7 @@ impl<S: Store> Reconciler<S> {
.get_their_resource_eventually(|r| &mut r.static_tokens, message.static_token_handle)
.await;

self.state().on_send_entry()?;
self.session.on_send_entry()?;

let authorised_entry = AuthorisedEntry::try_from_parts(
message.entry.entry,
Expand All @@ -421,9 +376,7 @@ impl<S: Store> Reconciler<S> {
their_handle: AreaOfInterestHandle,
is_final_reply_for_range: Option<ThreeDRange>,
) -> anyhow::Result<()> {
self.state()
.pending_ranges
.insert((our_handle, range.clone()));
self.session.insert_pending_range(our_handle, range.clone());
let msg = ReconciliationSendFingerprint {
range,
fingerprint,
Expand All @@ -446,8 +399,7 @@ impl<S: Store> Reconciler<S> {
our_count: Option<u64>,
) -> Result<(), Error> {
if want_response {
let mut state = self.state();
state.pending_ranges.insert((our_handle, range.clone()));
self.session.insert_pending_range(our_handle, range.clone());
}
let our_count = match our_count {
Some(count) => count,
Expand All @@ -472,10 +424,8 @@ impl<S: Store> Reconciler<S> {
let (static_token, dynamic_token) = token.into_parts();
// TODO: partial payloads
let available = entry.payload_length;
let (static_token_handle, static_token_bind_msg) = self
.session
.state_mut()
.bind_our_static_token(static_token)?;
let (static_token_handle, static_token_bind_msg) =
self.session.bind_our_static_token(static_token);
if let Some(msg) = static_token_bind_msg {
self.send(msg).await?;
}
Expand Down Expand Up @@ -535,10 +485,6 @@ impl<S: Store> Reconciler<S> {
Ok(())
}

fn state(&mut self) -> RefMut<SessionState> {
self.session.state_mut()
}

async fn send(&self, message: impl Into<Message>) -> Result<(), WriteError> {
self.session.send(message).await
}
Expand Down
Loading

0 comments on commit 7c327bf

Please sign in to comment.