diff --git a/Cargo.lock b/Cargo.lock index d7437b25e8..af7b2e0f66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1854,6 +1854,22 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "humantime-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57a3db5ea5923d99402c94e9feb261dc5ee9b4efa158b0315f788cf549cc200c" +dependencies = [ + "humantime", + "serde", +] + [[package]] name = "hyper" version = "1.5.1" @@ -2335,6 +2351,7 @@ dependencies = [ "hickory-resolver", "hickory-server", "http 1.1.0", + "humantime-serde", "iroh", "iroh-metrics", "iroh-test 0.29.0", @@ -2349,6 +2366,7 @@ dependencies = [ "serde", "struct_iterable", "strum", + "testresult", "tokio", "tokio-rustls", "tokio-rustls-acme", diff --git a/iroh-dns-server/Cargo.toml b/iroh-dns-server/Cargo.toml index 4b7900760e..a0e3f3c61c 100644 --- a/iroh-dns-server/Cargo.toml +++ b/iroh-dns-server/Cargo.toml @@ -29,7 +29,8 @@ governor = "0.6.3" #needs new release of tower_governor for 0.7.0 hickory-proto = "=0.25.0-alpha.2" hickory-server = { version = "=0.25.0-alpha.2", features = ["dns-over-rustls"] } http = "1.0.0" -iroh-metrics = "0.29" +humantime-serde = "1.1.1" +iroh-metrics = { version = "0.29.0" } lru = "0.12.3" parking_lot = "0.12.1" pkarr = { version = "2.2.0", features = [ "async", "relay", "dht"], default-features = false } @@ -64,6 +65,7 @@ hickory-resolver = "=0.25.0-alpha.2" iroh = { version = "0.29.0", path = "../iroh" } iroh-test = { version = "0.29.0", path = "../iroh-test" } pkarr = { version = "2.2.0", features = ["rand"] } +testresult = "0.4.1" [[bench]] name = "write" diff --git a/iroh-dns-server/benches/write.rs b/iroh-dns-server/benches/write.rs index 143a2b0917..52924672f3 100644 --- a/iroh-dns-server/benches/write.rs +++ b/iroh-dns-server/benches/write.rs @@ -7,7 +7,7 @@ use tokio::runtime::Runtime; const LOCALHOST_PKARR: &str = "http://localhost:8080/pkarr"; async fn start_dns_server(config: Config) -> Result { - let store = ZoneStore::persistent(Config::signed_packet_store_path()?)?; + let store = ZoneStore::persistent(Config::signed_packet_store_path()?, Default::default())?; Server::spawn(config, store).await } diff --git a/iroh-dns-server/src/config.rs b/iroh-dns-server/src/config.rs index 732d65e4e8..ba8409d24a 100644 --- a/iroh-dns-server/src/config.rs +++ b/iroh-dns-server/src/config.rs @@ -4,6 +4,7 @@ use std::{ env, net::{IpAddr, Ipv4Addr, SocketAddr}, path::{Path, PathBuf}, + time::Duration, }; use anyhow::{anyhow, Context, Result}; @@ -13,6 +14,7 @@ use tracing::info; use crate::{ dns::DnsConfig, http::{CertMode, HttpConfig, HttpsConfig, RateLimitConfig}, + store::ZoneStoreOptions, }; const DEFAULT_METRICS_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9117); @@ -44,11 +46,61 @@ pub struct Config { /// Config for the mainline lookup. pub mainline: Option, + /// Config for the zone store. + pub zone_store: Option, + /// Config for pkarr rate limit #[serde(default)] pub pkarr_put_rate_limit: RateLimitConfig, } +/// The config for the store. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct StoreConfig { + /// Maximum number of packets to process in a single write transaction. + max_batch_size: usize, + + /// Maximum time to keep a write transaction open. + #[serde(with = "humantime_serde")] + max_batch_time: Duration, + + /// Time to keep packets in the store before eviction. + #[serde(with = "humantime_serde")] + eviction: Duration, + + /// Pause between eviction checks. + #[serde(with = "humantime_serde")] + eviction_interval: Duration, +} + +impl Default for StoreConfig { + fn default() -> Self { + ZoneStoreOptions::default().into() + } +} + +impl From for StoreConfig { + fn from(value: ZoneStoreOptions) -> Self { + Self { + max_batch_size: value.max_batch_size, + max_batch_time: value.max_batch_time, + eviction: value.eviction, + eviction_interval: value.eviction_interval, + } + } +} + +impl From for ZoneStoreOptions { + fn from(value: StoreConfig) -> Self { + Self { + max_batch_size: value.max_batch_size, + max_batch_time: value.max_batch_time, + eviction: value.eviction, + eviction_interval: value.eviction_interval, + } + } +} + /// The config for the metrics server. #[derive(Debug, Serialize, Deserialize)] pub struct MetricsConfig { @@ -187,6 +239,7 @@ impl Default for Config { rr_aaaa: None, rr_ns: Some("ns1.irohdns.example.".to_string()), }, + zone_store: None, metrics: None, mainline: None, pkarr_put_rate_limit: RateLimitConfig::default(), diff --git a/iroh-dns-server/src/lib.rs b/iroh-dns-server/src/lib.rs index 8c18327e6e..9cea6bf51e 100644 --- a/iroh-dns-server/src/lib.rs +++ b/iroh-dns-server/src/lib.rs @@ -16,7 +16,10 @@ pub use store::ZoneStore; #[cfg(test)] mod tests { - use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::{ + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, + }; use anyhow::Result; use hickory_resolver::{ @@ -29,9 +32,16 @@ mod tests { key::SecretKey, }; use pkarr::{PkarrClient, SignedPacket}; + use testresult::TestResult; use url::Url; - use crate::{config::BootstrapOption, server::Server}; + use crate::{ + config::BootstrapOption, + server::Server, + store::{PacketSource, ZoneStoreOptions}, + util::PublicKeyBytes, + ZoneStore, + }; #[tokio::test] async fn pkarr_publish_dns_resolve() -> Result<()> { @@ -178,6 +188,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn store_eviction() -> TestResult<()> { + iroh_test::logging::setup_multithreaded(); + let options = ZoneStoreOptions { + eviction: Duration::from_millis(100), + eviction_interval: Duration::from_millis(100), + max_batch_time: Duration::from_millis(100), + ..Default::default() + }; + let store = ZoneStore::in_memory(options)?; + + // create a signed packet + let signed_packet = random_signed_packet()?; + let key = PublicKeyBytes::from_signed_packet(&signed_packet); + + store + .insert(signed_packet, PacketSource::PkarrPublish) + .await?; + + tokio::time::sleep(Duration::from_secs(1)).await; + for _ in 0..10 { + let entry = store.get_signed_packet(&key).await?; + if entry.is_none() { + return Ok(()); + } + tokio::time::sleep(Duration::from_secs(1)).await; + } + panic!("store did not evict packet"); + } + #[tokio::test] async fn integration_mainline() -> Result<()> { iroh_test::logging::setup_multithreaded(); @@ -188,7 +228,8 @@ mod tests { // spawn our server with mainline support let (server, nameserver, _http_url) = - Server::spawn_for_tests_with_mainline(Some(BootstrapOption::Custom(bootstrap))).await?; + Server::spawn_for_tests_with_options(Some(BootstrapOption::Custom(bootstrap)), None) + .await?; let origin = "irohdns.example."; @@ -228,4 +269,12 @@ mod tests { config.add_name_server(nameserver_config); AsyncResolver::tokio(config, Default::default()) } + + fn random_signed_packet() -> Result { + let secret_key = SecretKey::generate(); + let node_id = secret_key.public(); + let relay_url: Url = "https://relay.example.".parse()?; + let node_info = NodeInfo::new(node_id, Some(relay_url.clone()), Default::default()); + node_info.to_pkarr_signed_packet(&secret_key, 30) + } } diff --git a/iroh-dns-server/src/metrics.rs b/iroh-dns-server/src/metrics.rs index 82d715dd5a..1e0cc21088 100644 --- a/iroh-dns-server/src/metrics.rs +++ b/iroh-dns-server/src/metrics.rs @@ -22,6 +22,7 @@ pub struct Metrics { pub store_packets_inserted: Counter, pub store_packets_removed: Counter, pub store_packets_updated: Counter, + pub store_packets_expired: Counter, } impl Default for Metrics { @@ -44,6 +45,7 @@ impl Default for Metrics { store_packets_inserted: Counter::new("Signed packets inserted into the store"), store_packets_removed: Counter::new("Signed packets removed from the store"), store_packets_updated: Counter::new("Number of updates to existing packets"), + store_packets_expired: Counter::new("Number of expired packets"), } } } diff --git a/iroh-dns-server/src/server.rs b/iroh-dns-server/src/server.rs index 865c5cecf8..e40e3de1fd 100644 --- a/iroh-dns-server/src/server.rs +++ b/iroh-dns-server/src/server.rs @@ -14,7 +14,11 @@ use crate::{ /// Spawn the server and run until the `Ctrl-C` signal is received, then shutdown. pub async fn run_with_config_until_ctrl_c(config: Config) -> Result<()> { - let mut store = ZoneStore::persistent(Config::signed_packet_store_path()?)?; + let zone_store_options = config.zone_store.clone().unwrap_or_default(); + let mut store = ZoneStore::persistent( + Config::signed_packet_store_path()?, + zone_store_options.into(), + )?; if let Some(bootstrap) = config.mainline_enabled() { info!("mainline fallback enabled"); store = store.with_mainline_fallback(bootstrap); @@ -96,14 +100,15 @@ impl Server { /// HTTP server. #[cfg(test)] pub async fn spawn_for_tests() -> Result<(Self, std::net::SocketAddr, url::Url)> { - Self::spawn_for_tests_with_mainline(None).await + Self::spawn_for_tests_with_options(None, None).await } /// Spawn a server suitable for testing, while optionally enabling mainline with custom /// bootstrap addresses. #[cfg(test)] - pub async fn spawn_for_tests_with_mainline( + pub async fn spawn_for_tests_with_options( mainline: Option, + options: Option, ) -> Result<(Self, std::net::SocketAddr, url::Url)> { use std::net::{IpAddr, Ipv4Addr}; @@ -117,7 +122,7 @@ impl Server { config.https = None; config.metrics = Some(MetricsConfig::disabled()); - let mut store = ZoneStore::in_memory()?; + let mut store = ZoneStore::in_memory(options.unwrap_or_default())?; if let Some(bootstrap) = mainline { info!("mainline fallback enabled"); store = store.with_mainline_fallback(bootstrap); diff --git a/iroh-dns-server/src/store.rs b/iroh-dns-server/src/store.rs index 89f3ca9f43..3286fe0132 100644 --- a/iroh-dns-server/src/store.rs +++ b/iroh-dns-server/src/store.rs @@ -19,6 +19,7 @@ use crate::{ }; mod signed_packets; +pub use signed_packets::Options as ZoneStoreOptions; /// Cache up to 1 million pkarr zones by default pub const DEFAULT_CACHE_CAPACITY: usize = 1024 * 1024; @@ -44,14 +45,14 @@ pub struct ZoneStore { impl ZoneStore { /// Create a persistent store - pub fn persistent(path: impl AsRef) -> Result { - let packet_store = SignedPacketStore::persistent(path)?; + pub fn persistent(path: impl AsRef, options: ZoneStoreOptions) -> Result { + let packet_store = SignedPacketStore::persistent(path, options)?; Ok(Self::new(packet_store)) } /// Create an in-memory store. - pub fn in_memory() -> Result { - let packet_store = SignedPacketStore::in_memory()?; + pub fn in_memory(options: ZoneStoreOptions) -> Result { + let packet_store = SignedPacketStore::in_memory(options)?; Ok(Self::new(packet_store)) } diff --git a/iroh-dns-server/src/store/signed_packets.rs b/iroh-dns-server/src/store/signed_packets.rs index 63c5f66c7e..34de1abe94 100644 --- a/iroh-dns-server/src/store/signed_packets.rs +++ b/iroh-dns-server/src/store/signed_packets.rs @@ -1,40 +1,41 @@ -use std::{path::Path, result, time::Duration}; +use std::{future::Future, path::Path, result, time::Duration}; use anyhow::{Context, Result}; +use bytes::Bytes; use iroh_metrics::inc; -use pkarr::SignedPacket; -use redb::{backends::InMemoryBackend, Database, ReadableTable, TableDefinition}; +use pkarr::{system_time, SignedPacket}; +use redb::{ + backends::InMemoryBackend, Database, MultimapTableDefinition, ReadableTable, TableDefinition, +}; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; -use tracing::info; +use tracing::{debug, error, info, trace}; use crate::{metrics::Metrics, util::PublicKeyBytes}; pub type SignedPacketsKey = [u8; 32]; const SIGNED_PACKETS_TABLE: TableDefinition<&SignedPacketsKey, &[u8]> = TableDefinition::new("signed-packets-1"); -const MAX_BATCH_SIZE: usize = 1024 * 64; -const MAX_BATCH_TIME: Duration = Duration::from_secs(1); +const UPDATE_TIME_TABLE: MultimapTableDefinition<[u8; 8], SignedPacketsKey> = + MultimapTableDefinition::new("update-time-1"); #[derive(Debug)] pub struct SignedPacketStore { send: mpsc::Sender, cancel: CancellationToken, - thread: Option>, + _write_thread: IoThread, + _evict_thread: IoThread, } impl Drop for SignedPacketStore { fn drop(&mut self) { // cancel the actor self.cancel.cancel(); - // join the thread. This is important so that Drop implementations that - // are called from the actor thread can complete before we return. - if let Some(thread) = self.thread.take() { - let _ = thread.join(); - } + // after cancellation, the two threads will be joined } } +#[derive(derive_more::Debug)] enum Message { Upsert { packet: SignedPacket, @@ -48,14 +49,48 @@ enum Message { key: PublicKeyBytes, res: oneshot::Sender, }, + Snapshot { + #[debug(skip)] + res: oneshot::Sender, + }, + CheckExpired { + time: [u8; 8], + key: PublicKeyBytes, + }, } struct Actor { db: Database, - recv: mpsc::Receiver, + recv: PeekableReceiver, cancel: CancellationToken, - max_batch_size: usize, - max_batch_time: Duration, + options: Options, +} + +#[derive(Debug, Clone, Copy)] +pub struct Options { + /// Maximum number of packets to process in a single write transaction. + pub max_batch_size: usize, + /// Maximum time to keep a write transaction open. + pub max_batch_time: Duration, + /// Time to keep packets in the store before eviction. + pub eviction: Duration, + /// Pause between eviction checks. + pub eviction_interval: Duration, +} + +impl Default for Options { + fn default() -> Self { + Self { + // 64k packets + max_batch_size: 1024 * 64, + // this means we lose at most 1 second of data in case of a crash + max_batch_time: Duration::from_secs(1), + // 7 days + eviction: Duration::from_secs(3600 * 24 * 7), + // eviction can run frequently since it does not do a full scan + eviction_interval: Duration::from_secs(10), + } + } } impl Actor { @@ -63,19 +98,31 @@ impl Actor { match self.run0().await { Ok(()) => {} Err(e) => { - self.cancel.cancel(); tracing::error!("packet store actor failed: {:?}", e); + self.cancel.cancel(); } } } async fn run0(&mut self) -> anyhow::Result<()> { - loop { + let expiry_us = self.options.eviction.as_micros() as u64; + while let Some(msg) = self.recv.recv().await { + // if we get a snapshot message here we don't need to do a write transaction + let msg = if let Message::Snapshot { res } = msg { + let snapshot = Snapshot::new(&self.db)?; + res.send(snapshot).ok(); + continue; + } else { + msg + }; + trace!("batch"); + self.recv.push_back(msg).unwrap(); let transaction = self.db.begin_write()?; let mut tables = Tables::new(&transaction)?; - let timeout = tokio::time::sleep(self.max_batch_time); + let timeout = tokio::time::sleep(self.options.max_batch_time); + let expired = system_time() - expiry_us; tokio::pin!(timeout); - for _ in 0..self.max_batch_size { + for _ in 0..self.options.max_batch_size { tokio::select! { _ = self.cancel.cancelled() => { drop(tables); @@ -86,22 +133,28 @@ impl Actor { Some(msg) = self.recv.recv() => { match msg { Message::Get { key, res } => { + trace!("get {}", key); let packet = get_packet(&tables.signed_packets, &key)?; res.send(packet).ok(); } Message::Upsert { packet, res } => { let key = PublicKeyBytes::from_signed_packet(&packet); - let mut replaced = false; - if let Some(existing) = get_packet(&tables.signed_packets, &key)? { + trace!("upsert {}", key); + let replaced = if let Some(existing) = get_packet(&tables.signed_packets, &key)? { if existing.more_recent_than(&packet) { res.send(false).ok(); continue; } else { - replaced = true; + // remove the packet from the update time index + tables.update_time.remove(&packet.timestamp().to_be_bytes(), key.as_bytes())?; + true } - } + } else { + false + }; let value = packet.as_bytes(); tables.signed_packets.insert(key.as_bytes(), &value[..])?; + tables.update_time.insert(&packet.timestamp().to_be_bytes(), key.as_bytes())?; if replaced { inc!(Metrics, store_packets_updated); } else { @@ -110,14 +163,34 @@ impl Actor { res.send(true).ok(); } Message::Remove { key, res } => { - let updated = - tables.signed_packets.remove(key.as_bytes())?.is_some() - ; + trace!("remove {}", key); + let updated = if let Some(row) = tables.signed_packets.remove(key.as_bytes())? { + let packet = SignedPacket::from_bytes(&Bytes::copy_from_slice(row.value()))?; + tables.update_time.remove(&packet.timestamp().to_be_bytes(), key.as_bytes())?; + inc!(Metrics, store_packets_removed); + true + } else { + false + }; if updated { inc!(Metrics, store_packets_removed); } res.send(updated).ok(); } + Message::Snapshot { res } => { + trace!("snapshot"); + res.send(Snapshot::new(&self.db)?).ok(); + } + Message::CheckExpired { key, time } => { + trace!("check expired {} at {}", key, u64::from_be_bytes(time)); + if let Some(packet) = get_packet(&tables.signed_packets, &key)? { + if packet.timestamp() < expired { + tables.update_time.remove(&time, key.as_bytes())?; + let _ = tables.signed_packets.remove(key.as_bytes())?; + inc!(Metrics, store_packets_expired); + } + } + } } } } @@ -125,6 +198,7 @@ impl Actor { drop(tables); transaction.commit()?; } + Ok(()) } } @@ -132,18 +206,36 @@ impl Actor { /// signed packet store. pub(super) struct Tables<'a> { pub signed_packets: redb::Table<'a, &'static SignedPacketsKey, &'static [u8]>, + pub update_time: redb::MultimapTable<'a, [u8; 8], SignedPacketsKey>, } impl<'txn> Tables<'txn> { pub fn new(tx: &'txn redb::WriteTransaction) -> result::Result { Ok(Self { signed_packets: tx.open_table(SIGNED_PACKETS_TABLE)?, + update_time: tx.open_multimap_table(UPDATE_TIME_TABLE)?, + }) + } +} + +pub(super) struct Snapshot { + #[allow(dead_code)] + pub signed_packets: redb::ReadOnlyTable<&'static SignedPacketsKey, &'static [u8]>, + pub update_time: redb::ReadOnlyMultimapTable<[u8; 8], SignedPacketsKey>, +} + +impl Snapshot { + pub fn new(db: &Database) -> Result { + let tx = db.begin_read()?; + Ok(Self { + signed_packets: tx.open_table(SIGNED_PACKETS_TABLE)?, + update_time: tx.open_multimap_table(UPDATE_TIME_TABLE)?, }) } } impl SignedPacketStore { - pub fn persistent(path: impl AsRef) -> Result { + pub fn persistent(path: impl AsRef, options: Options) -> Result { let path = path.as_ref(); info!("loading packet database from {}", path.to_string_lossy()); if let Some(parent) = path.parent() { @@ -157,42 +249,42 @@ impl SignedPacketStore { let db = Database::builder() .create(path) .context("failed to open packet database")?; - Self::open(db) + Self::open(db, options) } - pub fn in_memory() -> Result { + pub fn in_memory(options: Options) -> Result { info!("using in-memory packet database"); let db = Database::builder().create_with_backend(InMemoryBackend::new())?; - Self::open(db) + Self::open(db, options) } - pub fn open(db: Database) -> Result { + pub fn open(db: Database, options: Options) -> Result { // create tables let write_tx = db.begin_write()?; let _ = Tables::new(&write_tx)?; write_tx.commit()?; let (send, recv) = mpsc::channel(1024); + let send2 = send.clone(); let cancel = CancellationToken::new(); let cancel2 = cancel.clone(); + let cancel3 = cancel.clone(); let actor = Actor { db, - recv, + recv: PeekableReceiver::new(recv), cancel: cancel2, - max_batch_size: MAX_BATCH_SIZE, - max_batch_time: MAX_BATCH_TIME, + options, }; // start an io thread and donate it to the tokio runtime so we can do blocking IO // inside the thread despite being in a tokio runtime - let handle = tokio::runtime::Handle::try_current()?; - let thread = std::thread::Builder::new() - .name("packet-store-actor".into()) - .spawn(move || { - handle.block_on(actor.run()); - })?; + let _write_thread = IoThread::new("packet-store-actor", move || actor.run())?; + let _evict_thread = IoThread::new("packet-store-evict", move || { + evict_task(send2, options, cancel3) + })?; Ok(Self { send, cancel, - thread: Some(thread), + _write_thread, + _evict_thread, }) } @@ -227,3 +319,142 @@ fn get_packet( let packet = SignedPacket::from_bytes(&row.value().to_vec().into())?; Ok(Some(packet)) } + +async fn evict_task(send: mpsc::Sender, options: Options, cancel: CancellationToken) { + let cancel2 = cancel.clone(); + let _ = cancel2 + .run_until_cancelled(async move { + info!("starting evict task"); + if let Err(cause) = evict_task_inner(send, options).await { + error!("evict task failed: {:?}", cause); + } + // when we are done for whatever reason we want to shut down the actor + cancel.cancel(); + }) + .await; +} + +/// Periodically check for expired packets and remove them. +async fn evict_task_inner(send: mpsc::Sender, options: Options) -> anyhow::Result<()> { + let expiry_us = options.eviction.as_micros() as u64; + loop { + let (tx, rx) = oneshot::channel(); + let _ = send.send(Message::Snapshot { res: tx }).await.ok(); + // if we can't get the snapshot we exit the loop, main actor dead + let Ok(snapshot) = rx.await else { + anyhow::bail!("failed to get snapshot"); + }; + let expired = system_time() - expiry_us; + trace!("evicting packets older than {}", expired); + // if getting the range fails we exit the loop and shut down + // if individual reads fail we log the error and limp on + for item in snapshot.update_time.range(..expired.to_be_bytes())? { + let (time, keys) = match item { + Ok(v) => v, + Err(e) => { + error!("failed to read update_time row {:?}", e); + continue; + } + }; + let time = time.value(); + trace!("evicting expired packets at {}", u64::from_be_bytes(time)); + for item in keys { + let key = match item { + Ok(v) => v, + Err(e) => { + error!( + "failed to read update_time item at {}: {:?}", + u64::from_be_bytes(time), + e + ); + continue; + } + }; + let key = PublicKeyBytes::new(key.value()); + debug!( + "evicting expired packet {} {}", + u64::from_be_bytes(time), + key + ); + send.send(Message::CheckExpired { time, key }).await?; + } + } + // sleep for the eviction interval so we don't constantly check + tokio::time::sleep(options.eviction_interval).await; + } +} + +/// An io thread that drives a future to completion on the current tokio runtime +/// +/// Inside the future, blocking IO can be done without blocking one of the tokio +/// pool threads. +#[derive(Debug)] +struct IoThread { + handle: Option>, +} + +impl IoThread { + /// Spawn a new io thread. + /// + /// Calling this function requires that the current thread is running in a + /// tokio runtime. It is up to the caller to make sure the future exits, + /// e.g. by using a cancellation token. Otherwise, drop will block. + fn new(name: &str, f: F) -> Result + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future, + { + let rt = tokio::runtime::Handle::try_current()?; + let handle = std::thread::Builder::new() + .name(name.into()) + .spawn(move || rt.block_on(f())) + .context("failed to spawn thread")?; + Ok(Self { + handle: Some(handle), + }) + } +} + +impl Drop for IoThread { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } + } +} + +/// A wrapper for a tokio mpsc receiver that allows peeking at the next message. +#[derive(Debug)] +pub(super) struct PeekableReceiver { + msg: Option, + recv: tokio::sync::mpsc::Receiver, +} + +#[allow(dead_code)] +impl PeekableReceiver { + pub fn new(recv: tokio::sync::mpsc::Receiver) -> Self { + Self { msg: None, recv } + } + + /// Receive the next message. + /// + /// Will block if there are no messages. + /// Returns None only if there are no more messages (sender is dropped). + pub async fn recv(&mut self) -> Option { + if let Some(msg) = self.msg.take() { + return Some(msg); + } + self.recv.recv().await + } + + /// Push back a message. This will only work if there is room for it. + /// Otherwise, it will fail and return the message. + pub fn push_back(&mut self, msg: T) -> std::result::Result<(), T> { + if self.msg.is_none() { + self.msg = Some(msg); + Ok(()) + } else { + Err(msg) + } + } +} diff --git a/iroh-dns-server/src/util.rs b/iroh-dns-server/src/util.rs index 6fc28b4d1f..b395b91d06 100644 --- a/iroh-dns-server/src/util.rs +++ b/iroh-dns-server/src/util.rs @@ -22,6 +22,10 @@ use pkarr::SignedPacket; pub struct PublicKeyBytes([u8; 32]); impl PublicKeyBytes { + pub fn new(bytes: [u8; 32]) -> Self { + Self(bytes) + } + pub fn from_z32(s: &str) -> Result { let bytes = z32::decode(s.as_bytes())?; let bytes: [u8; 32] = bytes.try_into().map_err(|_| anyhow!("invalid length"))?;