Skip to content

Commit

Permalink
feat: implement timeout in native connection manager
Browse files Browse the repository at this point in the history
  • Loading branch information
valeriansaliou committed Aug 5, 2024
1 parent aa1bab3 commit 8eb09bf
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 50 deletions.
118 changes: 82 additions & 36 deletions src-tauri/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ use log::{debug, error, info, warn};
use serde::Serialize;
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::Duration;
use tauri::plugin::{Builder, TauriPlugin};
use tauri::{Manager, Runtime, State, Window};
use thiserror::Error;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::{self, JoinHandle};
use tokio::time::timeout;
use tokio_xmpp::connect::ServerConnector;
use tokio_xmpp::{AsyncClient as Client, Error, Event, Packet};

Expand All @@ -28,6 +30,8 @@ use tokio_xmpp::{AsyncClient as Client, Error, Event, Packet};
const EVENT_STATE: &'static str = "connection:state";
const EVENT_RECEIVE: &'static str = "connection:receive";

const READ_TIMEOUT_MILLISECONDS: u64 = 300000;

/**************************************************************************
* TYPES
* ************************************************************************* */
Expand Down Expand Up @@ -74,6 +78,8 @@ pub enum PollInputError {
AuthenticationError,
#[error("Connection error")]
ConnectionError,
#[error("Timeout error")]
TimeoutError,
#[error("Other error")]
OtherError,
}
Expand Down Expand Up @@ -167,17 +173,71 @@ fn recover_closed_sender_channel<R: Runtime>(
async fn poll_input_events<R: Runtime, C: ServerConnector>(
window: &Window<R>,
id: &str,
read_timeout: Duration,
mut client_reader: SplitStream<Client<C>>,
) -> Result<(), PollInputError> {
while let Some(event) = client_reader.next().await {
// Wrap client reader in a timeout task; this is especially important \
// since the underlying 'tokio-xmpp' does not implement any kind of \
// timeout whatsoever. This timeout duration is served from the \
// connection initiator, and will most likely depend on the PING \
// interval set by the client.
while let Ok(event_maybe) = timeout(read_timeout, client_reader.next()).await {
// Handle next event
if let Some(result) = handle_next_input_event(window, id, event_maybe) {
// We received a non-empty result: we have to stop the loop there!
return result;
}
}

// The next event did not come in due time, consider as timed out
warn!(
"Timed out waiting {}ms for next event on: #{}",
read_timeout.as_millis(),
id
);

// Abort here (timed out)
// Notice: the event loop has timed out, abort connection and error out.
emit_connection_abort(window, id, ConnectionState::ConnectionTimeout);

Err(PollInputError::TimeoutError)
}

async fn poll_output_events<C: ServerConnector>(
id: &str,
mut client_writer: SplitSink<Client<C>, Packet>,
mut rx: UnboundedReceiver<Packet>,
) -> Result<(), PollOutputError> {
while let Some(packet) = rx.recv().await {
if let Err(err) = client_writer.send(packet).await {
error!(
"Failed sending packet over connection: #{} because: {}",
id, err
);

return Err(PollOutputError::PacketSendError);
}

debug!("Sent packet over connection: #{}", id);
}

Ok(())
}

fn handle_next_input_event<R: Runtime>(
window: &Window<R>,
id: &str,
event_maybe: Option<Event>,
) -> Option<Result<(), PollInputError>> {
if let Some(event) = event_maybe {
match event {
Event::Disconnected(Error::Disconnected) => {
info!("Received disconnected event on: #{}", id);

emit_connection_abort(window, id, ConnectionState::Disconnected);

// Abort here (success)
return Ok(());
Some(Ok(()))
}
Event::Disconnected(Error::Auth(err)) => {
warn!(
Expand All @@ -188,27 +248,26 @@ async fn poll_input_events<R: Runtime, C: ServerConnector>(
emit_connection_abort(window, id, ConnectionState::AuthenticationFailure);

// Abort here (error)
return Err(PollInputError::AuthenticationError);
Some(Err(PollInputError::AuthenticationError))
}
Event::Disconnected(Error::Connection(err)) => {
warn!(
"Received disconnected event: #{}, with connection error: {}",
id, err
);

// Notice: consider as timeout here.
emit_connection_abort(window, id, ConnectionState::ConnectionTimeout);
emit_connection_abort(window, id, ConnectionState::ConnectionError);

// Abort here (error)
return Err(PollInputError::ConnectionError);
Some(Err(PollInputError::ConnectionError))
}
Event::Disconnected(err) => {
warn!("Received disconnected event: #{}, with error: {}", id, err);

emit_connection_abort(window, id, ConnectionState::ConnectionError);

// Abort here (error)
return Err(PollInputError::OtherError);
Some(Err(PollInputError::OtherError))
}
Event::Online { .. } => {
info!("Received connected event on: #{}", id);
Expand All @@ -224,7 +283,7 @@ async fn poll_input_events<R: Runtime, C: ServerConnector>(
.unwrap();

// Continue
continue;
None
}
Event::Stanza(stanza) => {
debug!("Received stanza event on: #{}", id);
Expand All @@ -242,33 +301,13 @@ async fn poll_input_events<R: Runtime, C: ServerConnector>(
.unwrap();

// Continue
continue;
None
}
}
} else {
// Abort here (normal stop)
Some(Ok(()))
}

Ok(())
}

async fn poll_output_events<C: ServerConnector>(
id: &str,
mut client_writer: SplitSink<Client<C>, Packet>,
mut rx: UnboundedReceiver<Packet>,
) -> Result<(), PollOutputError> {
while let Some(packet) = rx.recv().await {
if let Err(err) = client_writer.send(packet).await {
error!(
"Failed sending packet over connection: #{} because: {}",
id, err
);

return Err(PollOutputError::PacketSendError);
}

debug!("Sent packet over connection: #{}", id);
}

Ok(())
}

/**************************************************************************
Expand All @@ -282,6 +321,7 @@ pub fn connect<R: Runtime>(
id: &str,
jid: &str,
password: &str,
timeout: Option<u64>,
) -> Result<(), ConnectError> {
info!("Connection #{} connect requested on JID: {}", id, jid);

Expand Down Expand Up @@ -321,9 +361,6 @@ pub fn connect<R: Runtime>(
// Connections are single-use only
client.set_reconnect(false);

// TODO: implement some kind of timeout, because connection can be left in \
// a dangling state at this point, not connected, not disconnected.

// Split client into RX (for writer) and TX (for reader)
let (tx, rx) = mpsc::unbounded_channel();
let (writer, reader) = client.split();
Expand All @@ -333,6 +370,8 @@ pub fn connect<R: Runtime>(
let id = id.to_owned();

task::spawn(async move {
info!("Connection #{} write poller has started", id);

// Poll for output events
if let Err(err) = poll_output_events(&id, writer, rx).await {
warn!(
Expand All @@ -347,10 +386,17 @@ pub fn connect<R: Runtime>(

let read_handle = {
let id = id.to_owned();
let read_timeout = Duration::from_millis(timeout.unwrap_or(READ_TIMEOUT_MILLISECONDS));

task::spawn(async move {
info!(
"Connection #{} read poller has started (with timeout: {}ms)",
id,
read_timeout.as_millis()
);

// Poll for input events
if let Err(err) = poll_input_events(&window, &id, reader).await {
if let Err(err) = poll_input_events(&window, &id, read_timeout, reader).await {
warn!(
"Connection #{} read poller terminated with error: {}",
id, err
Expand Down
11 changes: 7 additions & 4 deletions src/broker/connection/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ abstract class BrokerConnection {
return "XMPP";
}

protected _pingIntervalTime(): number {
return runtimeContext === "application"
? TIMER_PING_INTERVAL_APPLICATION
: TIMER_PING_INTERVAL_DEFAULT;
}

protected _onInput(data: string): void {
// Trace raw input?
if (this.__config.logReceivedStanzas === true) {
Expand Down Expand Up @@ -89,10 +95,7 @@ abstract class BrokerConnection {
}

// Acquire interval times
const pingEvery =
runtimeContext === "application"
? TIMER_PING_INTERVAL_APPLICATION
: TIMER_PING_INTERVAL_DEFAULT;
const pingEvery = this._pingIntervalTime();
const timeoutEvery = TIMER_TIMEOUT_INTERVAL;

// Schedule timers
Expand Down
35 changes: 27 additions & 8 deletions src/broker/connection/native.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ import {
} from "@/utilities/runtime";
import logger from "@/utilities/logger";

/**************************************************************************
* CONSTANTS
* ************************************************************************* */

const TIMEOUT_PING_DELAY = 15000; // 15 seconds

/**************************************************************************
* CLASS
* ************************************************************************* */
Expand All @@ -49,16 +55,29 @@ class BrokerConnectionNativeTauri
fail: reject
});

// Acquire connection timeout
// Notice: timeout should be a function of the ping interval, with a \
// slight delay on the top so that the server has time to respond to \
// a given ping. This sets up a realistic timeout where we are sure \
// that something occurs on the network channel in the timeout \
// timeframe. If nothing happens in this same timeframe, this means \
// that the connection is likely dead and should be considered as \
// timed out by the underlying connection manager.
const timeout = this._pingIntervalTime() + TIMEOUT_PING_DELAY;

// Request connection to connect
// Important: trigger reject handler if runtime request failed
UtilitiesRuntime.requestConnectionConnect(id, jidString, password).catch(
error => {
// Intercept error for logging purposes
logger.error("Broker failed to request a connection connect", error);

reject(error);
}
);
UtilitiesRuntime.requestConnectionConnect(
id,
jidString,
password,
timeout
).catch(error => {
// Intercept error for logging purposes
logger.error("Broker failed to request a connection connect", error);

reject(error);
});
});
}

Expand Down
6 changes: 4 additions & 2 deletions src/utilities/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,16 @@ class UtilitiesRuntime {
async requestConnectionConnect(
id: RuntimeConnectionID,
jidString: string,
password: string
password: string,
timeout?: number
): Promise<void> {
if (this.__isApplication === true) {
// Request to connect via Tauri API (application build)
await tauriInvoke("plugin:connection|connect", {
jid: jidString,
password,
id
id,
timeout
});
} else {
// This method should NEVER be used on other platforms
Expand Down

0 comments on commit 8eb09bf

Please sign in to comment.