Skip to content

Commit

Permalink
Force close conn loop when socket channel dies
Browse files Browse the repository at this point in the history
  • Loading branch information
carver committed Aug 2, 2023
1 parent 84fb93b commit 669d0de
Showing 1 changed file with 66 additions and 30 deletions.
96 changes: 66 additions & 30 deletions src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum Error {
Reset,
SynFromAcceptor,
TimedOut,
InternalError,
}

impl fmt::Display for Error {
Expand All @@ -42,6 +43,7 @@ impl fmt::Display for Error {
Self::Reset => "received RESET packet from remote peer",
Self::SynFromAcceptor => "received SYN packet from connection acceptor",
Self::TimedOut => "connection timed out",
Self::InternalError => "utp library has an unexpected state, and cannot continue",
};

write!(f, "{s}")
Expand All @@ -56,6 +58,7 @@ impl From<Error> for io::ErrorKind {
| SynFromAcceptor => io::ErrorKind::InvalidData,
Reset => io::ErrorKind::ConnectionReset,
TimedOut => io::ErrorKind::TimedOut,
InternalError => io::ErrorKind::Other,
}
}
}
Expand Down Expand Up @@ -256,7 +259,10 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
let idle_deadline = tokio::time::Instant::now() + self.config.max_idle_timeout;
idle_timeout.as_mut().reset(idle_deadline);

self.on_packet(&packet, Instant::now());
if let Err(err) = self.on_packet(&packet, Instant::now()) {
tracing::error!("Unrecoverable error while processing an incoming packet, shutting down: {err}");
shutting_down = true;
}
}
StreamEvent::Shutdown => {
shutting_down = true;
Expand All @@ -267,14 +273,20 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
let (seq, packet) = timeout;
tracing::debug!(seq, ack = %packet.ack_num(), packet = ?packet.packet_type(), "timeout");

self.on_timeout(packet, Instant::now());
if let Err(err) = self.on_timeout(packet, Instant::now()) {
tracing::error!("Unrecoverable error while processing a timeout, shutting down: {err}");
shutting_down = true;
}
}
Some(write) = writes.recv(), if !shutting_down => {
// Reset the idle timeout on any new write.
let idle_deadline = tokio::time::Instant::now() + self.config.max_idle_timeout;
idle_timeout.as_mut().reset(idle_deadline);

self.on_write(write);
if let Err(err) = self.on_write(write) {
tracing::error!("Unrecoverable error while receiving a write, shutting down: {err}");
shutting_down = true;
}
}
Some(read) = reads.recv() => {
self.on_read(read);
Expand All @@ -283,7 +295,10 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
self.process_reads();
}
_ = self.writable.notified() => {
self.process_writes(Instant::now());
if let Err(err) = self.process_writes(Instant::now()) {
tracing::error!("Unrecoverable error while processing writes, shutting down: {err}");
shutting_down = true;
}
}
() = &mut idle_timeout => {
if !std::matches!(self.state, State::Closed { .. }) {
Expand All @@ -307,7 +322,11 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
tracing::debug!(?err, "uTP conn closing...");

self.process_reads();
self.process_writes(Instant::now());
if let Err(err) = self.process_writes(Instant::now()) {
tracing::warn!("unable to process writes during shutdown: {err}");
// We already know that we can't send socket events, so skip the next one...
break;
}

if let Err(..) = self
.socket_events
Expand Down Expand Up @@ -352,14 +371,21 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
local_fin = Some(seq_num);

tracing::debug!(seq = %seq_num, "transmitting FIN");
Self::transmit(
let tx_result = Self::transmit(
sent_packets,
&mut self.unacked,
&mut self.socket_events,
fin,
&self.cid.peer,
Instant::now(),
);
if let Err(err) = tx_result {
tracing::error!(?err, "while transmitting FIN, in Established state");
self.state = State::Closed {
err: Some(Error::InternalError),
};
return;
}
}

self.state = State::Closing {
Expand Down Expand Up @@ -400,23 +426,29 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
*local_fin = Some(seq_num);

tracing::debug!(seq = %seq_num, "transmitting FIN");
Self::transmit(
let tx_result = Self::transmit(
sent_packets,
&mut self.unacked,
&mut self.socket_events,
fin,
&self.cid.peer,
Instant::now(),
);
if let Err(err) = tx_result {
tracing::warn!(?err, "while transmitting FIN, in Closing state");
self.state = State::Closed {
err: Some(Error::InternalError),
};
}
}
}
State::Closed { .. } => {}
}
}

fn process_writes(&mut self, now: Instant) {
fn process_writes(&mut self, now: Instant) -> Result<(), String> {
let (send_buf, sent_packets, recv_buf) = match &mut self.state {
State::Connecting(..) => return,
State::Connecting(..) => return Ok(()),
State::Established {
send_buf,
sent_packets,
Expand All @@ -438,7 +470,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
let (.., tx) = pending;
let _ = tx.send(result.map_err(io::Error::from));
}
return;
return Ok(());
}
};

Expand Down Expand Up @@ -498,12 +530,13 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
packet,
&self.cid.peer,
now,
);
)?;
seq_num = seq_num.wrapping_add(1);
}
Ok(())
}

fn on_write(&mut self, write: Write) {
fn on_write(&mut self, write: Write) -> Result<(), String> {
let (data, tx) = write;

match &mut self.state {
Expand All @@ -530,9 +563,10 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
}
}

self.process_writes(Instant::now());
self.process_writes(Instant::now())?;

self.writable.notify_waiters();
Ok(())
}

fn process_reads(&mut self) {
Expand Down Expand Up @@ -607,7 +641,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
}
}

fn on_timeout(&mut self, packet: Packet, now: Instant) {
fn on_timeout(&mut self, packet: Packet, now: Instant) -> Result<(), String> {
match &mut self.state {
State::Connecting(connected) => match &mut self.endpoint {
Endpoint::Initiator((syn, attempts)) => {
Expand Down Expand Up @@ -645,7 +679,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
// If the timed out packet is a SYN, then do nothing since the connection has
// already been established.
if std::matches!(packet.packet_type(), PacketType::Syn) {
return;
return Ok(());
}

// To prevent timeout amplification in the event that a batch of packets sent near
Expand Down Expand Up @@ -682,13 +716,14 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
packet,
&self.cid.peer,
now,
);
)?;
}
State::Closed { .. } => {}
}
Ok(())
}

fn on_packet(&mut self, packet: &Packet, now: Instant) {
fn on_packet(&mut self, packet: &Packet, now: Instant) -> Result<(), String> {
let now_micros = crate::time::now_micros();
self.peer_recv_window = packet.window_size();

Expand Down Expand Up @@ -729,7 +764,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
}

// If there are any lost packets, then queue retransmissions.
self.retransmit_lost_packets(now);
self.retransmit_lost_packets(now)?;

// Send a STATE packet if appropriate packet and connection in appropriate state.
//
Expand All @@ -740,10 +775,9 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
PacketType::Syn | PacketType::Fin | PacketType::Data => {
if let Some(state) = self.state_packet() {
let event = SocketEvent::Outgoing((state, self.cid.peer.clone()));
if self.socket_events.send(event).is_err() {
tracing::warn!("Cannot transmit state packet: socket closed channel");
return;
}
self.socket_events
.send(event)
.map_err(|_| "cannot transmit packet: socket closed channel".to_string())?
}
}
PacketType::State | PacketType::Reset => {}
Expand Down Expand Up @@ -775,6 +809,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
self.state = State::Closed { err: None };
}
}
Ok(())
}

fn on_syn(&mut self, seq_num: u16) {
Expand Down Expand Up @@ -1043,9 +1078,9 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
}
}

fn retransmit_lost_packets(&mut self, now: Instant) {
fn retransmit_lost_packets(&mut self, now: Instant) -> Result<(), String> {
let (sent_packets, recv_buf) = match &mut self.state {
State::Connecting(..) | State::Closed { .. } => return,
State::Connecting(..) | State::Closed { .. } => return Ok(()),
State::Established {
sent_packets,
recv_buf,
Expand All @@ -1059,7 +1094,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
};

if !sent_packets.has_lost_packets() {
return;
return Ok(());
}

let conn_id = self.cid.send;
Expand All @@ -1085,8 +1120,9 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
packet,
&self.cid.peer,
now,
);
)?;
}
Ok(())
}

fn transmit(
Expand All @@ -1096,7 +1132,7 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
packet: Packet,
dest: &P,
now: Instant,
) {
) -> Result<(), String> {
let (payload, len) = if packet.payload().is_empty() {
(None, 0)
} else {
Expand All @@ -1109,9 +1145,9 @@ impl<const N: usize, P: ConnectionPeer> Connection<N, P> {
sent_packets.on_transmit(packet.seq_num(), packet.packet_type(), payload, len, now);
unacked.insert_at(packet.seq_num(), packet.clone(), sent_packets.timeout());
let outbound = SocketEvent::Outgoing((packet, dest.clone()));
if socket_events.send(outbound).is_err() {
tracing::warn!("Cannot transmit packet: socket closed channel");
}
socket_events
.send(outbound)
.map_err(|_| "cannot transmit packet: socket closed channel".to_string())
}
}

Expand Down

0 comments on commit 669d0de

Please sign in to comment.