From 669d0de8217c0b7bfd5b84e9a6ebc5c0e3ed430b Mon Sep 17 00:00:00 2001 From: Jason Carver Date: Tue, 1 Aug 2023 20:32:04 -0700 Subject: [PATCH] Force close conn loop when socket channel dies --- src/conn.rs | 96 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 30 deletions(-) diff --git a/src/conn.rs b/src/conn.rs index 6778290..be272a3 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -27,6 +27,7 @@ enum Error { Reset, SynFromAcceptor, TimedOut, + InternalError, } impl fmt::Display for Error { @@ -42,6 +43,7 @@ impl fmt::Display for Error { Self::Reset => "received RESET packet from remote peer", Self::SynFromAcceptor => "received SYN packet from connection acceptor", Self::TimedOut => "connection timed out", + Self::InternalError => "utp library has an unexpected state, and cannot continue", }; write!(f, "{s}") @@ -56,6 +58,7 @@ impl From for io::ErrorKind { | SynFromAcceptor => io::ErrorKind::InvalidData, Reset => io::ErrorKind::ConnectionReset, TimedOut => io::ErrorKind::TimedOut, + InternalError => io::ErrorKind::Other, } } } @@ -256,7 +259,10 @@ impl Connection { let idle_deadline = tokio::time::Instant::now() + self.config.max_idle_timeout; idle_timeout.as_mut().reset(idle_deadline); - self.on_packet(&packet, Instant::now()); + if let Err(err) = self.on_packet(&packet, Instant::now()) { + tracing::error!("Unrecoverable error while processing an incoming packet, shutting down: {err}"); + shutting_down = true; + } } StreamEvent::Shutdown => { shutting_down = true; @@ -267,14 +273,20 @@ impl Connection { let (seq, packet) = timeout; tracing::debug!(seq, ack = %packet.ack_num(), packet = ?packet.packet_type(), "timeout"); - self.on_timeout(packet, Instant::now()); + if let Err(err) = self.on_timeout(packet, Instant::now()) { + tracing::error!("Unrecoverable error while processing a timeout, shutting down: {err}"); + shutting_down = true; + } } Some(write) = writes.recv(), if !shutting_down => { // Reset the idle timeout on any new write. let idle_deadline = tokio::time::Instant::now() + self.config.max_idle_timeout; idle_timeout.as_mut().reset(idle_deadline); - self.on_write(write); + if let Err(err) = self.on_write(write) { + tracing::error!("Unrecoverable error while receiving a write, shutting down: {err}"); + shutting_down = true; + } } Some(read) = reads.recv() => { self.on_read(read); @@ -283,7 +295,10 @@ impl Connection { self.process_reads(); } _ = self.writable.notified() => { - self.process_writes(Instant::now()); + if let Err(err) = self.process_writes(Instant::now()) { + tracing::error!("Unrecoverable error while processing writes, shutting down: {err}"); + shutting_down = true; + } } () = &mut idle_timeout => { if !std::matches!(self.state, State::Closed { .. }) { @@ -307,7 +322,11 @@ impl Connection { tracing::debug!(?err, "uTP conn closing..."); self.process_reads(); - self.process_writes(Instant::now()); + if let Err(err) = self.process_writes(Instant::now()) { + tracing::warn!("unable to process writes during shutdown: {err}"); + // We already know that we can't send socket events, so skip the next one... + break; + } if let Err(..) = self .socket_events @@ -352,7 +371,7 @@ impl Connection { local_fin = Some(seq_num); tracing::debug!(seq = %seq_num, "transmitting FIN"); - Self::transmit( + let tx_result = Self::transmit( sent_packets, &mut self.unacked, &mut self.socket_events, @@ -360,6 +379,13 @@ impl Connection { &self.cid.peer, Instant::now(), ); + if let Err(err) = tx_result { + tracing::error!(?err, "while transmitting FIN, in Established state"); + self.state = State::Closed { + err: Some(Error::InternalError), + }; + return; + } } self.state = State::Closing { @@ -400,7 +426,7 @@ impl Connection { *local_fin = Some(seq_num); tracing::debug!(seq = %seq_num, "transmitting FIN"); - Self::transmit( + let tx_result = Self::transmit( sent_packets, &mut self.unacked, &mut self.socket_events, @@ -408,15 +434,21 @@ impl Connection { &self.cid.peer, Instant::now(), ); + if let Err(err) = tx_result { + tracing::warn!(?err, "while transmitting FIN, in Closing state"); + self.state = State::Closed { + err: Some(Error::InternalError), + }; + } } } State::Closed { .. } => {} } } - fn process_writes(&mut self, now: Instant) { + fn process_writes(&mut self, now: Instant) -> Result<(), String> { let (send_buf, sent_packets, recv_buf) = match &mut self.state { - State::Connecting(..) => return, + State::Connecting(..) => return Ok(()), State::Established { send_buf, sent_packets, @@ -438,7 +470,7 @@ impl Connection { let (.., tx) = pending; let _ = tx.send(result.map_err(io::Error::from)); } - return; + return Ok(()); } }; @@ -498,12 +530,13 @@ impl Connection { packet, &self.cid.peer, now, - ); + )?; seq_num = seq_num.wrapping_add(1); } + Ok(()) } - fn on_write(&mut self, write: Write) { + fn on_write(&mut self, write: Write) -> Result<(), String> { let (data, tx) = write; match &mut self.state { @@ -530,9 +563,10 @@ impl Connection { } } - self.process_writes(Instant::now()); + self.process_writes(Instant::now())?; self.writable.notify_waiters(); + Ok(()) } fn process_reads(&mut self) { @@ -607,7 +641,7 @@ impl Connection { } } - fn on_timeout(&mut self, packet: Packet, now: Instant) { + fn on_timeout(&mut self, packet: Packet, now: Instant) -> Result<(), String> { match &mut self.state { State::Connecting(connected) => match &mut self.endpoint { Endpoint::Initiator((syn, attempts)) => { @@ -645,7 +679,7 @@ impl Connection { // If the timed out packet is a SYN, then do nothing since the connection has // already been established. if std::matches!(packet.packet_type(), PacketType::Syn) { - return; + return Ok(()); } // To prevent timeout amplification in the event that a batch of packets sent near @@ -682,13 +716,14 @@ impl Connection { packet, &self.cid.peer, now, - ); + )?; } State::Closed { .. } => {} } + Ok(()) } - fn on_packet(&mut self, packet: &Packet, now: Instant) { + fn on_packet(&mut self, packet: &Packet, now: Instant) -> Result<(), String> { let now_micros = crate::time::now_micros(); self.peer_recv_window = packet.window_size(); @@ -729,7 +764,7 @@ impl Connection { } // If there are any lost packets, then queue retransmissions. - self.retransmit_lost_packets(now); + self.retransmit_lost_packets(now)?; // Send a STATE packet if appropriate packet and connection in appropriate state. // @@ -740,10 +775,9 @@ impl Connection { PacketType::Syn | PacketType::Fin | PacketType::Data => { if let Some(state) = self.state_packet() { let event = SocketEvent::Outgoing((state, self.cid.peer.clone())); - if self.socket_events.send(event).is_err() { - tracing::warn!("Cannot transmit state packet: socket closed channel"); - return; - } + self.socket_events + .send(event) + .map_err(|_| "cannot transmit packet: socket closed channel".to_string())? } } PacketType::State | PacketType::Reset => {} @@ -775,6 +809,7 @@ impl Connection { self.state = State::Closed { err: None }; } } + Ok(()) } fn on_syn(&mut self, seq_num: u16) { @@ -1043,9 +1078,9 @@ impl Connection { } } - fn retransmit_lost_packets(&mut self, now: Instant) { + fn retransmit_lost_packets(&mut self, now: Instant) -> Result<(), String> { let (sent_packets, recv_buf) = match &mut self.state { - State::Connecting(..) | State::Closed { .. } => return, + State::Connecting(..) | State::Closed { .. } => return Ok(()), State::Established { sent_packets, recv_buf, @@ -1059,7 +1094,7 @@ impl Connection { }; if !sent_packets.has_lost_packets() { - return; + return Ok(()); } let conn_id = self.cid.send; @@ -1085,8 +1120,9 @@ impl Connection { packet, &self.cid.peer, now, - ); + )?; } + Ok(()) } fn transmit( @@ -1096,7 +1132,7 @@ impl Connection { packet: Packet, dest: &P, now: Instant, - ) { + ) -> Result<(), String> { let (payload, len) = if packet.payload().is_empty() { (None, 0) } else { @@ -1109,9 +1145,9 @@ impl Connection { sent_packets.on_transmit(packet.seq_num(), packet.packet_type(), payload, len, now); unacked.insert_at(packet.seq_num(), packet.clone(), sent_packets.timeout()); let outbound = SocketEvent::Outgoing((packet, dest.clone())); - if socket_events.send(outbound).is_err() { - tracing::warn!("Cannot transmit packet: socket closed channel"); - } + socket_events + .send(outbound) + .map_err(|_| "cannot transmit packet: socket closed channel".to_string()) } }