diff --git a/src/client.rs b/src/client.rs index b80df31..1965b53 100644 --- a/src/client.rs +++ b/src/client.rs @@ -326,7 +326,7 @@ pub async fn connect( client_fn: impl FnOnce(Client) -> E, config: ClientConfig, ) -> (Client, impl Future>) { - let (to_socket_sender, to_socket_receiver) = mpsc::unbounded_channel(); + let (to_socket_sender, mut to_socket_receiver) = mpsc::unbounded_channel(); let (client_call_sender, client_call_receiver) = mpsc::unbounded_channel(); let handle = Client { to_socket_sender, @@ -334,15 +334,14 @@ pub async fn connect( }; let mut client = client_fn(handle.clone()); let future = tokio::spawn(async move { - let http_request = config.connect_http_request(); tracing::info!("connecting to {}...", config.url); - let (stream, _) = tokio_tungstenite::connect_async(http_request).await?; - if let Err(err) = client.on_connect().await { - tracing::error!("calling on_connect() failed due to {}", err); - return Err(err); - } - let socket = Socket::new(stream, config.socket_config.clone().unwrap_or_default()); + let Some(socket) = + client_connect(1usize, &config, &mut to_socket_receiver, &mut client).await? + else { + return Ok(()); + }; tracing::info!("connected to {}", config.url); + let mut actor = ClientActor { client, to_socket_receiver, @@ -393,9 +392,15 @@ impl ClientActor { match self.client.on_close(frame).await? { ClientCloseMode::Reconnect => { - if !self.reconnect().await? { - return Ok(()) - } + let Some(socket) = client_connect( + usize::MAX, + &self.config, + &mut self.to_socket_receiver, + &mut self.client, + ).await? else { + return Ok(()); + }; + self.socket = socket; }, ClientCloseMode::Close => return Ok(()) } @@ -410,9 +415,15 @@ impl ClientActor { match self.client.on_disconnect().await? { ClientCloseMode::Reconnect => { - if !self.reconnect().await? { - return Ok(()) - } + let Some(socket) = client_connect( + usize::MAX, + &self.config, + &mut self.to_socket_receiver, + &mut self.client, + ).await? else { + return Ok(()); + }; + self.socket = socket; }, ClientCloseMode::Close => return Ok(()) } @@ -425,61 +436,73 @@ impl ClientActor { Ok(()) } +} - /// Returns Ok(true) if reconnecting succeeded, Ok(false) if the client closed itself, and `Err` if an error occurred. - async fn reconnect(&mut self) -> Result { - for i in 1.. { - tracing::info!("reconnecting attempt no: {}...", i); - let connect_http_request = self.config.connect_http_request(); - let result = tokio_tungstenite::connect_async(connect_http_request).await; - match result { - Ok((socket, _)) => { - tracing::info!("successfully reconnected"); - if let Err(err) = self.client.on_connect().await { - tracing::error!("calling on_connect() failed due to {}", err); - } - self.socket = Socket::new( - socket, - self.config.socket_config.clone().unwrap_or_default(), - ); - return Ok(true); - } - Err(err) => { - tracing::warn!( - "reconnecting failed due to {}. will retry in {}s", - err, - self.config.reconnect_interval.as_secs() - ); - } - }; - // Discard messages until either the reconnect interval passes, the socket receiver disconnects, or - // the user sends a close message. - let sleep = tokio::time::sleep(self.config.reconnect_interval); - tokio::pin!(sleep); - loop { - tokio::select! { - _ = &mut sleep => break, - Some(inmessage) = self.to_socket_receiver.recv() => { - match &inmessage.message - { - Some(Message::Close(frame)) => { - tracing::trace!(?frame, "client closed itself while reconnecting"); - return Ok(false); - } - _ => { - tracing::warn!("client is reconnecting, discarding message from user"); - continue; - } - } - }, - else => { - tracing::warn!("client is dead, aborting reconnect"); - return Err(Error::from("client died while trying to reconnect")); - }, +/// Returns Ok(Some(socket)) if connecting succeeded, Ok(None) if the client closed itself, and `Err` if an error occurred. +async fn client_connect( + max_attempts: usize, + config: &ClientConfig, + to_socket_receiver: &mut mpsc::UnboundedReceiver, + client: &mut E, +) -> Result, Error> { + for i in 1.. { + // connection attempt + tracing::info!("connecting attempt no: {}...", i); + let connect_http_request = config.connect_http_request(); + let result = tokio_tungstenite::connect_async(connect_http_request).await; + match result { + Ok((socket, _)) => { + tracing::info!("successfully connected"); + if let Err(err) = client.on_connect().await { + tracing::error!("calling on_connect() failed due to {}", err); } + let socket = Socket::new(socket, config.socket_config.clone().unwrap_or_default()); + return Ok(Some(socket)); } + Err(err) => { + tracing::warn!( + "connecting failed due to {}, will retry in {}s", + err, + config.reconnect_interval.as_secs() + ); + } + }; + + // abort if we have exceeded the max attempts + if i >= max_attempts { + return Err(Error::from(format!( + "failed to connect after {} attempt(s), aborting...", + i + ))); } - Err(Error::from("client failed to reconnect")) + // Discard messages until either the connect interval passes, the socket receiver disconnects, or + // the user sends a close message. + let sleep = tokio::time::sleep(config.reconnect_interval); + tokio::pin!(sleep); + loop { + tokio::select! { + _ = &mut sleep => break, + Some(inmessage) = to_socket_receiver.recv() => { + match &inmessage.message + { + Some(Message::Close(frame)) => { + tracing::trace!(?frame, "client closed itself while connecting"); + return Ok(None); + } + _ => { + tracing::warn!("client is connecting, discarding message from user"); + continue; + } + } + }, + else => { + tracing::warn!("client is dead, aborting connection attempts"); + return Err(Error::from("client died while trying to connect")); + }, + } + } } + + Err(Error::from("client failed to connect")) }