diff --git a/iroh-net/src/endpoint.rs b/iroh-net/src/endpoint.rs index 8915bdfa73..d13a6b1bf0 100644 --- a/iroh-net/src/endpoint.rs +++ b/iroh-net/src/endpoint.rs @@ -1274,8 +1274,9 @@ mod tests { #[tokio::test] async fn endpoint_conn_type_stream() { + const TIMEOUT: Duration = std::time::Duration::from_secs(15); let _logging_guard = iroh_test::logging::setup(); - let (relay_map, relay_url, _relay_guard) = run_relay_server().await.unwrap(); + let (relay_map, _relay_url, _relay_guard) = run_relay_server().await.unwrap(); let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); let ep1_secret_key = SecretKey::generate_with_rng(&mut rng); let ep2_secret_key = SecretKey::generate_with_rng(&mut rng); @@ -1296,31 +1297,25 @@ mod tests { .await .unwrap(); - async fn handle_direct_conn(ep: Endpoint, node_id: PublicKey) -> Result<()> { - let node_addr = NodeAddr::new(node_id); - ep.add_node_addr(node_addr)?; - let stream = ep.conn_type_stream(&node_id)?; - async fn get_direct_event( - src: &PublicKey, - dst: &PublicKey, - mut stream: ConnectionTypeStream, - ) -> Result<()> { - let src = src.fmt_short(); - let dst = dst.fmt_short(); - while let Some(conn_type) = stream.next().await { - tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type); - if matches!(conn_type, ConnectionType::Direct(_)) { - return Ok(()); - } + async fn handle_direct_conn(ep: &Endpoint, node_id: PublicKey) -> Result<()> { + let mut stream = ep.conn_type_stream(&node_id)?; + let src = ep.node_id().fmt_short(); + let dst = node_id.fmt_short(); + while let Some(conn_type) = stream.next().await { + tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type); + if matches!(conn_type, ConnectionType::Direct(_)) { + return Ok(()); } - anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); } - tokio::time::timeout( - Duration::from_secs(15), - get_direct_event(&ep.node_id(), &node_id, stream), - ) - .await??; - Ok(()) + anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); + } + + async fn accept(ep: &Endpoint) -> NodeId { + let incoming = ep.accept().await.unwrap(); + let conn = incoming.await.unwrap(); + let node_id = get_remote_node_id(&conn).unwrap(); + tracing::info!(node_id=%node_id.fmt_short(), "accepted connection"); + node_id } let ep1_nodeid = ep1.node_id(); @@ -1333,39 +1328,31 @@ mod tests { ); tracing::info!("node id 2 {ep2_nodeid}"); - let res_ep1 = tokio::spawn(handle_direct_conn(ep1.clone(), ep2_nodeid)); + let ep1_side = async move { + accept(&ep1).await; + handle_direct_conn(&ep1, ep2_nodeid).await + }; + + let ep2_side = async move { + ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); + handle_direct_conn(&ep2, ep1_nodeid).await + }; + + let res_ep1 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep1_side)); let ep1_abort_handle = res_ep1.abort_handle(); let _ep1_guard = CallOnDrop::new(move || { ep1_abort_handle.abort(); }); - let res_ep2 = tokio::spawn(handle_direct_conn(ep2.clone(), ep1_nodeid)); + let res_ep2 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep2_side)); let ep2_abort_handle = res_ep2.abort_handle(); let _ep2_guard = CallOnDrop::new(move || { ep2_abort_handle.abort(); }); - async fn accept(ep: Endpoint) -> NodeId { - let incoming = ep.accept().await.unwrap(); - let conn = incoming.await.unwrap(); - get_remote_node_id(&conn).unwrap() - } - - // create a node addr with no direct connections - let ep1_nodeaddr = NodeAddr::from_parts(ep1_nodeid, Some(relay_url), vec![]); - - let accept_res = tokio::spawn(accept(ep1.clone())); - let accept_abort_handle = accept_res.abort_handle(); - let _accept_guard = CallOnDrop::new(move || { - accept_abort_handle.abort(); - }); - - let _conn_2 = ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); - - let got_id = accept_res.await.unwrap(); - assert_eq!(ep2_nodeid, got_id); - res_ep1.await.unwrap().unwrap(); - res_ep2.await.unwrap().unwrap(); + let (r1, r2) = tokio::try_join!(res_ep1, res_ep2).unwrap(); + r1.expect("ep1 timeout").unwrap(); + r2.expect("ep2 timeout").unwrap(); } }