Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add timeouts for accepting connections #78

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ delay_map = "0.1.2"
futures = "0.3.26"
rand = "0.8.5"
tokio = { version = "1.25.0", features = ["io-util", "rt-multi-thread", "macros", "net", "sync", "time"] }
tokio-util = { version = "0.7.8", features = ["time"] }
tracing = { version = "0.1.37", features = ["std", "attributes", "log"] }

[dev-dependencies]
Expand Down
11 changes: 11 additions & 0 deletions src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,25 @@ enum State<const N: usize> {
pub type Write = (Vec<u8>, oneshot::Sender<io::Result<usize>>);
pub type Read = (usize, oneshot::Sender<io::Result<Vec<u8>>>);

/// The configuration for a single uTP connection (i.e. stream).
#[derive(Clone, Copy, Debug)]
pub struct ConnectionConfig {
/// The maximum packet size that the connection can transmit.
pub max_packet_size: u16,
/// The maximum number of connection attempts to make for an outgoing connection.
pub max_conn_attempts: usize,
/// The maximum duration that the connection will remain idle before termination. The idle
/// countdown resets upon a local write on the connection and upon receipt of a remote packet.
pub max_idle_timeout: Duration,
/// The initial timeout to establish the connection.
pub initial_timeout: Duration,
/// The minimum timeout that can be assigned to an outgoing packet.
pub min_timeout: Duration,
/// The maximum timeout that can be assigned to an outgoing packet.
///
/// Note: In most circumstances, `max_timeout` should be strictly less than `max_idle_timeout`.
pub max_timeout: Duration,
/// The target packet delay (used to calibrate congestion control).
pub target_delay: Duration,
}

Expand Down
117 changes: 84 additions & 33 deletions src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::collections::HashMap;
use std::io;
use std::marker::Unpin;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};

use delay_map::HashMapDelay;
use futures::StreamExt;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot};
use tokio_util::time::DelayQueue;

use crate::cid::{ConnectionId, ConnectionIdGenerator, ConnectionPeer, StdConnectionIdGenerator};
use crate::conn::ConnectionConfig;
Expand All @@ -13,6 +18,8 @@ use crate::packet::{Packet, PacketType};
use crate::stream::UtpStream;
use crate::udp::AsyncUdpSocket;

const DEFAULT_ACCEPT_TIMEOUT: Duration = Duration::MAX;

type ConnChannel = mpsc::UnboundedSender<StreamEvent>;

struct Accept<P> {
Expand All @@ -25,7 +32,7 @@ const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
pub struct UtpSocket<P> {
conns: Arc<RwLock<HashMap<ConnectionId<P>, ConnChannel>>>,
cid_gen: Mutex<StdConnectionIdGenerator<P>>,
accepts: mpsc::UnboundedSender<(Accept<P>, Option<ConnectionId<P>>)>,
accepts: mpsc::UnboundedSender<(Accept<P>, Option<ConnectionId<P>>, Instant)>,
socket_events: mpsc::UnboundedSender<SocketEvent<P>>,
}

Expand All @@ -39,7 +46,7 @@ impl UtpSocket<SocketAddr> {

impl<P> UtpSocket<P>
where
P: ConnectionPeer + 'static,
P: ConnectionPeer + Unpin + 'static,
{
pub fn with_socket<S>(socket: S) -> Self
where
Expand All @@ -50,8 +57,9 @@ where

let cid_gen = Mutex::new(StdConnectionIdGenerator::new());

let awaiting: HashMap<ConnectionId<P>, Accept<P>> = HashMap::new();
let awaiting = Arc::new(RwLock::new(awaiting));
let mut awaiting_cid: HashMapDelay<ConnectionId<P>, Accept<P>> =
HashMapDelay::new(DEFAULT_ACCEPT_TIMEOUT);
let mut awaiting: DelayQueue<Accept<P>> = DelayQueue::new();

let mut incoming_conns = HashMap::new();

Expand Down Expand Up @@ -92,31 +100,42 @@ where
None => {
if std::matches!(packet.packet_type(), PacketType::Syn) {
let cid = cid_from_packet(&packet, &src, true);
let mut awaiting = awaiting.write().unwrap();

// If there was an awaiting connection with the CID, then
// create a new stream for that connection. Otherwise, add the
// connection to the incoming connections.
if let Some(accept) = awaiting.remove(&cid) {
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();

conns.insert(cid.clone(), events_tx);

let stream = UtpStream::new(
cid,
accept.config,
Some(packet),
socket_event_tx.clone(),
events_rx,
connected_tx
);

tokio::spawn(async move {
Self::await_connected(stream, accept, connected_rx).await
});

// First check whether there is a pending accept that specifies
// `cid`. If no pending accept was found, then try to fulfill
// the pending accept with the nearest timeout deadline.
let accept = if let Some(accept) = awaiting_cid.remove(&cid) {
Some(accept)
} else {
incoming_conns.insert(cid, packet);
awaiting.peek().map(|key| awaiting.remove(&key).into_inner())
};

// If there was a suitable waiting accept, then create a new
// stream for the connection. Otherwise, add the CID and SYN to
// the incoming connections.
match accept {
Some(accept) => {
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();

conns.insert(cid.clone(), events_tx);

let stream = UtpStream::new(
cid,
accept.config,
Some(packet),
socket_event_tx.clone(),
events_rx,
connected_tx
);

tokio::spawn(async move {
Self::await_connected(stream, accept, connected_rx).await
});
}
None => {
incoming_conns.insert(cid, packet);
}
}
} else {
tracing::debug!(
Expand Down Expand Up @@ -151,7 +170,26 @@ where
}
}
}
Some((accept, cid)) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
Some((accept, cid, deadline)) = accepts_rx.recv() => {
// If the deadline has passed, then send the timeout to the acceptor.
let now = Instant::now();
if deadline < now {
// Drop the accept sender. A timeout error will be sent back to the
// caller.
tracing::warn!("accept timed out, dropping accept attempt");
continue;
}

// Compute the timeout duration. Given the check above, the subtraction
// cannot fail.
let timeout = deadline - now;

// If there are no incoming connections, then queue the accept.
if incoming_conns.is_empty() {
awaiting.insert(accept, timeout);
continue;
}

let (cid, syn) = match cid {
// If a CID was given, then check for an incoming connection with that
// CID. If one is found, then use that connection. Otherwise, add the
Expand All @@ -160,13 +198,13 @@ where
if let Some(syn) = incoming_conns.remove(&cid) {
(cid, syn)
} else {
awaiting.write().unwrap().insert(cid, accept);
awaiting_cid.insert_at(cid, accept, timeout);
continue;
}
}
// If a CID was not given, then pull an incoming connection, and use
// that connection's CID. An incoming connection is known to exist
// because of the condition in the `select` arm.
// because of the check above.
None => {
let cid = incoming_conns.keys().next().unwrap().clone();
let syn = incoming_conns.remove(&cid).unwrap();
Expand All @@ -193,10 +231,19 @@ where
connected_tx,
);


tokio::spawn(async move {
Self::await_connected(stream, accept, connected_rx).await
});
}
Some(Ok(_accept)) = awaiting_cid.next() => {
// The accept timed out, so drop it.
continue
}
Some(_accept) = awaiting.next() => {
// The accept timed out, so drop it.
continue
}
}
}
});
Expand All @@ -214,8 +261,10 @@ where
stream: stream_tx,
config,
};

let deadline = Instant::now() + accept.config.initial_timeout;
self.accepts
.send((accept, None))
.send((accept, None, deadline))
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
match stream_rx.await {
Ok(stream) => Ok(stream?),
Expand All @@ -233,8 +282,10 @@ where
stream: stream_tx,
config,
};

let deadline = Instant::now() + accept.config.initial_timeout;
self.accepts
.send((accept, Some(cid)))
.send((accept, Some(cid), deadline))
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
match stream_rx.await {
Ok(stream) => Ok(stream?),
Expand Down