Skip to content

Commit

Permalink
Add serialization to s3 storage
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Dec 9, 2024
1 parent 39b3383 commit 1427488
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 56 deletions.
4 changes: 2 additions & 2 deletions icechunk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ pub mod strategies;
pub mod zarr;

pub use repo::RepositoryConfig;
pub use repository::{Repository, RepositoryBuilder, SnapshotMetadata};
pub use repository::{Repository, RepositoryBuilder};
pub use storage::{MemCachingStorage, ObjectStorage, Storage, StorageError};
pub use zarr::Store;
pub use store::Store;

mod private {
/// Used to seal traits we don't want user code to implement, to maintain compatibility.
Expand Down
20 changes: 16 additions & 4 deletions icechunk/src/repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ use std::{iter, sync::Arc};

use futures::Stream;
use itertools::Either;
use serde::{Deserialize, Serialize};

use crate::{
format::{snapshot::Snapshot, SnapshotId},
format::{snapshot::{Snapshot, SnapshotMetadata}, SnapshotId},
refs::{
create_tag, fetch_branch_tip, fetch_ref, fetch_tag, list_branches, list_tags,
update_branch, BranchVersion, Ref, RefError,
},
repository::{raise_if_invalid_snapshot_id, RepositoryError, RepositoryResult},
session::Session,
storage::virtual_ref::VirtualChunkResolver,
zarr::VersionInfo,
MemCachingStorage, SnapshotMetadata, Storage,
MemCachingStorage, Storage,
};

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct RepositoryConfig {
// Chunks smaller than this will be stored inline in the manifst
pub inline_chunk_threshold_bytes: u16,
Expand All @@ -33,6 +33,18 @@ impl Default for RepositoryConfig {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[non_exhaustive]
pub enum VersionInfo {
#[serde(rename = "snapshot_id")]
SnapshotId(SnapshotId),
#[serde(rename = "tag")]
TagRef(String),
#[serde(rename = "branch")]
BranchTipRef(String),
}


#[derive(Debug)]
pub struct Repository {
config: RepositoryConfig,
Expand Down
3 changes: 2 additions & 1 deletion icechunk/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use aws_sdk_s3::{
primitives::ByteStreamError,
};
use chrono::{DateTime, Utc};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use core::fmt;
use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt};
use std::{ffi::OsString, sync::Arc};
Expand Down Expand Up @@ -84,7 +85,7 @@ const TRANSACTION_PREFIX: &str = "transactions/";
/// Different implementation can cache the files differently, or not at all.
/// Implementations are free to assume files are never overwritten.
#[async_trait]
pub trait Storage: fmt::Debug + private::Sealed {
pub trait Storage<'de>: fmt::Debug + private::Sealed + Serialize + Deserialize<'de> {
async fn fetch_snapshot(&self, id: &SnapshotId) -> StorageResult<Arc<Snapshot>>;
async fn fetch_attributes(
&self,
Expand Down
4 changes: 3 additions & 1 deletion icechunk/src/storage/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use object_store::{
AttributeValue, Attributes, GetOptions, GetRange, ObjectMeta, ObjectStore, PutMode,
PutOptions, PutPayload,
};
use serde::Serialize;
use std::{
fs::create_dir_all,
future::ready,
Expand Down Expand Up @@ -47,8 +48,9 @@ impl From<&ByteRange> for Option<GetRange> {
}
}

#[derive(Debug)]
#[derive(Debug, Serialize)]
pub struct ObjectStorage {
#[serde(skip)]
store: Arc<dyn ObjectStore>,
prefix: String,
// We need this because object_store's local file implementation doesn't sort refs. Since this
Expand Down
76 changes: 47 additions & 29 deletions icechunk/src/storage/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use futures::{
stream::{self, BoxStream},
StreamExt, TryStreamExt,
};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};

use crate::{
format::{
Expand All @@ -43,11 +43,11 @@ use super::{
TRANSACTION_PREFIX,
};

#[derive(Debug)]
#[derive(Debug, Serialize)]
pub struct S3Storage {
#[serde(skip)]
client: Arc<Client>,
prefix: String,
bucket: String,
config: S3Config,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
Expand All @@ -70,23 +70,31 @@ pub enum S3Credentials {
}

#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
pub struct S3Config {
pub struct S3ClientOptions {
pub region: Option<String>,
pub endpoint: Option<String>,
pub credentials: S3Credentials,
pub allow_http: bool,
}

pub async fn mk_client(config: Option<&S3Config>) -> Client {
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
pub struct S3Config {
pub bucket: String,
pub prefix: String,
pub options: Option<S3ClientOptions>,
}

pub async fn mk_client(config: Option<&S3ClientOptions>) -> Client {
let region = config
.as_ref()
.and_then(|c| c.region.as_ref())
.map(|r| RegionProviderChain::first_try(Some(Region::new(r.clone()))))
.unwrap_or_else(RegionProviderChain::default_provider);

let endpoint = config.and_then(|c| c.endpoint.clone());
let allow_http = config.map(|c| c.allow_http).unwrap_or(false);
let endpoint = config.as_ref().and_then(|c| c.endpoint.clone());
let allow_http = config.as_ref().map(|c| c.allow_http).unwrap_or(false);
let credentials =
config.map(|c| c.credentials.clone()).unwrap_or(S3Credentials::FromEnv);
config.as_ref().map(|c| c.credentials.clone()).unwrap_or(S3Credentials::FromEnv);
#[allow(clippy::unwrap_used)]
let app_name = AppName::new("icechunk").unwrap();
let mut aws_config = aws_config::defaults(BehaviorVersion::v2024_03_28())
Expand Down Expand Up @@ -123,17 +131,13 @@ pub async fn mk_client(config: Option<&S3Config>) -> Client {
}

impl S3Storage {
pub async fn new_s3_store(
bucket_name: impl Into<String>,
prefix: impl Into<String>,
config: Option<&S3Config>,
) -> Result<S3Storage, StorageError> {
let client = Arc::new(mk_client(config).await);
Ok(S3Storage { client, prefix: prefix.into(), bucket: bucket_name.into() })
pub async fn new_s3_store(config: &S3Config) -> Result<S3Storage, StorageError> {
let client = Arc::new(mk_client(config.options.as_ref()).await);
Ok(S3Storage { client, config: config.clone() })
}

fn get_path_str(&self, file_prefix: &str, id: &str) -> StorageResult<String> {
let path = PathBuf::from_iter([self.prefix.as_str(), file_prefix, id]);
let path = PathBuf::from_iter([self.config.prefix.as_str(), file_prefix, id]);
path.into_os_string().into_string().map_err(StorageError::BadPrefix)
}

Expand Down Expand Up @@ -163,15 +167,15 @@ impl S3Storage {
}

fn ref_key(&self, ref_key: &str) -> StorageResult<String> {
let path = PathBuf::from_iter([self.prefix.as_str(), REF_PREFIX, ref_key]);
let path = PathBuf::from_iter([self.config.prefix.as_str(), REF_PREFIX, ref_key]);
path.into_os_string().into_string().map_err(StorageError::BadPrefix)
}

async fn get_object(&self, key: &str) -> StorageResult<Bytes> {
Ok(self
.client
.get_object()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.key(key)
.send()
.await?
Expand All @@ -186,7 +190,7 @@ impl S3Storage {
key: &str,
range: &ByteRange,
) -> StorageResult<Bytes> {
let mut b = self.client.get_object().bucket(self.bucket.clone()).key(key);
let mut b = self.client.get_object().bucket(self.config.bucket.clone()).key(key);

if let Some(header) = range_to_header(range) {
b = b.range(header)
Expand All @@ -204,7 +208,7 @@ impl S3Storage {
metadata: I,
bytes: impl Into<ByteStream>,
) -> StorageResult<()> {
let mut b = self.client.put_object().bucket(self.bucket.clone()).key(key);
let mut b = self.client.put_object().bucket(self.config.bucket.clone()).key(key);

if let Some(ct) = content_type {
b = b.content_type(ct)
Expand Down Expand Up @@ -241,7 +245,7 @@ impl S3Storage {
let res = self
.client
.delete_objects()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.delete(delete)
.send()
.await?;
Expand All @@ -263,8 +267,22 @@ pub fn range_to_header(range: &ByteRange) -> Option<String> {

impl private::Sealed for S3Storage {}

impl<'de> Deserialize<'de> for S3Storage {
fn deserialize<D>(deserializer: D) -> Result<S3Storage, D::Error>
where
D: Deserializer<'de>,
{
let config = S3Config::deserialize(deserializer)?;
#[allow(clippy::expect_used)]
let runtime =
tokio::runtime::Runtime::new().expect("Could not create tokio runtime");
let client = Arc::new(runtime.block_on(mk_client(config.options.as_ref())));
Ok(S3Storage { client, config })
}
}

#[async_trait]
impl Storage for S3Storage {
impl Storage<'_> for S3Storage {
async fn fetch_snapshot(&self, id: &SnapshotId) -> StorageResult<Arc<Snapshot>> {
let key = self.get_snapshot_path(id)?;
let bytes = self.get_object(key.as_str()).await?;
Expand Down Expand Up @@ -389,7 +407,7 @@ impl Storage for S3Storage {
let res = self
.client
.get_object()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.key(key.clone())
.send()
.await;
Expand All @@ -413,7 +431,7 @@ impl Storage for S3Storage {
let mut paginator = self
.client
.list_objects_v2()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.prefix(prefix.clone())
.delimiter("/")
.into_paginator()
Expand Down Expand Up @@ -445,7 +463,7 @@ impl Storage for S3Storage {
let mut paginator = self
.client
.list_objects_v2()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.prefix(prefix.clone())
.into_paginator()
.send();
Expand All @@ -471,7 +489,7 @@ impl Storage for S3Storage {
) -> StorageResult<()> {
let key = self.ref_key(ref_key)?;
let mut builder =
self.client.put_object().bucket(self.bucket.clone()).key(key.clone());
self.client.put_object().bucket(self.config.bucket.clone()).key(key.clone());

if !overwrite_refs {
builder = builder.if_none_match("*")
Expand All @@ -498,14 +516,14 @@ impl Storage for S3Storage {
&'a self,
prefix: &str,
) -> StorageResult<BoxStream<'a, StorageResult<ListInfo<String>>>> {
let prefix = PathBuf::from_iter([self.prefix.as_str(), prefix])
let prefix = PathBuf::from_iter([self.config.prefix.as_str(), prefix])
.into_os_string()
.into_string()
.map_err(StorageError::BadPrefix)?;
let stream = self
.client
.list_objects_v2()
.bucket(self.bucket.clone())
.bucket(self.config.bucket.clone())
.prefix(prefix)
.into_paginator()
.send()
Expand Down
6 changes: 3 additions & 3 deletions icechunk/src/storage/virtual_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::fmt::Debug;
use tokio::sync::OnceCell;
use url::{self, Url};

use super::s3::{mk_client, range_to_header, S3Config};
use super::s3::{mk_client, range_to_header, S3ClientOptions};

#[async_trait]
pub trait VirtualChunkResolver: Debug + private::Sealed {
Expand All @@ -25,7 +25,7 @@ pub trait VirtualChunkResolver: Debug + private::Sealed {

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ObjectStoreVirtualChunkResolverConfig {
S3(S3Config),
S3(S3ClientOptions),
}

#[derive(Debug)]
Expand All @@ -45,7 +45,7 @@ impl ObjectStoreVirtualChunkResolver {
.get_or_init(|| async move {
match config.as_ref() {
Some(ObjectStoreVirtualChunkResolverConfig::S3(config)) => {
mk_client(Some(config)).await
mk_client(Some(&config)).await
}
None => mk_client(None).await,
}
Expand Down
Loading

0 comments on commit 1427488

Please sign in to comment.