diff --git a/vsock_proxy/src/dns.rs b/vsock_proxy/src/dns.rs index bd809972..d0935059 100644 --- a/vsock_proxy/src/dns.rs +++ b/vsock_proxy/src/dns.rs @@ -1,18 +1,53 @@ // Copyright 2019-2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 + #![deny(warnings)] -/// Contains code for Proxy, a library used for translating vsock traffic to -/// TCP traffic -/// +use std::net::IpAddr; + +use chrono::{DateTime, Duration, Utc}; use hickory_resolver::config::*; use hickory_resolver::Resolver; use idna::domain_to_ascii; -use crate::{DnsResolveResult, IpAddrType, VsockProxyResult}; +use crate::{IpAddrType, VsockProxyResult}; + +/// `DnsResolutionInfo` represents DNS resolution information, including the resolved +/// IP address, TTL value and last resolution time. +#[derive(Copy, Clone, Debug)] +pub struct DnsResolutionInfo { + /// The IP address that the hostname was resolved to. + ip_addr: IpAddr, + /// The configured duration after which the DNS resolution should be refreshed. + ttl: Duration, + /// The timestamp representing the last time the DNS resolution was performed. + last_dns_resolution_time: DateTime, +} + +impl DnsResolutionInfo { + pub fn is_expired(&self) -> bool { + Utc::now() > self.last_dns_resolution_time + self.ttl + } + + fn new(new_ip_addr: IpAddr, new_ttl: Duration) -> Self { + DnsResolutionInfo { + ip_addr: new_ip_addr, + ttl: new_ttl, + last_dns_resolution_time: Utc::now(), + } + } + + pub fn ip_addr(&self) -> IpAddr { + self.ip_addr + } + + pub fn ttl(&self) -> Duration { + self.ttl + } +} /// Resolve a DNS name (IDNA format) into multiple IP addresses (v4 or v6) -pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult> { +pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult> { // IDNA parsing let addr = domain_to_ascii(addr).map_err(|_| "Could not parse domain name")?; @@ -21,7 +56,7 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult = resolver + let rresults: Vec = resolver .lookup_ip(addr) .map_err(|_| "DNS lookup failed!")? .as_lookup() @@ -29,9 +64,9 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult VsockProxyResult, Vec<_>) = - rresults.into_iter().partition(|result| result.ip.is_ipv4()); + let (rresults_with_ipv4, rresults_with_ipv6): (Vec<_>, Vec<_>) = rresults + .into_iter() + .partition(|result| result.ip_addr().is_ipv4()); if IpAddrType::IPAddrV4Only == ip_addr_type && !rresults_with_ipv4.is_empty() { Ok(rresults_with_ipv4) @@ -61,7 +97,7 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult VsockProxyResult { +pub fn resolve_single(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult { let rresults = resolve(addr, ip_addr_type)?; // Return the first resolved IP address and its TTL value. rresults @@ -127,14 +163,14 @@ mod tests { fn test_resolve_ipv4_only() { let domain = unsafe { IPV4_ONLY_TEST_DOMAIN }; let rresults = resolve(domain, IpAddrType::IPAddrV4Only).unwrap(); - assert!(rresults.iter().all(|item| item.ip.is_ipv4())); + assert!(rresults.iter().all(|item| item.ip_addr().is_ipv4())); } #[test] fn test_resolve_ipv6_only() { let domain = unsafe { IPV6_ONLY_TEST_DOMAIN }; let rresults = resolve(domain, IpAddrType::IPAddrV6Only).unwrap(); - assert!(rresults.iter().all(|item| item.ip.is_ipv6())); + assert!(rresults.iter().all(|item| item.ip_addr().is_ipv6())); } #[test] @@ -148,7 +184,7 @@ mod tests { fn test_resolve_single_address() { let domain = unsafe { IPV4_ONLY_TEST_DOMAIN }; let rresult = resolve_single(domain, IpAddrType::IPAddrMixed).unwrap(); - assert!(rresult.ip.is_ipv4()); - assert!(rresult.ttl != 0); + assert!(rresult.ip_addr().is_ipv4()); + assert!(rresult.ttl != Duration::seconds(0)); } } diff --git a/vsock_proxy/src/lib.rs b/vsock_proxy/src/lib.rs index b6c4e3a3..492cadb6 100644 --- a/vsock_proxy/src/lib.rs +++ b/vsock_proxy/src/lib.rs @@ -4,8 +4,6 @@ pub mod dns; pub mod proxy; -use std::net::IpAddr; - #[derive(Copy, Clone, PartialEq)] pub enum IpAddrType { /// Only allows IP4 addresses @@ -16,13 +14,5 @@ pub enum IpAddrType { IPAddrMixed, } -#[derive(Copy, Clone, Debug)] -pub struct DnsResolveResult { - ///Resolved address - pub ip: IpAddr, - ///DNS TTL value - pub ttl: u32, -} - /// The most common result type provided by VsockProxy operations. pub type VsockProxyResult = Result; diff --git a/vsock_proxy/src/proxy.rs b/vsock_proxy/src/proxy.rs index 4ea49240..92d291b6 100644 --- a/vsock_proxy/src/proxy.rs +++ b/vsock_proxy/src/proxy.rs @@ -4,7 +4,6 @@ /// Contains code for Proxy, a library used for translating vsock traffic to /// TCP traffic -use chrono::{DateTime, Duration, Utc}; use log::{info, warn}; use nix::sys::select::{select, FdSet}; use nix::sys::socket::SockType; @@ -16,6 +15,7 @@ use threadpool::ThreadPool; use vsock::{VsockAddr, VsockListener}; use yaml_rust::YamlLoader; +use crate::dns::DnsResolutionInfo; use crate::{dns, IpAddrType, VsockProxyResult}; const BUFF_SIZE: usize = 8192; @@ -43,7 +43,7 @@ pub fn check_allowlist( // Obtain the remote server's IP address. let dns_result = dns::resolve_single(remote_host, ip_addr_type)?; - let remote_addr = dns_result.ip; + let remote_addr = dns_result.ip_addr(); for raw_service in services { let addr = raw_service["address"].as_str().ok_or("No address field")?; @@ -69,7 +69,7 @@ pub fn check_allowlist( let remote_addr_matched = rresults .into_iter() .flatten() - .find(|rresult| rresult.ip == remote_addr) + .find(|rresult| rresult.ip_addr() == remote_addr) .map(|_| remote_addr); match remote_addr_matched { @@ -89,10 +89,8 @@ pub fn check_allowlist( pub struct Proxy { local_port: u32, remote_host: String, - remote_addr: Option, remote_port: u16, - dns_resolve_date: Option>, - dns_refresh_interval: Option, + dns_resolution_info: Option, pool: ThreadPool, sock_type: SockType, ip_addr_type: IpAddrType, @@ -108,17 +106,13 @@ impl Proxy { ) -> VsockProxyResult { let pool = ThreadPool::new(num_workers); let sock_type = SockType::Stream; - let remote_addr: Option = None; - let dns_resolve_date: Option> = None; - let dns_refresh_interval: Option = None; + let dns_resolution_info: Option = None; Ok(Proxy { local_port, remote_host, - remote_addr, remote_port, - dns_resolve_date, - dns_refresh_interval, + dns_resolution_info, pool, sock_type, ip_addr_type, @@ -145,28 +139,31 @@ impl Proxy { .map_err(|_| "Could not accept connection")?; info!("Accepted connection on {:?}", client_addr); - let needs_resolve = - |d: DateTime, i: Duration| (Utc::now() - d + Duration::seconds(2)) > i; + let dns_needs_resolution = self + .dns_resolution_info + .map_or(true, |info| info.is_expired()); - if self.dns_resolve_date.is_none() - || needs_resolve( - self.dns_resolve_date.unwrap(), - self.dns_refresh_interval.unwrap(), - ) - { + let remote_addr = if dns_needs_resolution { info!("Resolving hostname: {}.", self.remote_host); - let result = dns::resolve_single(&self.remote_host, self.ip_addr_type)?; - self.dns_resolve_date = Some(Utc::now()); - self.dns_refresh_interval = Some(Duration::seconds(result.ttl as i64)); - self.remote_addr = Some(result.ip); + + let dns_resolution = dns::resolve_single(&self.remote_host, self.ip_addr_type)?; info!( "Using IP \"{:?}\" for the given server \"{}\". (TTL: {} secs)", - result.ip, self.remote_host, result.ttl + dns_resolution.ip_addr(), + self.remote_host, + dns_resolution.ttl().num_seconds() ); - } - let sockaddr = SocketAddr::new(self.remote_addr.unwrap(), self.remote_port); + self.dns_resolution_info = Some(dns_resolution.clone()); + dns_resolution.ip_addr() + } else { + self.dns_resolution_info + .ok_or("DNS resolution failed!")? + .ip_addr() + }; + + let sockaddr = SocketAddr::new(remote_addr, self.remote_port); let sock_type = self.sock_type; self.pool.execute(move || { let mut server = match sock_type {