diff --git a/iroh-blobs/src/store/fs.rs b/iroh-blobs/src/store/fs.rs index 5febe54457..6308d58c61 100644 --- a/iroh-blobs/src/store/fs.rs +++ b/iroh-blobs/src/store/fs.rs @@ -1424,6 +1424,10 @@ impl super::Store for Store { self.0.temp.temp_tag(value) } + fn tag_drop(&self) -> Option<&dyn TagDrop> { + Some(self.0.temp.as_ref()) + } + async fn shutdown(&self) { self.0.shutdown().await; } diff --git a/iroh-blobs/src/store/mem.rs b/iroh-blobs/src/store/mem.rs index e10849e2b7..d98af09f04 100644 --- a/iroh-blobs/src/store/mem.rs +++ b/iroh-blobs/src/store/mem.rs @@ -222,6 +222,10 @@ impl super::Store for Store { self.inner.temp_tag(tag) } + fn tag_drop(&self) -> Option<&dyn TagDrop> { + Some(self.inner.as_ref()) + } + async fn gc_start(&self) -> io::Result<()> { Ok(()) } diff --git a/iroh-blobs/src/store/readonly_mem.rs b/iroh-blobs/src/store/readonly_mem.rs index 4b77698313..2ef0a2b89e 100644 --- a/iroh-blobs/src/store/readonly_mem.rs +++ b/iroh-blobs/src/store/readonly_mem.rs @@ -15,7 +15,7 @@ use crate::{ }, util::{ progress::{BoxedProgressSender, IdGenerator, ProgressSender}, - Tag, + Tag, TagDrop, }, BlobFormat, Hash, HashAndFormat, TempTag, IROH_BLOCK_SIZE, }; @@ -324,6 +324,10 @@ impl super::Store for Store { TempTag::new(inner, None) } + fn tag_drop(&self) -> Option<&dyn TagDrop> { + None + } + async fn gc_start(&self) -> io::Result<()> { Ok(()) } diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index e0ec3e6b39..49d0a43abd 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -19,7 +19,7 @@ use crate::{ protocol::RangeSpec, util::{ progress::{BoxedProgressSender, IdGenerator, ProgressSender}, - Tag, + Tag, TagDrop, }, BlobFormat, Hash, HashAndFormat, TempTag, IROH_BLOCK_SIZE, }; @@ -356,6 +356,9 @@ pub trait Store: ReadableStore + MapMut { /// Create a temporary pin for this store fn temp_tag(&self, value: HashAndFormat) -> TempTag; + /// Handle to use to drop tags + fn tag_drop(&self) -> Option<&dyn TagDrop>; + /// Notify the store that a new gc phase is about to start. /// /// This should not fail unless the store is shut down or otherwise in a diff --git a/iroh-blobs/src/util.rs b/iroh-blobs/src/util.rs index 751886492c..d1c3dd3ebd 100644 --- a/iroh-blobs/src/util.rs +++ b/iroh-blobs/src/util.rs @@ -255,11 +255,17 @@ impl TempTag { self.inner.format } + /// The hash and format of the pinned item + pub fn hash_and_format(&self) -> HashAndFormat { + self.inner + } + /// Keep the item alive until the end of the process pub fn leak(mut self) { // set the liveness tracker to None, so that the refcount is not decreased // during drop. This means that the refcount will never reach 0 and the - // item will not be gced until the end of the process. + // item will not be gced until the end of the process, unless you manually + // invoke on_drop. self.on_drop = None; } } diff --git a/iroh/src/client/blobs.rs b/iroh/src/client/blobs.rs index 64f4f10b82..2a88d9787c 100644 --- a/iroh/src/client/blobs.rs +++ b/iroh/src/client/blobs.rs @@ -34,12 +34,13 @@ use tokio_util::io::{ReaderStream, StreamReader}; use tracing::warn; use crate::rpc_protocol::{ - BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchUpdate, - BlobAddPathRequest, BlobAddStreamRequest, BlobAddStreamUpdate, BlobConsistencyCheckRequest, - BlobDeleteBlobRequest, BlobDownloadRequest, BlobExportRequest, BlobGetCollectionRequest, - BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest, - BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, - CreateCollectionRequest, CreateCollectionResponse, NodeStatusRequest, RpcService, SetTagOption, + BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchCreateRequest, + BatchCreateResponse, BatchUpdate, BlobAddPathRequest, BlobAddStreamRequest, + BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, + BlobExportRequest, BlobGetCollectionRequest, BlobGetCollectionResponse, + BlobListCollectionsRequest, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, + BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, + NodeStatusRequest, RpcService, SetTagOption, }; use super::{flatten, Iroh}; @@ -60,6 +61,14 @@ impl Client where C: ServiceConnection, { + /// Create a new batch for adding data. + pub async fn batch(&self) -> Result> { + let (updates, mut stream) = self.rpc.bidi(BatchCreateRequest).await?; + let updates = Mutex::new(updates); + let BatchCreateResponse::Id(id) = stream.next().await.context("expected scope id")??; + let rpc = self.rpc.clone(); + Ok(Batch(Arc::new(BatchInner { id, rpc, updates }))) + } /// Stream the contents of a a single blob. /// /// Returns a [`Reader`], which can report the size of the blob before reading it. @@ -956,7 +965,6 @@ pub enum DownloadMode { mod tests { use super::*; - use anyhow::Context as _; use rand::RngCore; use tokio::io::AsyncWriteExt; diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 058363276f..112a7868ca 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -3,16 +3,19 @@ //! A node is a server that serves various protocols. //! //! To shut down the node, call [`Node::shutdown`]. +use std::collections::BTreeMap; use std::fmt::Debug; use std::net::SocketAddr; use std::path::Path; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; +use iroh_blobs::util::TagDrop; +use iroh_blobs::{HashAndFormat, TempTag}; use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; use quic_rpc::transport::flume::FlumeConnection; @@ -62,6 +65,63 @@ struct NodeInner { rt: LocalPoolHandle, pub(crate) sync: Engine, downloader: Downloader, + blob_scopes: Mutex, +} + +#[derive(Debug, Default)] +struct BlobScopes { + scopes: BTreeMap, + max: u64, +} + +#[derive(Debug, Default)] +struct BlobScope { + tags: BTreeMap, +} + +impl BlobScopes { + /// Create a new blob scope. + fn create(&mut self) -> u64 { + let id = self.max; + self.max += 1; + id + } + + /// Store a tag in a scope. + fn store(&mut self, scope: u64, tt: TempTag) { + let entry = self.scopes.entry(scope).or_default(); + let count = entry.tags.entry(tt.hash_and_format()).or_default(); + tt.leak(); + *count += 1; + } + + /// Remove a tag from a scope. + fn remove_one(&mut self, scope: u64, content: &HashAndFormat, u: Option<&dyn TagDrop>) { + if let Some(scope) = self.scopes.get_mut(&scope) { + if let Some(counter) = scope.tags.get_mut(content) { + *counter -= 1; + if let Some(u) = u { + u.on_drop(content); + } + if *counter == 0 { + scope.tags.remove(content); + } + } + } + } + + /// Remove an entire scope. + fn remove(&mut self, scope: u64, u: Option<&dyn TagDrop>) { + if let Some(scope) = self.scopes.remove(&scope) { + for (content, count) in scope.tags { + if let Some(u) = u { + for _ in 0..count { + u.on_drop(&content); + } + } + } + } + } } /// In memory node. diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 7c9875f3c1..3015c4961c 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -480,7 +480,6 @@ where }; let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); - let inner = Arc::new(NodeInner { db: self.blobs_store, endpoint: endpoint.clone(), @@ -491,6 +490,7 @@ where rt: lp.clone(), sync, downloader, + blob_scopes: Default::default(), }); let task = { let gossip = gossip.clone(); diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 0f50a253e2..4aa4cb600d 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -854,23 +854,25 @@ impl Handler { _: BatchCreateRequest, mut updates: impl Stream + Send + Unpin + 'static, ) -> impl Stream { - let scope_id = 0; - // let scope_id = self.inner.temp_tags.lock().unwrap().create(); + let scope_id = self.inner.blob_scopes.lock().unwrap().create(); tokio::spawn(async move { while let Some(item) = updates.next().await { match item { BatchUpdate::Drop(content) => { - // println!("dropping tag {} {}", scope_id, tag_id); - // self.inner - // .temp_tags - // .lock() - // .unwrap() - // .remove_one(scope_id, tag_id); + self.inner.blob_scopes.lock().unwrap().remove_one( + scope_id, + &content, + self.inner.db.tag_drop(), + ); } } } println!("dropping scope {}", scope_id); - // self.inner.temp_tags.lock().unwrap().remove(scope_id); + self.inner + .blob_scopes + .lock() + .unwrap() + .remove(scope_id, self.inner.db.tag_drop()); }); futures_lite::stream::once(BatchCreateResponse::Id(scope_id)) } @@ -899,7 +901,6 @@ impl Handler { stream: impl Stream + Send + Unpin + 'static, progress: flume::Sender, ) -> anyhow::Result<()> { - println!("batch_add_stream0"); let progress = FlumeProgressSender::new(progress); let stream = stream.map(|item| match item { @@ -912,23 +913,17 @@ impl Handler { let import_progress = progress.clone().with_filter_map(move |x| match x { _ => None, }); - println!("collecting stream"); - let items: Vec<_> = stream.collect().await; - println!("stream collected"); - let stream = futures_lite::stream::iter(items.into_iter()); let (temp_tag, _len) = self .inner .db .import_stream(stream, BlobFormat::Raw, import_progress) .await?; - println!("stream imported {:?}", temp_tag.inner().hash); let hash = temp_tag.inner().hash; - // let tag = self - // .inner - // .temp_tags - // .lock() - // .unwrap() - // .create_one(msg.scope, temp_tag); + self.inner + .blob_scopes + .lock() + .unwrap() + .store(msg.scope, temp_tag); progress .send(BatchAddStreamResponse::Result { hash }) .await?; diff --git a/iroh/tests/batch.rs b/iroh/tests/batch.rs new file mode 100644 index 0000000000..e36e0c31de --- /dev/null +++ b/iroh/tests/batch.rs @@ -0,0 +1,54 @@ +use std::time::Duration; + +use bao_tree::blake3; +use iroh::node::GcPolicy; +use iroh_blobs::{store::mem::Store, BlobFormat}; + +async fn create_node() -> anyhow::Result> { + iroh::node::Node::memory() + .gc_policy(GcPolicy::Interval(Duration::from_millis(10))) + .spawn() + .await +} + +#[tokio::test] +async fn test_batch_create_1() -> anyhow::Result<()> { + let node = create_node().await?; + let client = &node.client().blobs; + let batch = client.batch().await?; + let expected_data: &[u8] = b"test"; + let expected_hash = blake3::hash(expected_data).into(); + let tag = batch.add_bytes(expected_data, BlobFormat::Raw).await?; + let hash = *tag.hash(); + assert_eq!(hash, expected_hash); + // Check that the store has the data and that it is protected from gc + tokio::time::sleep(Duration::from_millis(50)).await; + let data = client.read_to_bytes(hash).await?; + assert_eq!(data.as_ref(), expected_data); + drop(tag); + // Check that the store drops the data when the temp tag gets dropped + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(client.read_to_bytes(hash).await.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_batch_create_2() -> anyhow::Result<()> { + let node = create_node().await?; + let client = &node.client().blobs; + let batch = client.batch().await?; + let expected_data: &[u8] = b"test"; + let expected_hash = blake3::hash(expected_data).into(); + let tag = batch.add_bytes(expected_data, BlobFormat::Raw).await?; + let hash = *tag.hash(); + assert_eq!(hash, expected_hash); + // Check that the store has the data and that it is protected from gc + tokio::time::sleep(Duration::from_millis(50)).await; + let data = client.read_to_bytes(hash).await?; + assert_eq!(data.as_ref(), expected_data); + drop(batch); + // Check that the store drops the data when the temp tag gets dropped + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(client.read_to_bytes(hash).await.is_err()); + Ok(()) +}