From bb75f7b5145dd5c7c79bfea1522f2762cec9dee2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Mon, 25 Mar 2024 14:10:43 +0800 Subject: [PATCH 1/2] Refactor: move state machine Handle and Worker into separate files --- openraft/src/core/raft_core.rs | 3 +- openraft/src/core/sm/handle.rs | 26 ++++ openraft/src/core/sm/mod.rs | 224 +-------------------------------- openraft/src/core/sm/worker.rs | 208 ++++++++++++++++++++++++++++++ openraft/src/engine/command.rs | 6 +- openraft/src/raft/mod.rs | 4 +- 6 files changed, 243 insertions(+), 228 deletions(-) create mode 100644 openraft/src/core/sm/handle.rs create mode 100644 openraft/src/core/sm/worker.rs diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 2150cd139..85dac7392 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -34,6 +34,7 @@ use crate::core::raft_msg::RaftMsg; use crate::core::raft_msg::ResultSender; use crate::core::raft_msg::VoteTx; use crate::core::sm; +use crate::core::sm::handle; use crate::core::sm::CommandSeq; use crate::core::ServerState; use crate::display_ext::DisplayOption; @@ -175,7 +176,7 @@ where pub(crate) log_store: LS, /// A controlling handle to the [`RaftStateMachine`] worker. - pub(crate) sm_handle: sm::Handle, + pub(crate) sm_handle: handle::Handle, pub(crate) engine: Engine, diff --git a/openraft/src/core/sm/handle.rs b/openraft/src/core/sm/handle.rs new file mode 100644 index 000000000..9ba60f859 --- /dev/null +++ b/openraft/src/core/sm/handle.rs @@ -0,0 +1,26 @@ +//! State machine control handle + +use tokio::sync::mpsc; + +use crate::alias::JoinHandleOf; +use crate::core::sm::Command; +use crate::RaftTypeConfig; + +/// State machine worker handle for sending command to it. +pub(crate) struct Handle +where C: RaftTypeConfig +{ + pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender>, + + #[allow(dead_code)] + pub(in crate::core::sm) join_handle: JoinHandleOf, +} + +impl Handle +where C: RaftTypeConfig +{ + pub(crate) fn send(&mut self, cmd: Command) -> Result<(), mpsc::error::SendError>> { + tracing::debug!("sending command to state machine worker: {:?}", cmd); + self.cmd_tx.send(cmd) + } +} diff --git a/openraft/src/core/sm/mod.rs b/openraft/src/core/sm/mod.rs index 871d56425..c423f3c03 100644 --- a/openraft/src/core/sm/mod.rs +++ b/openraft/src/core/sm/mod.rs @@ -4,233 +4,13 @@ //! It is responsible for applying log entries, building/receiving snapshot and sending responses //! to the RaftCore. -use tokio::sync::mpsc; - -use crate::async_runtime::AsyncOneshotSendExt; -use crate::core::ApplyResult; -use crate::core::ApplyingEntry; -use crate::entry::RaftPayload; -use crate::storage::RaftStateMachine; -use crate::summary::MessageSummary; -use crate::AsyncRuntime; -use crate::RaftLogId; -use crate::RaftSnapshotBuilder; -use crate::RaftTypeConfig; -use crate::Snapshot; -use crate::StorageError; - pub(crate) mod command; +pub(crate) mod handle; pub(crate) mod response; +pub(crate) mod worker; pub(crate) use command::Command; pub(crate) use command::CommandPayload; #[allow(unused_imports)] pub(crate) use command::CommandSeq; pub(crate) use response::CommandResult; pub(crate) use response::Response; - -use crate::core::notify::Notify; -use crate::core::raft_msg::ResultSender; -use crate::type_config::alias::JoinHandleOf; - -/// State machine worker handle for sending command to it. -pub(crate) struct Handle -where C: RaftTypeConfig -{ - cmd_tx: mpsc::UnboundedSender>, - #[allow(dead_code)] - join_handle: JoinHandleOf, -} - -impl Handle -where C: RaftTypeConfig -{ - pub(crate) fn send(&mut self, cmd: Command) -> Result<(), mpsc::error::SendError>> { - tracing::debug!("sending command to state machine worker: {:?}", cmd); - self.cmd_tx.send(cmd) - } -} - -pub(crate) struct Worker -where - C: RaftTypeConfig, - SM: RaftStateMachine, -{ - state_machine: SM, - - cmd_rx: mpsc::UnboundedReceiver>, - - resp_tx: mpsc::UnboundedSender>, -} - -impl Worker -where - C: RaftTypeConfig, - SM: RaftStateMachine, -{ - /// Spawn a new state machine worker, return a controlling handle. - pub(crate) fn spawn(state_machine: SM, resp_tx: mpsc::UnboundedSender>) -> Handle { - let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); - - let worker = Worker { - state_machine, - cmd_rx, - resp_tx, - }; - - let join_handle = worker.do_spawn(); - - Handle { cmd_tx, join_handle } - } - - fn do_spawn(mut self) -> JoinHandleOf { - C::AsyncRuntime::spawn(async move { - let res = self.worker_loop().await; - - if let Err(err) = res { - tracing::error!("{} while execute state machine command", err,); - - let _ = self.resp_tx.send(Notify::StateMachine { - command_result: CommandResult { - command_seq: 0, - result: Err(err), - }, - }); - } - }) - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn worker_loop(&mut self) -> Result<(), StorageError> { - loop { - let cmd = self.cmd_rx.recv().await; - let cmd = match cmd { - None => { - tracing::info!("{}: rx closed, state machine worker quit", func_name!()); - return Ok(()); - } - Some(x) => x, - }; - - tracing::debug!("{}: received command: {:?}", func_name!(), cmd); - - match cmd.payload { - CommandPayload::BuildSnapshot => { - tracing::info!("{}: build snapshot", func_name!()); - - // It is a read operation and is spawned, and it responds in another task - self.build_snapshot(cmd.seq, self.resp_tx.clone()).await; - } - CommandPayload::GetSnapshot { tx } => { - tracing::info!("{}: get snapshot", func_name!()); - - self.get_snapshot(tx).await?; - // GetSnapshot does not respond to RaftCore - } - CommandPayload::InstallFullSnapshot { snapshot } => { - tracing::info!("{}: install complete snapshot", func_name!()); - - let meta = snapshot.meta.clone(); - self.state_machine.install_snapshot(&meta, snapshot.snapshot).await?; - - tracing::info!("Done install complete snapshot, meta: {}", meta); - - let res = CommandResult::new(cmd.seq, Ok(Response::InstallSnapshot(Some(meta)))); - let _ = self.resp_tx.send(Notify::sm(res)); - } - CommandPayload::BeginReceivingSnapshot { tx } => { - tracing::info!("{}: BeginReceivingSnapshot", func_name!()); - - let snapshot_data = self.state_machine.begin_receiving_snapshot().await?; - - let _ = tx.send(Ok(snapshot_data)); - // No response to RaftCore - } - CommandPayload::Apply { entries } => { - let resp = self.apply(entries).await?; - let res = CommandResult::new(cmd.seq, Ok(Response::Apply(resp))); - let _ = self.resp_tx.send(Notify::sm(res)); - } - }; - } - } - #[tracing::instrument(level = "debug", skip_all)] - async fn apply(&mut self, entries: Vec) -> Result, StorageError> { - // TODO: prepare response before apply_to_state_machine, - // so that an Entry does not need to be Clone, - // and no references will be used by apply_to_state_machine - - let since = entries.first().map(|x| x.get_log_id().index).unwrap(); - let end = entries.last().map(|x| x.get_log_id().index + 1).unwrap(); - let last_applied = entries.last().map(|x| *x.get_log_id()).unwrap(); - - // Fake complain: avoid using `collect()` when not needed - #[allow(clippy::needless_collect)] - let applying_entries = entries - .iter() - .map(|e| ApplyingEntry::new(*e.get_log_id(), e.get_membership().cloned())) - .collect::>(); - - let n_entries = applying_entries.len(); - - let apply_results = self.state_machine.apply(entries).await?; - - let n_replies = apply_results.len(); - - debug_assert_eq!( - n_entries, n_replies, - "n_entries: {} should equal n_replies: {}", - n_entries, n_replies - ); - - let resp = ApplyResult { - since, - end, - last_applied, - applying_entries, - apply_results, - }; - - Ok(resp) - } - - /// Build a snapshot from the state machine. - /// - /// Building snapshot is a read-only operation, so it can be run in another task in parallel. - /// This parallelization depends on the [`RaftSnapshotBuilder`] implementation returned by - /// [`get_snapshot_builder()`](`RaftStateMachine::get_snapshot_builder()`): The builder must: - /// - hold a consistent view of the state machine that won't be affected by further writes such - /// as applying a log entry, - /// - or it must be able to acquire a lock that prevents any write operations. - #[tracing::instrument(level = "info", skip_all)] - async fn build_snapshot(&mut self, seq: CommandSeq, resp_tx: mpsc::UnboundedSender>) { - // TODO: need to be abortable? - // use futures::future::abortable; - // let (fu, abort_handle) = abortable(async move { builder.build_snapshot().await }); - - tracing::info!("{}", func_name!()); - - let mut builder = self.state_machine.get_snapshot_builder().await; - - let _handle = C::AsyncRuntime::spawn(async move { - let res = builder.build_snapshot().await; - let res = res.map(|snap| Response::BuildSnapshot(snap.meta)); - let cmd_res = CommandResult::new(seq, res); - let _ = resp_tx.send(Notify::sm(cmd_res)); - }); - tracing::info!("{} returning; spawned building snapshot task", func_name!()); - } - - #[tracing::instrument(level = "info", skip_all)] - async fn get_snapshot(&mut self, tx: ResultSender>>) -> Result<(), StorageError> { - tracing::info!("{}", func_name!()); - - let snapshot = self.state_machine.get_current_snapshot().await?; - - tracing::info!( - "sending back snapshot: meta: {:?}", - snapshot.as_ref().map(|s| s.meta.summary()) - ); - let _ = tx.send(Ok(snapshot)); - Ok(()) - } -} diff --git a/openraft/src/core/sm/worker.rs b/openraft/src/core/sm/worker.rs new file mode 100644 index 000000000..079d3a1c5 --- /dev/null +++ b/openraft/src/core/sm/worker.rs @@ -0,0 +1,208 @@ +use tokio::sync::mpsc; + +use crate::alias::JoinHandleOf; +use crate::async_runtime::AsyncOneshotSendExt; +use crate::core::notify::Notify; +use crate::core::raft_msg::ResultSender; +use crate::core::sm::handle::Handle; +use crate::core::sm::Command; +use crate::core::sm::CommandPayload; +use crate::core::sm::CommandResult; +use crate::core::sm::CommandSeq; +use crate::core::sm::Response; +use crate::core::ApplyResult; +use crate::core::ApplyingEntry; +use crate::display_ext::DisplayOptionExt; +use crate::entry::RaftPayload; +use crate::storage::RaftStateMachine; +use crate::AsyncRuntime; +use crate::RaftLogId; +use crate::RaftSnapshotBuilder; +use crate::RaftTypeConfig; +use crate::Snapshot; +use crate::StorageError; + +pub(crate) struct Worker +where + C: RaftTypeConfig, + SM: RaftStateMachine, +{ + state_machine: SM, + + cmd_rx: mpsc::UnboundedReceiver>, + + resp_tx: mpsc::UnboundedSender>, +} + +impl Worker +where + C: RaftTypeConfig, + SM: RaftStateMachine, +{ + /// Spawn a new state machine worker, return a controlling handle. + pub(crate) fn spawn(state_machine: SM, resp_tx: mpsc::UnboundedSender>) -> Handle { + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); + + let worker = Worker { + state_machine, + cmd_rx, + resp_tx, + }; + + let join_handle = worker.do_spawn(); + + Handle { cmd_tx, join_handle } + } + + fn do_spawn(mut self) -> JoinHandleOf { + C::AsyncRuntime::spawn(async move { + let res = self.worker_loop().await; + + if let Err(err) = res { + tracing::error!("{} while execute state machine command", err,); + + let _ = self.resp_tx.send(Notify::StateMachine { + command_result: CommandResult { + command_seq: 0, + result: Err(err), + }, + }); + } + }) + } + + #[tracing::instrument(level = "debug", skip_all)] + async fn worker_loop(&mut self) -> Result<(), StorageError> { + loop { + let cmd = self.cmd_rx.recv().await; + let cmd = match cmd { + None => { + tracing::info!("{}: rx closed, state machine worker quit", func_name!()); + return Ok(()); + } + Some(x) => x, + }; + + tracing::debug!("{}: received command: {:?}", func_name!(), cmd); + + match cmd.payload { + CommandPayload::BuildSnapshot => { + tracing::info!("{}: build snapshot", func_name!()); + + // It is a read operation and is spawned, and it responds in another task + self.build_snapshot(cmd.seq, self.resp_tx.clone()).await; + } + CommandPayload::GetSnapshot { tx } => { + tracing::info!("{}: get snapshot", func_name!()); + + self.get_snapshot(tx).await?; + // GetSnapshot does not respond to RaftCore + } + CommandPayload::InstallFullSnapshot { snapshot } => { + tracing::info!("{}: install complete snapshot", func_name!()); + + let meta = snapshot.meta.clone(); + self.state_machine.install_snapshot(&meta, snapshot.snapshot).await?; + + tracing::info!("Done install complete snapshot, meta: {}", meta); + + let res = CommandResult::new(cmd.seq, Ok(Response::InstallSnapshot(Some(meta)))); + let _ = self.resp_tx.send(Notify::sm(res)); + } + CommandPayload::BeginReceivingSnapshot { tx } => { + tracing::info!("{}: BeginReceivingSnapshot", func_name!()); + + let snapshot_data = self.state_machine.begin_receiving_snapshot().await?; + + let _ = tx.send(Ok(snapshot_data)); + // No response to RaftCore + } + CommandPayload::Apply { entries } => { + let resp = self.apply(entries).await?; + let res = CommandResult::new(cmd.seq, Ok(Response::Apply(resp))); + let _ = self.resp_tx.send(Notify::sm(res)); + } + }; + } + } + #[tracing::instrument(level = "debug", skip_all)] + async fn apply(&mut self, entries: Vec) -> Result, StorageError> { + // TODO: prepare response before apply_to_state_machine, + // so that an Entry does not need to be Clone, + // and no references will be used by apply_to_state_machine + + let since = entries.first().map(|x| x.get_log_id().index).unwrap(); + let end = entries.last().map(|x| x.get_log_id().index + 1).unwrap(); + let last_applied = entries.last().map(|x| *x.get_log_id()).unwrap(); + + // Fake complain: avoid using `collect()` when not needed + #[allow(clippy::needless_collect)] + let applying_entries = entries + .iter() + .map(|e| ApplyingEntry::new(*e.get_log_id(), e.get_membership().cloned())) + .collect::>(); + + let n_entries = applying_entries.len(); + + let apply_results = self.state_machine.apply(entries).await?; + + let n_replies = apply_results.len(); + + debug_assert_eq!( + n_entries, n_replies, + "n_entries: {} should equal n_replies: {}", + n_entries, n_replies + ); + + let resp = ApplyResult { + since, + end, + last_applied, + applying_entries, + apply_results, + }; + + Ok(resp) + } + + /// Build a snapshot from the state machine. + /// + /// Building snapshot is a read-only operation, so it can be run in another task in parallel. + /// This parallelization depends on the [`RaftSnapshotBuilder`] implementation returned by + /// [`get_snapshot_builder()`](`RaftStateMachine::get_snapshot_builder()`): The builder must: + /// - hold a consistent view of the state machine that won't be affected by further writes such + /// as applying a log entry, + /// - or it must be able to acquire a lock that prevents any write operations. + #[tracing::instrument(level = "info", skip_all)] + async fn build_snapshot(&mut self, seq: CommandSeq, resp_tx: mpsc::UnboundedSender>) { + // TODO: need to be abortable? + // use futures::future::abortable; + // let (fu, abort_handle) = abortable(async move { builder.build_snapshot().await }); + + tracing::info!("{}", func_name!()); + + let mut builder = self.state_machine.get_snapshot_builder().await; + + let _handle = C::AsyncRuntime::spawn(async move { + let res = builder.build_snapshot().await; + let res = res.map(|snap| Response::BuildSnapshot(snap.meta)); + let cmd_res = CommandResult::new(seq, res); + let _ = resp_tx.send(Notify::sm(cmd_res)); + }); + tracing::info!("{} returning; spawned building snapshot task", func_name!()); + } + + #[tracing::instrument(level = "info", skip_all)] + async fn get_snapshot(&mut self, tx: ResultSender>>) -> Result<(), StorageError> { + tracing::info!("{}", func_name!()); + + let snapshot = self.state_machine.get_current_snapshot().await?; + + tracing::info!( + "sending back snapshot: meta: {}", + snapshot.as_ref().map(|s| &s.meta).display() + ); + let _ = tx.send(Ok(snapshot)); + Ok(()) + } +} diff --git a/openraft/src/engine/command.rs b/openraft/src/engine/command.rs index c81c3c2b0..2e6856323 100644 --- a/openraft/src/engine/command.rs +++ b/openraft/src/engine/command.rs @@ -89,9 +89,9 @@ where C: RaftTypeConfig DeleteConflictLog { since: LogId }, // TODO(1): current it is only used to replace BuildSnapshot, InstallSnapshot, CancelSnapshot. - /// A command send to state machine worker [`sm::Worker`]. + /// A command send to state machine worker [`worker::Worker`]. /// - /// The runtime(`RaftCore`) will just forward this command to [`sm::Worker`]. + /// The runtime(`RaftCore`) will just forward this command to [`worker::Worker`]. /// The response will be sent back in a `RaftMsg::StateMachine` message to `RaftCore`. StateMachine { command: sm::Command }, @@ -212,7 +212,7 @@ where NID: NodeId #[allow(dead_code)] Applied { log_id: Option> }, - /// Wait until a [`sm::Worker`] command is finished. + /// Wait until a [`worker::Worker`] command is finished. #[allow(dead_code)] StateMachineCommand { command_seq: sm::CommandSeq }, } diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index f910cb1c6..dc9a445b6 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -40,7 +40,7 @@ use crate::core::command_state::CommandState; use crate::core::raft_msg::external_command::ExternalCommand; use crate::core::raft_msg::RaftMsg; use crate::core::replication_lag; -use crate::core::sm; +use crate::core::sm::worker; use crate::core::RaftCore; use crate::core::Tick; use crate::engine::Engine; @@ -209,7 +209,7 @@ where C: RaftTypeConfig let engine = Engine::new(state, eng_config); - let sm_handle = sm::Worker::spawn(state_machine, tx_notify.clone()); + let sm_handle = worker::Worker::spawn(state_machine, tx_notify.clone()); let core: RaftCore = RaftCore { id, From c818a5818389fbbf9de995356115b9082c038cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Mon, 25 Mar 2024 15:35:16 +0800 Subject: [PATCH 2/2] Refactor: `ReplicationCore` get a snapshot directly from state machine, via `SnapshotReader` --- openraft/src/core/raft_core.rs | 18 ++------ openraft/src/core/raft_msg/mod.rs | 3 -- openraft/src/core/sm/handle.rs | 66 +++++++++++++++++++++++++++-- openraft/src/core/sm/worker.rs | 2 +- openraft/src/replication/mod.rs | 22 +++++----- openraft/src/replication/request.rs | 14 +++--- 6 files changed, 86 insertions(+), 39 deletions(-) diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 85dac7392..4381f7e98 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -822,6 +822,7 @@ where network, snapshot_network, self.log_store.get_log_reader().await, + self.sm_handle.new_snapshot_reader(), self.tx_notify.clone(), tracing::span!(parent: &self.span, Level::DEBUG, "replication", id=display(self.id), target=display(target)), ) @@ -1674,21 +1675,10 @@ where let _ = node.tx_repl.send(Replicate::logs(RequestId::new_append_entries(id), log_id_range)); } Inflight::Snapshot { id, last_log_id } => { - let _ = last_log_id; - - // Create a channel to let state machine worker to send the snapshot and the replication - // worker to receive it. - let (tx, rx) = C::AsyncRuntime::oneshot(); - - let cmd = sm::Command::get_snapshot(tx); - self.sm_handle - .send(cmd) - .map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?; - // unwrap: The replication channel must not be dropped or it is a bug. - node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), rx)).map_err(|_e| { - StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")) - })?; + node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), last_log_id)).map_err( + |_e| StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")), + )?; } } } else { diff --git a/openraft/src/core/raft_msg/mod.rs b/openraft/src/core/raft_msg/mod.rs index 72ddfbc7e..f72e90409 100644 --- a/openraft/src/core/raft_msg/mod.rs +++ b/openraft/src/core/raft_msg/mod.rs @@ -13,7 +13,6 @@ use crate::raft::SnapshotResponse; use crate::raft::VoteRequest; use crate::raft::VoteResponse; use crate::type_config::alias::LogIdOf; -use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; use crate::type_config::alias::SnapshotDataOf; use crate::ChangeMembers; @@ -27,8 +26,6 @@ pub(crate) mod external_command; /// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`. pub(crate) type ResultSender = OneshotSenderOf>; -pub(crate) type ResultReceiver = OneshotReceiverOf>; - /// TX for Vote Response pub(crate) type VoteTx = ResultSender>; diff --git a/openraft/src/core/sm/handle.rs b/openraft/src/core/sm/handle.rs index 9ba60f859..8a718663c 100644 --- a/openraft/src/core/sm/handle.rs +++ b/openraft/src/core/sm/handle.rs @@ -2,15 +2,18 @@ use tokio::sync::mpsc; -use crate::alias::JoinHandleOf; -use crate::core::sm::Command; +use crate::core::sm; +use crate::type_config::alias::AsyncRuntimeOf; +use crate::type_config::alias::JoinHandleOf; +use crate::AsyncRuntime; use crate::RaftTypeConfig; +use crate::Snapshot; /// State machine worker handle for sending command to it. pub(crate) struct Handle where C: RaftTypeConfig { - pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender>, + pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender>, #[allow(dead_code)] pub(in crate::core::sm) join_handle: JoinHandleOf, @@ -19,8 +22,63 @@ where C: RaftTypeConfig impl Handle where C: RaftTypeConfig { - pub(crate) fn send(&mut self, cmd: Command) -> Result<(), mpsc::error::SendError>> { + pub(crate) fn send(&mut self, cmd: sm::Command) -> Result<(), mpsc::error::SendError>> { tracing::debug!("sending command to state machine worker: {:?}", cmd); self.cmd_tx.send(cmd) } + + /// Create a [`SnapshotReader`] to get the current snapshot from the state machine. + pub(crate) fn new_snapshot_reader(&self) -> SnapshotReader { + SnapshotReader { + cmd_tx: self.cmd_tx.downgrade(), + } + } +} + +/// A handle for retrieving a snapshot from the state machine. +pub(crate) struct SnapshotReader +where C: RaftTypeConfig +{ + /// Weak command sender to the state machine worker. + /// + /// It is weak because the [`Worker`] watches the close event of this channel for shutdown. + /// + /// [`Worker`]: sm::worker::Worker + cmd_tx: mpsc::WeakUnboundedSender>, +} + +impl SnapshotReader +where C: RaftTypeConfig +{ + /// Get a snapshot from the state machine. + /// + /// If the state machine worker has shutdown, it will return an error. + /// If there is not snapshot available, it will return `Ok(None)`. + pub(crate) async fn get_snapshot(&self) -> Result>, &'static str> { + let (tx, rx) = AsyncRuntimeOf::::oneshot(); + + let cmd = sm::Command::get_snapshot(tx); + tracing::debug!("SnapshotReader sending command to sm::Worker: {:?}", cmd); + + let Some(cmd_tx) = self.cmd_tx.upgrade() else { + tracing::info!("failed to upgrade cmd_tx, sm::Worker may have shutdown"); + return Err("failed to upgrade cmd_tx, sm::Worker may have shutdown"); + }; + + // If fail to send command, cmd is dropped and tx will be dropped. + let _ = cmd_tx.send(cmd); + + let got = match rx.await { + Ok(x) => x, + Err(_e) => { + tracing::error!("failed to receive snapshot, sm::Worker may have shutdown"); + return Err("failed to receive snapshot, sm::Worker may have shutdown"); + } + }; + + // Safe unwrap(): error is Infallible. + let snapshot = got.unwrap(); + + Ok(snapshot) + } } diff --git a/openraft/src/core/sm/worker.rs b/openraft/src/core/sm/worker.rs index 079d3a1c5..1e69d886e 100644 --- a/openraft/src/core/sm/worker.rs +++ b/openraft/src/core/sm/worker.rs @@ -1,6 +1,5 @@ use tokio::sync::mpsc; -use crate::alias::JoinHandleOf; use crate::async_runtime::AsyncOneshotSendExt; use crate::core::notify::Notify; use crate::core::raft_msg::ResultSender; @@ -15,6 +14,7 @@ use crate::core::ApplyingEntry; use crate::display_ext::DisplayOptionExt; use crate::entry::RaftPayload; use crate::storage::RaftStateMachine; +use crate::type_config::alias::JoinHandleOf; use crate::AsyncRuntime; use crate::RaftLogId; use crate::RaftSnapshotBuilder; diff --git a/openraft/src/replication/mod.rs b/openraft/src/replication/mod.rs index 80d31975c..d437607ce 100644 --- a/openraft/src/replication/mod.rs +++ b/openraft/src/replication/mod.rs @@ -26,7 +26,7 @@ use tracing_futures::Instrument; use crate::config::Config; use crate::core::notify::Notify; -use crate::core::raft_msg::ResultReceiver; +use crate::core::sm::handle::SnapshotReader; use crate::display_ext::DisplayOptionExt; use crate::error::HigherVote; use crate::error::PayloadTooLarge; @@ -53,6 +53,7 @@ use crate::storage::Snapshot; use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::InstantOf; use crate::type_config::alias::JoinHandleOf; +use crate::type_config::alias::LogIdOf; use crate::AsyncRuntime; use crate::Instant; use crate::LogId; @@ -127,6 +128,9 @@ where /// The [`RaftLogStorage::LogReader`] interface. log_reader: LS::LogReader, + /// The handle to get a snapshot directly from state machine. + snapshot_reader: SnapshotReader, + /// The Raft's runtime config. config: Arc, @@ -163,6 +167,7 @@ where network: N::Network, snapshot_network: N::Network, log_reader: LS::LogReader, + snapshot_reader: SnapshotReader, tx_raft_core: mpsc::UnboundedSender>, span: tracing::Span, ) -> ReplicationHandle { @@ -185,6 +190,7 @@ where snapshot_state: None, backoff: None, log_reader, + snapshot_reader, config, committed, matching, @@ -697,21 +703,17 @@ where #[tracing::instrument(level = "info", skip_all)] async fn stream_snapshot( &mut self, - snapshot_rx: DataWithId>>>, + snapshot_req: DataWithId>>, ) -> Result>, ReplicationError> { - let request_id = snapshot_rx.request_id(); - let rx = snapshot_rx.into_data(); + let request_id = snapshot_req.request_id(); tracing::info!(request_id = display(request_id), "{}", func_name!()); - let snapshot = rx.await.map_err(|e| { - let io_err = StorageIOError::read_snapshot(None, AnyError::error(e)); - StorageError::IO { source: io_err } + let snapshot = self.snapshot_reader.get_snapshot().await.map_err(|reason| { + tracing::warn!(error = display(&reason), "failed to get snapshot from state machine"); + ReplicationClosed::new(reason) })?; - // Safe unwrap(): the error is Infallible, so it is safe to unwrap. - let snapshot = snapshot.unwrap(); - tracing::info!( "received snapshot: request_id={}; meta:{}", request_id, diff --git a/openraft/src/replication/request.rs b/openraft/src/replication/request.rs index df1b6aa13..92bc2d573 100644 --- a/openraft/src/replication/request.rs +++ b/openraft/src/replication/request.rs @@ -1,5 +1,7 @@ use std::fmt; +use crate::type_config::alias::LogIdOf; + /// A replication request sent by RaftCore leader state to replication stream. #[derive(Debug)] pub(crate) enum Replicate @@ -22,8 +24,8 @@ where C: RaftTypeConfig Self::Data(Data::new_logs(id, log_id_range)) } - pub(crate) fn snapshot(id: RequestId, snapshot_rx: ResultReceiver>>) -> Self { - Self::Data(Data::new_snapshot(id, snapshot_rx)) + pub(crate) fn snapshot(id: RequestId, last_log_id: Option>) -> Self { + Self::Data(Data::new_snapshot(id, last_log_id)) } pub(crate) fn new_data(data: Data) -> Self { @@ -49,7 +51,6 @@ where C: RaftTypeConfig } } -use crate::core::raft_msg::ResultReceiver; use crate::display_ext::DisplayOptionExt; use crate::error::Fatal; use crate::error::StreamingError; @@ -61,7 +62,6 @@ use crate::type_config::alias::InstantOf; use crate::LogId; use crate::MessageSummary; use crate::RaftTypeConfig; -use crate::Snapshot; use crate::SnapshotMeta; /// Request to replicate a chunk of data, logs or snapshot. @@ -74,7 +74,7 @@ where C: RaftTypeConfig { Heartbeat, Logs(DataWithId>), - Snapshot(DataWithId>>>), + Snapshot(DataWithId>>), SnapshotCallback(DataWithId>), } @@ -143,8 +143,8 @@ where C: RaftTypeConfig Self::Logs(DataWithId::new(request_id, log_id_range)) } - pub(crate) fn new_snapshot(request_id: RequestId, snapshot_rx: ResultReceiver>>) -> Self { - Self::Snapshot(DataWithId::new(request_id, snapshot_rx)) + pub(crate) fn new_snapshot(request_id: RequestId, last_log_id: Option>) -> Self { + Self::Snapshot(DataWithId::new(request_id, last_log_id)) } pub(crate) fn new_snapshot_callback(