Skip to content

Commit

Permalink
emit result when sink shuts down so client can correctly repond to th…
Browse files Browse the repository at this point in the history
…e shut down reason (#89)
  • Loading branch information
UkoeHB authored Oct 6, 2023
1 parent 7fe4bfe commit 1b55a6a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
21 changes: 19 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::error::Error as WSError;
use url::Url;

pub const DEFAULT_RECONNECT_INTERVAL: Duration = Duration::new(5, 0);
Expand Down Expand Up @@ -394,9 +395,24 @@ impl<E: ClientExt> ClientActor<E> {
loop {
tokio::select! {
Some(inmessage) = self.to_socket_receiver.recv() => {
let mut closed_self = matches!(inmessage.message, Some(Message::Close(_)));
let closed_self = matches!(inmessage.message, Some(Message::Close(_)));
if self.socket.send(inmessage).await.is_err() {
closed_self = true;
match self.socket.await_sink_close().await {
Err(WSError::ConnectionClosed) |
Err(WSError::AlreadyClosed) |
Err(WSError::Io(_)) |
Err(WSError::Tls(_)) => {
// either:
// A) The connection was closed via the close protocol, so we will allow the stream to
// handle it.
// B) We already tried and failed to submit another message, so now we are
// waiting for other parts of the tokio::select to shut us down.
// C) An IO error means the connection closed unexpectedly, so we can try to reconnect when
// the stream fails.
}
Err(_) if !closed_self => return Err(Error::from("unexpected sink error, aborting client actor")),
_ => (),
}
}
if closed_self {
tracing::trace!("client closed itself");
Expand Down Expand Up @@ -433,6 +449,7 @@ impl<E: ClientExt> ClientActor<E> {
};
}
Some(Err(error)) => {
let error = Error::from(error);
tracing::warn!("connection error: {error}");
}
None => {
Expand Down
3 changes: 2 additions & 1 deletion src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ impl<I: std::fmt::Display + Clone, C> Session<I, C> {
#[doc(hidden)]
/// WARN: Use only if really nessesary.
///
/// this uses some hack, which takes ownership of underlying `oneshot::Receiver`, making it inaccessible for all future calls of this method.
/// This uses some hack, which takes ownership of underlying `oneshot::Receiver`, making it inaccessible
/// for all future calls of this method.
pub(super) async fn await_close(&self) -> Result<Option<CloseFrame>, Error> {
let mut closed_indicator = self.closed_indicator.lock().await;
let closed_indicator = closed_indicator
Expand Down
52 changes: 34 additions & 18 deletions src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::Error;
use futures::{SinkExt, StreamExt, TryStreamExt};
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
Expand All @@ -11,6 +10,7 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::error::Error as WSError;

/// Wrapper trait for `Fn(Duration) -> RawMessage`.
pub trait SocketHeartbeatPingFn: Fn(Duration) -> RawMessage + Sync + Send {}
Expand Down Expand Up @@ -317,7 +317,7 @@ impl Drop for InMessage {
struct SinkActor<M, S>
where
M: From<RawMessage>,
S: SinkExt<M, Error = Error> + Unpin,
S: SinkExt<M, Error = WSError> + Unpin,
{
receiver: mpsc::UnboundedReceiver<InRawMessage>,
abort_receiver: oneshot::Receiver<()>,
Expand All @@ -328,9 +328,9 @@ where
impl<M, S> SinkActor<M, S>
where
M: From<RawMessage>,
S: SinkExt<M, Error = Error> + Unpin,
S: SinkExt<M, Error = WSError> + Unpin,
{
async fn run(&mut self) -> Result<(), Error> {
async fn run(&mut self) -> Result<(), WSError> {
loop {
tokio::select! {
Some(mut inmessage) = self.receiver.recv() => {
Expand Down Expand Up @@ -368,10 +368,10 @@ impl Sink {
fn new<M, S>(
sink: S,
abort_receiver: oneshot::Receiver<()>,
) -> (tokio::task::JoinHandle<Result<(), Error>>, Self)
) -> (tokio::task::JoinHandle<Result<(), WSError>>, Self)
where
M: From<RawMessage> + Send + 'static,
S: SinkExt<M, Error = Error> + Unpin + Send + 'static,
S: SinkExt<M, Error = WSError> + Unpin + Send + 'static,
{
let (sender, receiver) = mpsc::unbounded_channel();
let mut actor = SinkActor {
Expand Down Expand Up @@ -407,17 +407,17 @@ impl Sink {
struct StreamActor<M, S>
where
M: Into<RawMessage>,
S: StreamExt<Item = Result<M, Error>> + Unpin,
S: StreamExt<Item = Result<M, WSError>> + Unpin,
{
sender: mpsc::UnboundedSender<Result<Message, Error>>,
sender: mpsc::UnboundedSender<Result<Message, WSError>>,
stream: S,
last_alive: Arc<Mutex<Instant>>,
}

impl<M, S> StreamActor<M, S>
where
M: Into<RawMessage>,
S: StreamExt<Item = Result<M, Error>> + Unpin,
S: StreamExt<Item = Result<M, WSError>> + Unpin,
{
async fn run(mut self) {
while let Some(result) = self.stream.next().await {
Expand All @@ -438,7 +438,7 @@ where
let timestamp = Duration::from_millis(timestamp as u64); // TODO: handle overflow
let latency = SystemTime::now()
.duration_since(UNIX_EPOCH + timestamp)
.unwrap_or(Duration::default());
.unwrap_or_default();
// TODO: handle time zone
tracing::trace!("latency: {}ms", latency.as_millis());
}
Expand Down Expand Up @@ -470,14 +470,14 @@ where

#[derive(Debug)]
pub struct Stream {
receiver: mpsc::UnboundedReceiver<Result<Message, Error>>,
receiver: mpsc::UnboundedReceiver<Result<Message, WSError>>,
}

impl Stream {
fn new<M, S>(stream: S, last_alive: Arc<Mutex<Instant>>) -> (JoinHandle<()>, Self)
where
M: Into<RawMessage> + std::fmt::Debug + Send + 'static,
S: StreamExt<Item = Result<M, Error>> + Unpin + Send + 'static,
S: StreamExt<Item = Result<M, WSError>> + Unpin + Send + 'static,
{
let (sender, receiver) = mpsc::unbounded_channel();
let actor = StreamActor {
Expand All @@ -490,7 +490,7 @@ impl Stream {
(future, Self { receiver })
}

pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
self.receiver.recv().await
}
}
Expand All @@ -499,13 +499,14 @@ impl Stream {
pub struct Socket {
pub sink: Sink,
pub stream: Stream,
sink_result_receiver: Option<oneshot::Receiver<Result<(), WSError>>>,
}

impl Socket {
pub fn new<M, E: std::error::Error, S>(socket: S, config: SocketConfig) -> Self
where
M: Into<RawMessage> + From<RawMessage> + std::fmt::Debug + Send + 'static,
E: Into<Error>,
E: Into<WSError>,
S: SinkExt<M, Error = E> + Unpin + StreamExt<Item = Result<M, E>> + Unpin + Send + 'static,
{
let last_alive = Instant::now();
Expand Down Expand Up @@ -535,7 +536,7 @@ impl Socket {
}
let timestamp = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(Duration::default());
.unwrap_or_default();
if sink
.send_raw(InRawMessage::new((config.heartbeat_ping_msg_fn)(timestamp)))
.await
Expand All @@ -547,14 +548,20 @@ impl Socket {
}
});

let (sink_result_sender, sink_result_receiver) = oneshot::channel();
tokio::spawn(async move {
let _ = stream_future.await;
let _ = sink_abort_sender.send(());
heartbeat_future.abort();
let _ = sink_future.await; //todo: send result to socket
let _ =
sink_result_sender.send(sink_future.await.unwrap_or(Err(WSError::AlreadyClosed)));
});

Self { sink, stream }
Self {
sink,
stream,
sink_result_receiver: Some(sink_result_receiver),
}
}

pub async fn send(
Expand All @@ -571,7 +578,16 @@ impl Socket {
self.sink.send_raw(message).await
}

pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
self.stream.recv().await
}

pub(crate) async fn await_sink_close(&mut self) -> Result<(), WSError> {
let Some(sink_result_receiver) = self.sink_result_receiver.take() else {
return Err(WSError::AlreadyClosed);
};
sink_result_receiver
.await
.unwrap_or(Err(WSError::AlreadyClosed))
}
}

0 comments on commit 1b55a6a

Please sign in to comment.