diff --git a/src/address_family.rs b/src/address_family.rs index 75cbaea..f49e966 100644 --- a/src/address_family.rs +++ b/src/address_family.rs @@ -1,5 +1,5 @@ use super::MDNS_PORT; -use if_addrs::{get_if_addrs, Interface}; +use if_addrs::{get_if_addrs, IfAddr}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::collections::HashSet; use std::io; @@ -47,14 +47,13 @@ impl AddressFamily for Inet { const DOMAIN: Domain = Domain::IPV4; fn join_multicast(socket: &Socket, multiaddr: &Self::Addr) -> io::Result<()> { - let ifaces = get_iface_list()?; - if ifaces.is_empty() { + let addrs = get_one_nonloopback_ipv4_addr_per_iface()?; + if addrs.is_empty() { socket.join_multicast_v4(multiaddr, &Ipv4Addr::UNSPECIFIED) } else { - for iface in ifaces { - if let IpAddr::V4(ip) = iface.ip() { - socket.join_multicast_v4(multiaddr, &ip)?; - } + // TODO: If any join succeeds return success (log failures) + for ip in addrs { + socket.join_multicast_v4(multiaddr, &ip)?; } Ok(()) } @@ -70,28 +69,62 @@ impl AddressFamily for Inet6 { const DOMAIN: Domain = Domain::IPV6; fn join_multicast(socket: &Socket, multiaddr: &Self::Addr) -> io::Result<()> { - let ifaces = get_iface_list()?; - if ifaces.is_empty() { + let indexes = get_one_nonloopback_ipv6_index_per_iface()?; + if indexes.is_empty() { socket.join_multicast_v6(multiaddr, 0) } else { - // TODO: Make each interface resilient to failures on another. - for iface in ifaces { - if let (IpAddr::V6(_), Some(ipv6_index)) = (iface.ip(), iface.index) { - socket.join_multicast_v6(multiaddr, ipv6_index)?; - } + // TODO: If any join succeeds return success (log failures) + for ipv6_index in indexes { + socket.join_multicast_v6(multiaddr, ipv6_index)?; } Ok(()) } } } -fn get_iface_list() -> io::Result> { +fn get_one_nonloopback_ipv6_index_per_iface() -> io::Result> { + // There may be multiple ip addresses on a single interface and we join multicast by interface. + // Joining multicast on the same interface multiple times returns an error + // so we filter duplicate interfaces. + let mut collected_interfaces = HashSet::new(); + Ok(get_if_addrs()? + .into_iter() + .filter_map(|iface| { + if iface.is_loopback() { + None + } else if matches!(iface.addr, IfAddr::V6(_)) { + if collected_interfaces.insert(iface.name.clone()) { + iface.index + } else { + None + } + } else { + None + } + }) + .collect()) +} + +fn get_one_nonloopback_ipv4_addr_per_iface() -> io::Result> { // There may be multiple ip addresses on a single interface and we join multicast by interface. // Joining multicast on the same interface multiple times returns an error // so we filter duplicate interfaces. let mut collected_interfaces = HashSet::new(); Ok(get_if_addrs()? .into_iter() - .filter(|iface| !iface.is_loopback() && collected_interfaces.insert(iface.name.clone())) + .filter_map(|iface| { + if iface.is_loopback() { + None + } else if let IpAddr::V4(ip) = iface.ip() { + if collected_interfaces.insert(iface.name.clone()) { + Some(ip) + } else { + None + } + } else { + None + } + }) .collect()) } +