From b443becc65c76379cc32431e57d35738a4007529 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 3 Jun 2024 12:59:34 +0300 Subject: [PATCH] WIP add batch API --- iroh/src/client/blobs.rs | 101 +++++++++++++++++++++++++++++++-- iroh/src/node/rpc.rs | 117 +++++++++++++++++++++++++++++++++++---- iroh/src/rpc_protocol.rs | 71 ++++++++++++++++++++++++ 3 files changed, 273 insertions(+), 16 deletions(-) diff --git a/iroh/src/client/blobs.rs b/iroh/src/client/blobs.rs index 61d075e7fc..64f4f10b82 100644 --- a/iroh/src/client/blobs.rs +++ b/iroh/src/client/blobs.rs @@ -5,31 +5,36 @@ use std::{ io, path::PathBuf, pin::Pin, - sync::Arc, + sync::{Arc, Mutex}, task::{Context, Poll}, }; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use bytes::Bytes; use futures_lite::{Stream, StreamExt}; -use futures_util::SinkExt; +use futures_util::{FutureExt, SinkExt}; use iroh_base::{node_addr::AddrInfoOptions, ticket::BlobTicket}; use iroh_blobs::{ export::ExportProgress as BytesExportProgress, format::collection::Collection, get::db::DownloadProgress as BytesDownloadProgress, store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ValidateProgress}, - BlobFormat, Hash, Tag, + util::TagDrop, + BlobFormat, Hash, HashAndFormat, Tag, TempTag, }; use iroh_net::NodeAddr; use portable_atomic::{AtomicU64, Ordering}; -use quic_rpc::{client::BoxStreamSync, RpcClient, ServiceConnection}; +use quic_rpc::{ + client::{BoxStreamSync, UpdateSink}, + RpcClient, ServiceConnection, +}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_util::io::{ReaderStream, StreamReader}; use tracing::warn; use crate::rpc_protocol::{ + BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchUpdate, BlobAddPathRequest, BlobAddStreamRequest, BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, BlobExportRequest, BlobGetCollectionRequest, BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest, @@ -368,6 +373,92 @@ where } } +/// A scope in which blobs can be added. +#[derive(derive_more::Debug)] +struct BatchInner> { + /// The id of the scope. + id: u64, + /// The rpc client. + rpc: RpcClient, + /// The stream to send drop + #[debug(skip)] + updates: Mutex>, +} + +/// + +#[derive(derive_more::Debug)] +pub struct Batch>(Arc>); + +impl> TagDrop for BatchInner { + fn on_drop(&self, content: &HashAndFormat) { + let mut updates = self.updates.lock().unwrap(); + updates.send(BatchUpdate::Drop(*content)).now_or_never(); + } +} + +impl> Batch { + /// Write a blob by passing bytes. + pub async fn add_bytes(&self, bytes: impl Into, format: BlobFormat) -> Result { + let input = futures_lite::stream::once(Ok(bytes.into())); + self.add_stream(input, format).await + } + + /// Write a blob by passing a stream of bytes. + pub async fn add_stream( + &self, + mut input: impl Stream> + Send + Unpin + 'static, + format: BlobFormat, + ) -> Result { + let (mut sink, mut stream) = self + .0 + .rpc + .bidi(BatchAddStreamRequest { + scope: self.0.id, + format, + }) + .await?; + while let Some(item) = input.next().await { + match item { + Ok(chunk) => { + sink.send(BatchAddStreamUpdate::Chunk(chunk)) + .await + .map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?; + } + Err(err) => { + warn!("Abort send, reason: failed to read from source stream: {err:?}"); + sink.send(BatchAddStreamUpdate::Abort) + .await + .map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?; + break; + } + } + } + sink.close() + .await + .map_err(|err| anyhow!("Failed to close the stream: {err:?}"))?; + // this is needed for the remote to notice that the stream is closed + drop(sink); + let mut res = None; + while let Some(item) = stream.next().await { + match item? { + BatchAddStreamResponse::Abort(cause) => { + Err(cause)?; + } + BatchAddStreamResponse::Result { hash } => { + res = Some(hash); + } + } + } + let hash = res.context("Missing answer")?; + let t: Arc = self.0.clone(); + Ok(TempTag::new( + HashAndFormat { hash, format }, + Some(Arc::downgrade(&t)), + )) + } +} + /// Whether to wrap the added data in a collection. #[derive(Debug, Serialize, Deserialize)] pub enum WrapOption { diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index ba03e10486..0f50a253e2 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -39,17 +39,19 @@ use crate::client::blobs::{ use crate::client::tags::TagInfo; use crate::client::NodeStatus; use crate::rpc_protocol::{ - BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, - BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, - BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobGetCollectionRequest, - BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest, - BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, - CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, - DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest, - ListTagsRequest, NodeAddrRequest, NodeConnectionInfoRequest, NodeConnectionInfoResponse, - NodeConnectionsRequest, NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, - NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, - NodeWatchResponse, Request, RpcService, SetTagOption, + BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchCreateRequest, + BatchCreateResponse, BatchUpdate, BlobAddPathRequest, BlobAddPathResponse, + BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest, + BlobDeleteBlobRequest, BlobDownloadRequest, BlobDownloadResponse, BlobExportRequest, + BlobExportResponse, BlobGetCollectionRequest, BlobGetCollectionResponse, + BlobListCollectionsRequest, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, + BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, + DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, DocImportFileRequest, + DocImportFileResponse, DocSetHashRequest, ListTagsRequest, NodeAddrRequest, + NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest, + NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, NodeShutdownRequest, + NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, NodeWatchResponse, + Request, RpcService, SetTagOption, }; use super::NodeInner; @@ -99,6 +101,11 @@ impl Handler { } CreateCollection(msg) => chan.rpc(msg, handler, Self::create_collection).await, BlobGetCollection(msg) => chan.rpc(msg, handler, Self::blob_get_collection).await, + BatchAddStreamRequest(msg) => { + chan.bidi_streaming(msg, handler, Self::batch_add_stream) + .await + } + BatchAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), ListTags(msg) => { chan.server_streaming(msg, handler, Self::blob_list_tags) .await @@ -131,6 +138,8 @@ impl Handler { .await } BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), + BatchCreate(msg) => chan.bidi_streaming(msg, handler, Self::batch_create).await, + BatchUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), AuthorList(msg) => { chan.server_streaming(msg, handler, |handler, req| { handler.inner.sync.author_list(req) @@ -840,6 +849,92 @@ impl Handler { }) } + fn batch_create( + self, + _: BatchCreateRequest, + mut updates: impl Stream + Send + Unpin + 'static, + ) -> impl Stream { + let scope_id = 0; + // let scope_id = self.inner.temp_tags.lock().unwrap().create(); + tokio::spawn(async move { + while let Some(item) = updates.next().await { + match item { + BatchUpdate::Drop(content) => { + // println!("dropping tag {} {}", scope_id, tag_id); + // self.inner + // .temp_tags + // .lock() + // .unwrap() + // .remove_one(scope_id, tag_id); + } + } + } + println!("dropping scope {}", scope_id); + // self.inner.temp_tags.lock().unwrap().remove(scope_id); + }); + futures_lite::stream::once(BatchCreateResponse::Id(scope_id)) + } + + fn batch_add_stream( + self, + msg: BatchAddStreamRequest, + stream: impl Stream + Send + Unpin + 'static, + ) -> impl Stream { + let (tx, rx) = flume::bounded(32); + let this = self.clone(); + + self.rt().spawn_pinned(|| async move { + if let Err(err) = this.batch_add_stream0(msg, stream, tx.clone()).await { + tx.send_async(BatchAddStreamResponse::Abort(err.into())) + .await + .ok(); + } + }); + rx.into_stream() + } + + async fn batch_add_stream0( + self, + msg: BatchAddStreamRequest, + stream: impl Stream + Send + Unpin + 'static, + progress: flume::Sender, + ) -> anyhow::Result<()> { + println!("batch_add_stream0"); + let progress = FlumeProgressSender::new(progress); + + let stream = stream.map(|item| match item { + BatchAddStreamUpdate::Chunk(chunk) => Ok(chunk), + BatchAddStreamUpdate::Abort => { + Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort")) + } + }); + + let import_progress = progress.clone().with_filter_map(move |x| match x { + _ => None, + }); + println!("collecting stream"); + let items: Vec<_> = stream.collect().await; + println!("stream collected"); + let stream = futures_lite::stream::iter(items.into_iter()); + let (temp_tag, _len) = self + .inner + .db + .import_stream(stream, BlobFormat::Raw, import_progress) + .await?; + println!("stream imported {:?}", temp_tag.inner().hash); + let hash = temp_tag.inner().hash; + // let tag = self + // .inner + // .temp_tags + // .lock() + // .unwrap() + // .create_one(msg.scope, temp_tag); + progress + .send(BatchAddStreamResponse::Result { hash }) + .await?; + Ok(()) + } + fn blob_add_stream( self, msg: BlobAddStreamRequest, diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 7bfb5d60b3..1117773e0f 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -17,6 +17,7 @@ use iroh_blobs::{ format::collection::Collection, store::{BaoBlobSize, ConsistencyCheckProgress}, util::Tag, + HashAndFormat, }; use iroh_net::{ endpoint::{ConnectionInfo, NodeAddr}, @@ -53,6 +54,33 @@ use crate::{ }; pub use iroh_blobs::util::SetTagOption; +/// Request to create a new scope for temp tags +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchCreateRequest; + +/// Update to a temp tag scope +#[derive(Debug, Serialize, Deserialize)] +pub enum BatchUpdate { + /// Drop of a remote temp tag + Drop(HashAndFormat), +} + +/// Response to a temp tag scope request +#[derive(Debug, Serialize, Deserialize)] +pub enum BatchCreateResponse { + /// We got the id of the scope + Id(u64), +} + +impl Msg for BatchCreateRequest { + type Pattern = BidiStreaming; +} + +impl BidiStreamingMsg for BatchCreateRequest { + type Update = BatchUpdate; + type Response = BatchCreateResponse; +} + /// A request to the node to provide the data at the given path /// /// Will produce a stream of [`AddProgress`] messages. @@ -1015,6 +1043,40 @@ impl BidiStreamingMsg for BlobAddStreamRequest { #[derive(Debug, Serialize, Deserialize, derive_more::Into)] pub struct BlobAddStreamResponse(pub AddProgress); +/// Write a blob from a byte stream +#[derive(Serialize, Deserialize, Debug)] +pub struct BatchAddStreamRequest { + /// What format to use for the blob + pub format: BlobFormat, + /// Scope to create the temp tag in + pub scope: u64, +} + +/// Write a blob from a byte stream +#[derive(Serialize, Deserialize, Debug)] +pub enum BatchAddStreamUpdate { + /// A chunk of stream data + Chunk(Bytes), + /// Abort the request due to an error on the client side + Abort, +} + +impl Msg for BatchAddStreamRequest { + type Pattern = BidiStreaming; +} + +impl BidiStreamingMsg for BatchAddStreamRequest { + type Update = BatchAddStreamUpdate; + type Response = BatchAddStreamResponse; +} + +/// Wrapper around [`AddProgress`]. +#[derive(Debug, Serialize, Deserialize)] +pub enum BatchAddStreamResponse { + Abort(RpcError), + Result { hash: Hash }, +} + /// Get stats for the running Iroh node #[derive(Serialize, Deserialize, Debug)] pub struct NodeStatsRequest {} @@ -1072,6 +1134,11 @@ pub enum Request { CreateCollection(CreateCollectionRequest), BlobGetCollection(BlobGetCollectionRequest), + BatchCreate(BatchCreateRequest), + BatchUpdate(BatchUpdate), + BatchAddStreamRequest(BatchAddStreamRequest), + BatchAddStreamUpdate(BatchAddStreamUpdate), + DeleteTag(DeleteTagRequest), ListTags(ListTagsRequest), @@ -1133,6 +1200,10 @@ pub enum Response { CreateCollection(RpcResult), BlobGetCollection(RpcResult), + BatchCreateResponse(BatchCreateResponse), + BatchRequest(BatchCreateRequest), + BatchAddStream(BatchAddStreamResponse), + ListTags(TagInfo), DeleteTag(RpcResult<()>),