Skip to content

Commit

Permalink
WIP add batch API
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed Jun 3, 2024
1 parent 573051c commit b443bec
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 16 deletions.
101 changes: 96 additions & 5 deletions iroh/src/client/blobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,36 @@ use std::{
io,
path::PathBuf,
pin::Pin,
sync::Arc,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use bytes::Bytes;
use futures_lite::{Stream, StreamExt};
use futures_util::SinkExt;
use futures_util::{FutureExt, SinkExt};
use iroh_base::{node_addr::AddrInfoOptions, ticket::BlobTicket};
use iroh_blobs::{
export::ExportProgress as BytesExportProgress,
format::collection::Collection,
get::db::DownloadProgress as BytesDownloadProgress,
store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ValidateProgress},
BlobFormat, Hash, Tag,
util::TagDrop,
BlobFormat, Hash, HashAndFormat, Tag, TempTag,
};
use iroh_net::NodeAddr;
use portable_atomic::{AtomicU64, Ordering};
use quic_rpc::{client::BoxStreamSync, RpcClient, ServiceConnection};
use quic_rpc::{
client::{BoxStreamSync, UpdateSink},
RpcClient, ServiceConnection,
};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
use tokio_util::io::{ReaderStream, StreamReader};
use tracing::warn;

use crate::rpc_protocol::{
BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchUpdate,
BlobAddPathRequest, BlobAddStreamRequest, BlobAddStreamUpdate, BlobConsistencyCheckRequest,
BlobDeleteBlobRequest, BlobDownloadRequest, BlobExportRequest, BlobGetCollectionRequest,
BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest,
Expand Down Expand Up @@ -368,6 +373,92 @@ where
}
}

/// A scope in which blobs can be added.
#[derive(derive_more::Debug)]
struct BatchInner<C: ServiceConnection<RpcService>> {
/// The id of the scope.
id: u64,
/// The rpc client.
rpc: RpcClient<RpcService, C>,
/// The stream to send drop
#[debug(skip)]
updates: Mutex<UpdateSink<RpcService, C, BatchUpdate>>,
}

///
#[derive(derive_more::Debug)]
pub struct Batch<C: ServiceConnection<RpcService>>(Arc<BatchInner<C>>);

impl<C: ServiceConnection<RpcService>> TagDrop for BatchInner<C> {
fn on_drop(&self, content: &HashAndFormat) {
let mut updates = self.updates.lock().unwrap();
updates.send(BatchUpdate::Drop(*content)).now_or_never();
}
}

impl<C: ServiceConnection<RpcService>> Batch<C> {
/// Write a blob by passing bytes.
pub async fn add_bytes(&self, bytes: impl Into<Bytes>, format: BlobFormat) -> Result<TempTag> {
let input = futures_lite::stream::once(Ok(bytes.into()));
self.add_stream(input, format).await
}

/// Write a blob by passing a stream of bytes.
pub async fn add_stream(
&self,
mut input: impl Stream<Item = io::Result<Bytes>> + Send + Unpin + 'static,
format: BlobFormat,
) -> Result<TempTag> {
let (mut sink, mut stream) = self
.0
.rpc
.bidi(BatchAddStreamRequest {
scope: self.0.id,
format,
})
.await?;
while let Some(item) = input.next().await {
match item {
Ok(chunk) => {
sink.send(BatchAddStreamUpdate::Chunk(chunk))
.await
.map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?;
}
Err(err) => {
warn!("Abort send, reason: failed to read from source stream: {err:?}");
sink.send(BatchAddStreamUpdate::Abort)
.await
.map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?;
break;
}
}
}
sink.close()
.await
.map_err(|err| anyhow!("Failed to close the stream: {err:?}"))?;
// this is needed for the remote to notice that the stream is closed
drop(sink);
let mut res = None;
while let Some(item) = stream.next().await {
match item? {
BatchAddStreamResponse::Abort(cause) => {
Err(cause)?;
}
BatchAddStreamResponse::Result { hash } => {
res = Some(hash);
}
}
}
let hash = res.context("Missing answer")?;
let t: Arc<dyn TagDrop> = self.0.clone();
Ok(TempTag::new(
HashAndFormat { hash, format },
Some(Arc::downgrade(&t)),
))
}
}

/// Whether to wrap the added data in a collection.
#[derive(Debug, Serialize, Deserialize)]
pub enum WrapOption {
Expand Down
117 changes: 106 additions & 11 deletions iroh/src/node/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,19 @@ use crate::client::blobs::{
use crate::client::tags::TagInfo;
use crate::client::NodeStatus;
use crate::rpc_protocol::{
BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse,
BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest,
BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobGetCollectionRequest,
BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest,
BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest,
CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest,
DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest,
ListTagsRequest, NodeAddrRequest, NodeConnectionInfoRequest, NodeConnectionInfoResponse,
NodeConnectionsRequest, NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest,
NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest,
NodeWatchResponse, Request, RpcService, SetTagOption,
BatchAddStreamRequest, BatchAddStreamResponse, BatchAddStreamUpdate, BatchCreateRequest,
BatchCreateResponse, BatchUpdate, BlobAddPathRequest, BlobAddPathResponse,
BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest,
BlobDeleteBlobRequest, BlobDownloadRequest, BlobDownloadResponse, BlobExportRequest,
BlobExportResponse, BlobGetCollectionRequest, BlobGetCollectionResponse,
BlobListCollectionsRequest, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest,
BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse,
DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, DocImportFileRequest,
DocImportFileResponse, DocSetHashRequest, ListTagsRequest, NodeAddrRequest,
NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest,
NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, NodeShutdownRequest,
NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, NodeWatchResponse,
Request, RpcService, SetTagOption,
};

use super::NodeInner;
Expand Down Expand Up @@ -99,6 +101,11 @@ impl<D: BaoStore> Handler<D> {
}
CreateCollection(msg) => chan.rpc(msg, handler, Self::create_collection).await,
BlobGetCollection(msg) => chan.rpc(msg, handler, Self::blob_get_collection).await,
BatchAddStreamRequest(msg) => {
chan.bidi_streaming(msg, handler, Self::batch_add_stream)
.await
}
BatchAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage),
ListTags(msg) => {
chan.server_streaming(msg, handler, Self::blob_list_tags)
.await
Expand Down Expand Up @@ -131,6 +138,8 @@ impl<D: BaoStore> Handler<D> {
.await
}
BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage),
BatchCreate(msg) => chan.bidi_streaming(msg, handler, Self::batch_create).await,
BatchUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage),
AuthorList(msg) => {
chan.server_streaming(msg, handler, |handler, req| {
handler.inner.sync.author_list(req)
Expand Down Expand Up @@ -840,6 +849,92 @@ impl<D: BaoStore> Handler<D> {
})
}

fn batch_create(
self,
_: BatchCreateRequest,
mut updates: impl Stream<Item = BatchUpdate> + Send + Unpin + 'static,
) -> impl Stream<Item = BatchCreateResponse> {
let scope_id = 0;
// let scope_id = self.inner.temp_tags.lock().unwrap().create();
tokio::spawn(async move {
while let Some(item) = updates.next().await {
match item {
BatchUpdate::Drop(content) => {
// println!("dropping tag {} {}", scope_id, tag_id);
// self.inner
// .temp_tags
// .lock()
// .unwrap()
// .remove_one(scope_id, tag_id);
}
}
}
println!("dropping scope {}", scope_id);
// self.inner.temp_tags.lock().unwrap().remove(scope_id);
});
futures_lite::stream::once(BatchCreateResponse::Id(scope_id))
}

fn batch_add_stream(
self,
msg: BatchAddStreamRequest,
stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
) -> impl Stream<Item = BatchAddStreamResponse> {
let (tx, rx) = flume::bounded(32);
let this = self.clone();

self.rt().spawn_pinned(|| async move {
if let Err(err) = this.batch_add_stream0(msg, stream, tx.clone()).await {
tx.send_async(BatchAddStreamResponse::Abort(err.into()))
.await
.ok();
}
});
rx.into_stream()
}

async fn batch_add_stream0(
self,
msg: BatchAddStreamRequest,
stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
progress: flume::Sender<BatchAddStreamResponse>,
) -> anyhow::Result<()> {
println!("batch_add_stream0");
let progress = FlumeProgressSender::new(progress);

let stream = stream.map(|item| match item {
BatchAddStreamUpdate::Chunk(chunk) => Ok(chunk),
BatchAddStreamUpdate::Abort => {
Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort"))
}
});

let import_progress = progress.clone().with_filter_map(move |x| match x {
_ => None,
});
println!("collecting stream");
let items: Vec<_> = stream.collect().await;
println!("stream collected");
let stream = futures_lite::stream::iter(items.into_iter());
let (temp_tag, _len) = self
.inner
.db
.import_stream(stream, BlobFormat::Raw, import_progress)
.await?;
println!("stream imported {:?}", temp_tag.inner().hash);
let hash = temp_tag.inner().hash;
// let tag = self
// .inner
// .temp_tags
// .lock()
// .unwrap()
// .create_one(msg.scope, temp_tag);
progress
.send(BatchAddStreamResponse::Result { hash })
.await?;
Ok(())
}

fn blob_add_stream(
self,
msg: BlobAddStreamRequest,
Expand Down
71 changes: 71 additions & 0 deletions iroh/src/rpc_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use iroh_blobs::{
format::collection::Collection,
store::{BaoBlobSize, ConsistencyCheckProgress},
util::Tag,
HashAndFormat,
};
use iroh_net::{
endpoint::{ConnectionInfo, NodeAddr},
Expand Down Expand Up @@ -53,6 +54,33 @@ use crate::{
};
pub use iroh_blobs::util::SetTagOption;

/// Request to create a new scope for temp tags
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchCreateRequest;

/// Update to a temp tag scope
#[derive(Debug, Serialize, Deserialize)]
pub enum BatchUpdate {
/// Drop of a remote temp tag
Drop(HashAndFormat),
}

/// Response to a temp tag scope request
#[derive(Debug, Serialize, Deserialize)]
pub enum BatchCreateResponse {
/// We got the id of the scope
Id(u64),
}

impl Msg<RpcService> for BatchCreateRequest {
type Pattern = BidiStreaming;
}

impl BidiStreamingMsg<RpcService> for BatchCreateRequest {
type Update = BatchUpdate;
type Response = BatchCreateResponse;
}

/// A request to the node to provide the data at the given path
///
/// Will produce a stream of [`AddProgress`] messages.
Expand Down Expand Up @@ -1015,6 +1043,40 @@ impl BidiStreamingMsg<RpcService> for BlobAddStreamRequest {
#[derive(Debug, Serialize, Deserialize, derive_more::Into)]
pub struct BlobAddStreamResponse(pub AddProgress);

/// Write a blob from a byte stream
#[derive(Serialize, Deserialize, Debug)]
pub struct BatchAddStreamRequest {
/// What format to use for the blob
pub format: BlobFormat,
/// Scope to create the temp tag in
pub scope: u64,
}

/// Write a blob from a byte stream
#[derive(Serialize, Deserialize, Debug)]
pub enum BatchAddStreamUpdate {
/// A chunk of stream data
Chunk(Bytes),
/// Abort the request due to an error on the client side
Abort,
}

impl Msg<RpcService> for BatchAddStreamRequest {
type Pattern = BidiStreaming;
}

impl BidiStreamingMsg<RpcService> for BatchAddStreamRequest {
type Update = BatchAddStreamUpdate;
type Response = BatchAddStreamResponse;
}

/// Wrapper around [`AddProgress`].
#[derive(Debug, Serialize, Deserialize)]
pub enum BatchAddStreamResponse {
Abort(RpcError),
Result { hash: Hash },
}

/// Get stats for the running Iroh node
#[derive(Serialize, Deserialize, Debug)]
pub struct NodeStatsRequest {}
Expand Down Expand Up @@ -1072,6 +1134,11 @@ pub enum Request {
CreateCollection(CreateCollectionRequest),
BlobGetCollection(BlobGetCollectionRequest),

BatchCreate(BatchCreateRequest),
BatchUpdate(BatchUpdate),
BatchAddStreamRequest(BatchAddStreamRequest),
BatchAddStreamUpdate(BatchAddStreamUpdate),

DeleteTag(DeleteTagRequest),
ListTags(ListTagsRequest),

Expand Down Expand Up @@ -1133,6 +1200,10 @@ pub enum Response {
CreateCollection(RpcResult<CreateCollectionResponse>),
BlobGetCollection(RpcResult<BlobGetCollectionResponse>),

BatchCreateResponse(BatchCreateResponse),
BatchRequest(BatchCreateRequest),
BatchAddStream(BatchAddStreamResponse),

ListTags(TagInfo),
DeleteTag(RpcResult<()>),

Expand Down

0 comments on commit b443bec

Please sign in to comment.