Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: ReplicationCore get a snapshot directly from state machine, via SnapshotReader #1084

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::core::raft_msg::RaftMsg;
use crate::core::raft_msg::ResultSender;
use crate::core::raft_msg::VoteTx;
use crate::core::sm;
use crate::core::sm::handle;
use crate::core::sm::CommandSeq;
use crate::core::ServerState;
use crate::display_ext::DisplayOption;
Expand Down Expand Up @@ -175,7 +176,7 @@ where
pub(crate) log_store: LS,

/// A controlling handle to the [`RaftStateMachine`] worker.
pub(crate) sm_handle: sm::Handle<C>,
pub(crate) sm_handle: handle::Handle<C>,

pub(crate) engine: Engine<C>,

Expand Down Expand Up @@ -821,6 +822,7 @@ where
network,
snapshot_network,
self.log_store.get_log_reader().await,
self.sm_handle.new_snapshot_reader(),
self.tx_notify.clone(),
tracing::span!(parent: &self.span, Level::DEBUG, "replication", id=display(self.id), target=display(target)),
)
Expand Down Expand Up @@ -1673,21 +1675,10 @@ where
let _ = node.tx_repl.send(Replicate::logs(RequestId::new_append_entries(id), log_id_range));
}
Inflight::Snapshot { id, last_log_id } => {
let _ = last_log_id;

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
.send(cmd)
.map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?;

// unwrap: The replication channel must not be dropped or it is a bug.
node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), rx)).map_err(|_e| {
StorageIOError::read_snapshot(None, AnyError::error("replication channel closed"))
})?;
node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), last_log_id)).map_err(
|_e| StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")),
)?;
}
}
} else {
Expand Down
3 changes: 0 additions & 3 deletions openraft/src/core/raft_msg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::raft::SnapshotResponse;
use crate::raft::VoteRequest;
use crate::raft::VoteResponse;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::ChangeMembers;
Expand All @@ -27,8 +26,6 @@ pub(crate) mod external_command;
/// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`.
pub(crate) type ResultSender<C, T, E = Infallible> = OneshotSenderOf<C, Result<T, E>>;

pub(crate) type ResultReceiver<C, T, E = Infallible> = OneshotReceiverOf<C, Result<T, E>>;

/// TX for Vote Response
pub(crate) type VoteTx<C> = ResultSender<C, VoteResponse<C>>;

Expand Down
84 changes: 84 additions & 0 deletions openraft/src/core/sm/handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//! State machine control handle

use tokio::sync::mpsc;

use crate::core::sm;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::JoinHandleOf;
use crate::AsyncRuntime;
use crate::RaftTypeConfig;
use crate::Snapshot;

/// State machine worker handle for sending command to it.
pub(crate) struct Handle<C>
where C: RaftTypeConfig
{
pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender<sm::Command<C>>,

#[allow(dead_code)]
pub(in crate::core::sm) join_handle: JoinHandleOf<C, ()>,
}

impl<C> Handle<C>
where C: RaftTypeConfig
{
pub(crate) fn send(&mut self, cmd: sm::Command<C>) -> Result<(), mpsc::error::SendError<sm::Command<C>>> {
tracing::debug!("sending command to state machine worker: {:?}", cmd);
self.cmd_tx.send(cmd)
}

/// Create a [`SnapshotReader`] to get the current snapshot from the state machine.
pub(crate) fn new_snapshot_reader(&self) -> SnapshotReader<C> {
SnapshotReader {
cmd_tx: self.cmd_tx.downgrade(),
}
}
}

/// A handle for retrieving a snapshot from the state machine.
pub(crate) struct SnapshotReader<C>
where C: RaftTypeConfig
{
/// Weak command sender to the state machine worker.
///
/// It is weak because the [`Worker`] watches the close event of this channel for shutdown.
///
/// [`Worker`]: sm::worker::Worker
cmd_tx: mpsc::WeakUnboundedSender<sm::Command<C>>,
}

impl<C> SnapshotReader<C>
where C: RaftTypeConfig
{
/// Get a snapshot from the state machine.
///
/// If the state machine worker has shutdown, it will return an error.
/// If there is not snapshot available, it will return `Ok(None)`.
pub(crate) async fn get_snapshot(&self) -> Result<Option<Snapshot<C>>, &'static str> {
let (tx, rx) = AsyncRuntimeOf::<C>::oneshot();

let cmd = sm::Command::get_snapshot(tx);
tracing::debug!("SnapshotReader sending command to sm::Worker: {:?}", cmd);

let Some(cmd_tx) = self.cmd_tx.upgrade() else {
tracing::info!("failed to upgrade cmd_tx, sm::Worker may have shutdown");
return Err("failed to upgrade cmd_tx, sm::Worker may have shutdown");
};

// If fail to send command, cmd is dropped and tx will be dropped.
let _ = cmd_tx.send(cmd);

let got = match rx.await {
Ok(x) => x,
Err(_e) => {
tracing::error!("failed to receive snapshot, sm::Worker may have shutdown");
return Err("failed to receive snapshot, sm::Worker may have shutdown");
}
};

// Safe unwrap(): error is Infallible.
let snapshot = got.unwrap();

Ok(snapshot)
}
}
224 changes: 2 additions & 222 deletions openraft/src/core/sm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,233 +4,13 @@
//! It is responsible for applying log entries, building/receiving snapshot and sending responses
//! to the RaftCore.

use tokio::sync::mpsc;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::core::ApplyResult;
use crate::core::ApplyingEntry;
use crate::entry::RaftPayload;
use crate::storage::RaftStateMachine;
use crate::summary::MessageSummary;
use crate::AsyncRuntime;
use crate::RaftLogId;
use crate::RaftSnapshotBuilder;
use crate::RaftTypeConfig;
use crate::Snapshot;
use crate::StorageError;

pub(crate) mod command;
pub(crate) mod handle;
pub(crate) mod response;
pub(crate) mod worker;

pub(crate) use command::Command;
pub(crate) use command::CommandPayload;
#[allow(unused_imports)] pub(crate) use command::CommandSeq;
pub(crate) use response::CommandResult;
pub(crate) use response::Response;

use crate::core::notify::Notify;
use crate::core::raft_msg::ResultSender;
use crate::type_config::alias::JoinHandleOf;

/// State machine worker handle for sending command to it.
pub(crate) struct Handle<C>
where C: RaftTypeConfig
{
cmd_tx: mpsc::UnboundedSender<Command<C>>,
#[allow(dead_code)]
join_handle: JoinHandleOf<C, ()>,
}

impl<C> Handle<C>
where C: RaftTypeConfig
{
pub(crate) fn send(&mut self, cmd: Command<C>) -> Result<(), mpsc::error::SendError<Command<C>>> {
tracing::debug!("sending command to state machine worker: {:?}", cmd);
self.cmd_tx.send(cmd)
}
}

pub(crate) struct Worker<C, SM>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
state_machine: SM,

cmd_rx: mpsc::UnboundedReceiver<Command<C>>,

resp_tx: mpsc::UnboundedSender<Notify<C>>,
}

impl<C, SM> Worker<C, SM>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
/// Spawn a new state machine worker, return a controlling handle.
pub(crate) fn spawn(state_machine: SM, resp_tx: mpsc::UnboundedSender<Notify<C>>) -> Handle<C> {
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();

let worker = Worker {
state_machine,
cmd_rx,
resp_tx,
};

let join_handle = worker.do_spawn();

Handle { cmd_tx, join_handle }
}

fn do_spawn(mut self) -> JoinHandleOf<C, ()> {
C::AsyncRuntime::spawn(async move {
let res = self.worker_loop().await;

if let Err(err) = res {
tracing::error!("{} while execute state machine command", err,);

let _ = self.resp_tx.send(Notify::StateMachine {
command_result: CommandResult {
command_seq: 0,
result: Err(err),
},
});
}
})
}

#[tracing::instrument(level = "debug", skip_all)]
async fn worker_loop(&mut self) -> Result<(), StorageError<C::NodeId>> {
loop {
let cmd = self.cmd_rx.recv().await;
let cmd = match cmd {
None => {
tracing::info!("{}: rx closed, state machine worker quit", func_name!());
return Ok(());
}
Some(x) => x,
};

tracing::debug!("{}: received command: {:?}", func_name!(), cmd);

match cmd.payload {
CommandPayload::BuildSnapshot => {
tracing::info!("{}: build snapshot", func_name!());

// It is a read operation and is spawned, and it responds in another task
self.build_snapshot(cmd.seq, self.resp_tx.clone()).await;
}
CommandPayload::GetSnapshot { tx } => {
tracing::info!("{}: get snapshot", func_name!());

self.get_snapshot(tx).await?;
// GetSnapshot does not respond to RaftCore
}
CommandPayload::InstallFullSnapshot { snapshot } => {
tracing::info!("{}: install complete snapshot", func_name!());

let meta = snapshot.meta.clone();
self.state_machine.install_snapshot(&meta, snapshot.snapshot).await?;

tracing::info!("Done install complete snapshot, meta: {}", meta);

let res = CommandResult::new(cmd.seq, Ok(Response::InstallSnapshot(Some(meta))));
let _ = self.resp_tx.send(Notify::sm(res));
}
CommandPayload::BeginReceivingSnapshot { tx } => {
tracing::info!("{}: BeginReceivingSnapshot", func_name!());

let snapshot_data = self.state_machine.begin_receiving_snapshot().await?;

let _ = tx.send(Ok(snapshot_data));
// No response to RaftCore
}
CommandPayload::Apply { entries } => {
let resp = self.apply(entries).await?;
let res = CommandResult::new(cmd.seq, Ok(Response::Apply(resp)));
let _ = self.resp_tx.send(Notify::sm(res));
}
};
}
}
#[tracing::instrument(level = "debug", skip_all)]
async fn apply(&mut self, entries: Vec<C::Entry>) -> Result<ApplyResult<C>, StorageError<C::NodeId>> {
// TODO: prepare response before apply_to_state_machine,
// so that an Entry does not need to be Clone,
// and no references will be used by apply_to_state_machine

let since = entries.first().map(|x| x.get_log_id().index).unwrap();
let end = entries.last().map(|x| x.get_log_id().index + 1).unwrap();
let last_applied = entries.last().map(|x| *x.get_log_id()).unwrap();

// Fake complain: avoid using `collect()` when not needed
#[allow(clippy::needless_collect)]
let applying_entries = entries
.iter()
.map(|e| ApplyingEntry::new(*e.get_log_id(), e.get_membership().cloned()))
.collect::<Vec<_>>();

let n_entries = applying_entries.len();

let apply_results = self.state_machine.apply(entries).await?;

let n_replies = apply_results.len();

debug_assert_eq!(
n_entries, n_replies,
"n_entries: {} should equal n_replies: {}",
n_entries, n_replies
);

let resp = ApplyResult {
since,
end,
last_applied,
applying_entries,
apply_results,
};

Ok(resp)
}

/// Build a snapshot from the state machine.
///
/// Building snapshot is a read-only operation, so it can be run in another task in parallel.
/// This parallelization depends on the [`RaftSnapshotBuilder`] implementation returned by
/// [`get_snapshot_builder()`](`RaftStateMachine::get_snapshot_builder()`): The builder must:
/// - hold a consistent view of the state machine that won't be affected by further writes such
/// as applying a log entry,
/// - or it must be able to acquire a lock that prevents any write operations.
#[tracing::instrument(level = "info", skip_all)]
async fn build_snapshot(&mut self, seq: CommandSeq, resp_tx: mpsc::UnboundedSender<Notify<C>>) {
// TODO: need to be abortable?
// use futures::future::abortable;
// let (fu, abort_handle) = abortable(async move { builder.build_snapshot().await });

tracing::info!("{}", func_name!());

let mut builder = self.state_machine.get_snapshot_builder().await;

let _handle = C::AsyncRuntime::spawn(async move {
let res = builder.build_snapshot().await;
let res = res.map(|snap| Response::BuildSnapshot(snap.meta));
let cmd_res = CommandResult::new(seq, res);
let _ = resp_tx.send(Notify::sm(cmd_res));
});
tracing::info!("{} returning; spawned building snapshot task", func_name!());
}

#[tracing::instrument(level = "info", skip_all)]
async fn get_snapshot(&mut self, tx: ResultSender<C, Option<Snapshot<C>>>) -> Result<(), StorageError<C::NodeId>> {
tracing::info!("{}", func_name!());

let snapshot = self.state_machine.get_current_snapshot().await?;

tracing::info!(
"sending back snapshot: meta: {:?}",
snapshot.as_ref().map(|s| s.meta.summary())
);
let _ = tx.send(Ok(snapshot));
Ok(())
}
}
Loading
Loading