From 1b55a6a1abd9bc2f726430e4c049f5abadf497e8 Mon Sep 17 00:00:00 2001 From: UkoeHB <37489173+UkoeHB@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:13:15 -0500 Subject: [PATCH] emit result when sink shuts down so client can correctly repond to the shut down reason (#89) --- src/client.rs | 21 ++++++++++++++++++-- src/session.rs | 3 ++- src/socket.rs | 52 +++++++++++++++++++++++++++++++++----------------- 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/src/client.rs b/src/client.rs index 506eed9..d69722e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -64,6 +64,7 @@ use std::time::Duration; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio_tungstenite::tungstenite; +use tokio_tungstenite::tungstenite::error::Error as WSError; use url::Url; pub const DEFAULT_RECONNECT_INTERVAL: Duration = Duration::new(5, 0); @@ -394,9 +395,24 @@ impl ClientActor { loop { tokio::select! { Some(inmessage) = self.to_socket_receiver.recv() => { - let mut closed_self = matches!(inmessage.message, Some(Message::Close(_))); + let closed_self = matches!(inmessage.message, Some(Message::Close(_))); if self.socket.send(inmessage).await.is_err() { - closed_self = true; + match self.socket.await_sink_close().await { + Err(WSError::ConnectionClosed) | + Err(WSError::AlreadyClosed) | + Err(WSError::Io(_)) | + Err(WSError::Tls(_)) => { + // either: + // A) The connection was closed via the close protocol, so we will allow the stream to + // handle it. + // B) We already tried and failed to submit another message, so now we are + // waiting for other parts of the tokio::select to shut us down. + // C) An IO error means the connection closed unexpectedly, so we can try to reconnect when + // the stream fails. + } + Err(_) if !closed_self => return Err(Error::from("unexpected sink error, aborting client actor")), + _ => (), + } } if closed_self { tracing::trace!("client closed itself"); @@ -433,6 +449,7 @@ impl ClientActor { }; } Some(Err(error)) => { + let error = Error::from(error); tracing::warn!("connection error: {error}"); } None => { diff --git a/src/session.rs b/src/session.rs index 87c8f5d..837fb80 100644 --- a/src/session.rs +++ b/src/session.rs @@ -94,7 +94,8 @@ impl Session { #[doc(hidden)] /// WARN: Use only if really nessesary. /// - /// this uses some hack, which takes ownership of underlying `oneshot::Receiver`, making it inaccessible for all future calls of this method. + /// This uses some hack, which takes ownership of underlying `oneshot::Receiver`, making it inaccessible + /// for all future calls of this method. pub(super) async fn await_close(&self) -> Result, Error> { let mut closed_indicator = self.closed_indicator.lock().await; let closed_indicator = closed_indicator diff --git a/src/socket.rs b/src/socket.rs index c982cbc..ea15706 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,4 +1,3 @@ -use crate::Error; use futures::{SinkExt, StreamExt, TryStreamExt}; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; @@ -11,6 +10,7 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::sync::Mutex; use tokio::task::JoinHandle; +use tokio_tungstenite::tungstenite::error::Error as WSError; /// Wrapper trait for `Fn(Duration) -> RawMessage`. pub trait SocketHeartbeatPingFn: Fn(Duration) -> RawMessage + Sync + Send {} @@ -317,7 +317,7 @@ impl Drop for InMessage { struct SinkActor where M: From, - S: SinkExt + Unpin, + S: SinkExt + Unpin, { receiver: mpsc::UnboundedReceiver, abort_receiver: oneshot::Receiver<()>, @@ -328,9 +328,9 @@ where impl SinkActor where M: From, - S: SinkExt + Unpin, + S: SinkExt + Unpin, { - async fn run(&mut self) -> Result<(), Error> { + async fn run(&mut self) -> Result<(), WSError> { loop { tokio::select! { Some(mut inmessage) = self.receiver.recv() => { @@ -368,10 +368,10 @@ impl Sink { fn new( sink: S, abort_receiver: oneshot::Receiver<()>, - ) -> (tokio::task::JoinHandle>, Self) + ) -> (tokio::task::JoinHandle>, Self) where M: From + Send + 'static, - S: SinkExt + Unpin + Send + 'static, + S: SinkExt + Unpin + Send + 'static, { let (sender, receiver) = mpsc::unbounded_channel(); let mut actor = SinkActor { @@ -407,9 +407,9 @@ impl Sink { struct StreamActor where M: Into, - S: StreamExt> + Unpin, + S: StreamExt> + Unpin, { - sender: mpsc::UnboundedSender>, + sender: mpsc::UnboundedSender>, stream: S, last_alive: Arc>, } @@ -417,7 +417,7 @@ where impl StreamActor where M: Into, - S: StreamExt> + Unpin, + S: StreamExt> + Unpin, { async fn run(mut self) { while let Some(result) = self.stream.next().await { @@ -438,7 +438,7 @@ where let timestamp = Duration::from_millis(timestamp as u64); // TODO: handle overflow let latency = SystemTime::now() .duration_since(UNIX_EPOCH + timestamp) - .unwrap_or(Duration::default()); + .unwrap_or_default(); // TODO: handle time zone tracing::trace!("latency: {}ms", latency.as_millis()); } @@ -470,14 +470,14 @@ where #[derive(Debug)] pub struct Stream { - receiver: mpsc::UnboundedReceiver>, + receiver: mpsc::UnboundedReceiver>, } impl Stream { fn new(stream: S, last_alive: Arc>) -> (JoinHandle<()>, Self) where M: Into + std::fmt::Debug + Send + 'static, - S: StreamExt> + Unpin + Send + 'static, + S: StreamExt> + Unpin + Send + 'static, { let (sender, receiver) = mpsc::unbounded_channel(); let actor = StreamActor { @@ -490,7 +490,7 @@ impl Stream { (future, Self { receiver }) } - pub async fn recv(&mut self) -> Option> { + pub async fn recv(&mut self) -> Option> { self.receiver.recv().await } } @@ -499,13 +499,14 @@ impl Stream { pub struct Socket { pub sink: Sink, pub stream: Stream, + sink_result_receiver: Option>>, } impl Socket { pub fn new(socket: S, config: SocketConfig) -> Self where M: Into + From + std::fmt::Debug + Send + 'static, - E: Into, + E: Into, S: SinkExt + Unpin + StreamExt> + Unpin + Send + 'static, { let last_alive = Instant::now(); @@ -535,7 +536,7 @@ impl Socket { } let timestamp = SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap_or(Duration::default()); + .unwrap_or_default(); if sink .send_raw(InRawMessage::new((config.heartbeat_ping_msg_fn)(timestamp))) .await @@ -547,14 +548,20 @@ impl Socket { } }); + let (sink_result_sender, sink_result_receiver) = oneshot::channel(); tokio::spawn(async move { let _ = stream_future.await; let _ = sink_abort_sender.send(()); heartbeat_future.abort(); - let _ = sink_future.await; //todo: send result to socket + let _ = + sink_result_sender.send(sink_future.await.unwrap_or(Err(WSError::AlreadyClosed))); }); - Self { sink, stream } + Self { + sink, + stream, + sink_result_receiver: Some(sink_result_receiver), + } } pub async fn send( @@ -571,7 +578,16 @@ impl Socket { self.sink.send_raw(message).await } - pub async fn recv(&mut self) -> Option> { + pub async fn recv(&mut self) -> Option> { self.stream.recv().await } + + pub(crate) async fn await_sink_close(&mut self) -> Result<(), WSError> { + let Some(sink_result_receiver) = self.sink_result_receiver.take() else { + return Err(WSError::AlreadyClosed); + }; + sink_result_receiver + .await + .unwrap_or(Err(WSError::AlreadyClosed)) + } }