diff --git a/icechunk-python/python/icechunk/__init__.py b/icechunk-python/python/icechunk/__init__.py index b0a40b27..06552d62 100644 --- a/icechunk-python/python/icechunk/__init__.py +++ b/icechunk-python/python/icechunk/__init__.py @@ -377,6 +377,34 @@ def new_branch(self, branch_name: str) -> str: """ return self._store.new_branch(branch_name) + async def async_reset_branch(self, to_snapshot: str) -> None: + """Reset the currently checked out branch to point to a different snapshot. + + This requires having no uncommitted changes. + + The snapshot id can be obtained as the result of a commit operation, but, more probably, + as the id of one of the SnapshotMetadata objects returned by `ancestry()` + + This operation edits the repository history, it must be executed carefully. + In particular, the current snapshot may end up being inaccessible from any + other branches or tags. + """ + return await self._store.async_reset_branch(to_snapshot) + + def reset_branch(self, to_snapshot: str) -> None: + """Reset the currently checked out branch to point to a different snapshot. + + This requires having no uncommitted changes. + + The snapshot id can be obtained as the result of a commit operation, but, more probably, + as the id of one of the SnapshotMetadata objects returned by `ancestry()` + + This operation edits the repository history, it must be executed carefully. + In particular, the current snapshot may end up being inaccessible from any + other branches or tags. + """ + return self._store.reset_branch(to_snapshot) + def tag(self, tag_name: str, snapshot_id: str) -> None: """Create a tag pointing to the current checked out snapshot.""" return self._store.tag(tag_name, snapshot_id=snapshot_id) diff --git a/icechunk-python/python/icechunk/_icechunk_python.pyi b/icechunk-python/python/icechunk/_icechunk_python.pyi index 2259650e..bb2c424d 100644 --- a/icechunk-python/python/icechunk/_icechunk_python.pyi +++ b/icechunk-python/python/icechunk/_icechunk_python.pyi @@ -31,6 +31,8 @@ class PyIcechunkStore: async def async_reset(self) -> None: ... def new_branch(self, branch_name: str) -> str: ... async def async_new_branch(self, branch_name: str) -> str: ... + def reset_branch(self, snapshot_id: str) -> None: ... + async def async_reset_branch(self, snapshot_id: str) -> None: ... def tag(self, tag: str, snapshot_id: str) -> None: ... async def async_tag(self, tag: str, snapshot_id: str) -> None: ... def ancestry(self) -> list[SnapshotMetadata]: ... diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index 91cc4efb..224be2ff 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -530,6 +530,29 @@ impl PyIcechunkStore { }) } + fn async_reset_branch<'py>( + &'py self, + py: Python<'py>, + to_snapshot: String, + ) -> PyResult> { + let store = Arc::clone(&self.store); + pyo3_asyncio_0_21::tokio::future_into_py(py, async move { + do_reset_branch(store, to_snapshot).await + }) + } + + fn reset_branch<'py>( + &'py self, + py: Python<'py>, + to_snapshot: String, + ) -> PyResult> { + let store = Arc::clone(&self.store); + pyo3_asyncio_0_21::tokio::get_runtime().block_on(async move { + do_reset_branch(store, to_snapshot).await?; + Ok(PyNone::get_bound(py).to_owned()) + }) + } + fn async_tag<'py>( &'py self, py: Python<'py>, @@ -955,6 +978,20 @@ async fn do_new_branch<'py>( Ok(String::from(&oid)) } +async fn do_reset_branch<'py>( + store: Arc>, + to_snapshot: String, +) -> PyResult<()> { + let to_snapshot = ObjectId::try_from(to_snapshot.as_str()) + .map_err(|e| PyIcechunkStoreError::UnkownError(e.to_string()))?; + let mut writeable_store = store.write().await; + writeable_store + .reset_branch(to_snapshot) + .await + .map_err(PyIcechunkStoreError::from)?; + Ok(()) +} + async fn do_tag<'py>( store: Arc>, tag: String, diff --git a/icechunk-python/tests/test_timetravel.py b/icechunk-python/tests/test_timetravel.py index ac5118c8..6f149139 100644 --- a/icechunk-python/tests/test_timetravel.py +++ b/icechunk-python/tests/test_timetravel.py @@ -58,3 +58,27 @@ def test_timetravel(): ] assert sorted(parents, key=lambda p: p.written_at) == list(reversed(parents)) assert len(set([snap.id for snap in parents])) == 4 + + +async def test_branch_reset(): + store = icechunk.IcechunkStore.create( + storage=icechunk.StorageConfig.memory("test"), + config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1), + ) + + group = zarr.group(store=store, overwrite=True) + group.create_group("a") + prev_snapshot_id = store.commit("group a") + group.create_group("b") + store.commit("group b") + + keys = {k async for k in store.list()} + assert "a/zarr.json" in keys + assert "b/zarr.json" in keys + + store.reset_branch(prev_snapshot_id) + + keys = {k async for k in store.list()} + assert "a/zarr.json" in keys + assert "b/zarr.json" not in keys + diff --git a/icechunk/src/repository.rs b/icechunk/src/repository.rs index 55eb1341..f44cb7e5 100644 --- a/icechunk/src/repository.rs +++ b/icechunk/src/repository.rs @@ -265,6 +265,10 @@ impl Repository { } } + pub fn config(&self) -> &RepositoryConfig { + &self.config + } + pub(crate) fn set_snapshot_id(&mut self, snapshot_id: SnapshotId) { self.snapshot_id = snapshot_id; } diff --git a/icechunk/src/zarr.rs b/icechunk/src/zarr.rs index ba62562b..f18a5b92 100644 --- a/icechunk/src/zarr.rs +++ b/icechunk/src/zarr.rs @@ -27,7 +27,7 @@ use crate::{ snapshot::{NodeData, UserAttributesSnapshot}, ByteRange, ChunkOffset, IcechunkFormatError, SnapshotId, }, - refs::{BranchVersion, Ref}, + refs::{update_branch, BranchVersion, Ref, RefError}, repository::{ get_chunk, ArrayShape, ChunkIndices, ChunkKeyEncoding, ChunkPayload, ChunkShape, Codec, DataType, DimensionNames, FillValue, Path, RepositoryError, @@ -275,6 +275,8 @@ pub enum StoreError { NotFound(#[from] KeyNotFoundError), #[error("unsuccessful repository operation: `{0}`")] RepositoryError(#[from] RepositoryError), + #[error("unsuccessful ref operation: `{0}`")] + RefError(#[from] RefError), #[error("cannot commit when no snapshot is present")] NoSnapshot, #[error("all commits must be made on a branch")] @@ -443,6 +445,40 @@ impl Store { Ok((snapshot_id, version)) } + /// Make the current branch point to the given snapshot. + /// This fails if there is uncommitted changes, or if the branch has been updated + /// since checkout. + /// After execution, history of the repo branch will be altered, and the current + /// store will point to a different base snapshot_id + pub async fn reset_branch( + &mut self, + to_snapshot: SnapshotId, + ) -> StoreResult { + // TODO: this should check the snapshot exists + let mut guard = self.repository.write().await; + if guard.has_uncommitted_changes() { + return Err(StoreError::UncommittedChanges); + } + match self.current_branch() { + None => Err(StoreError::NotOnBranch), + Some(branch) => { + let old_snapshot = guard.snapshot_id(); + let storage = guard.storage(); + let overwrite = guard.config().unsafe_overwrite_refs; + let version = update_branch( + storage.as_ref(), + branch.as_str(), + to_snapshot.clone(), + Some(old_snapshot), + overwrite, + ) + .await?; + guard.set_snapshot_id(to_snapshot); + Ok(version) + } + } + } + /// Commit the current changes to the current branch. If the store is not currently /// on a branch, this will return an error. pub async fn commit(&mut self, message: &str) -> StoreResult { @@ -2415,6 +2451,65 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_branch_reset() -> Result<(), Box> { + let storage: Arc = + Arc::new(ObjectStorage::new_in_memory_store(Some("prefix".into()))); + + let mut store = Store::new_from_storage(Arc::clone(&storage)).await?; + + store + .set( + "zarr.json", + Bytes::copy_from_slice(br#"{"zarr_format":3, "node_type":"group"}"#), + ) + .await + .unwrap(); + + store.commit("root group").await.unwrap(); + + store + .set( + "a/zarr.json", + Bytes::copy_from_slice(br#"{"zarr_format":3, "node_type":"group"}"#), + ) + .await + .unwrap(); + + let prev_snap = store.commit("group a").await?; + + store + .set( + "b/zarr.json", + Bytes::copy_from_slice(br#"{"zarr_format":3, "node_type":"group"}"#), + ) + .await + .unwrap(); + + store.commit("group b").await?; + assert!(store.exists("a/zarr.json").await?); + assert!(store.exists("b/zarr.json").await?); + + store.reset_branch(prev_snap).await?; + + assert!(!store.exists("b/zarr.json").await?); + assert!(store.exists("a/zarr.json").await?); + + let (repo, _) = + RepositoryConfig::existing(VersionInfo::BranchTipRef("main".to_string())) + .make_repository(storage) + .await?; + let store = Store::from_repository( + repo, + AccessMode::ReadOnly, + Some("main".to_string()), + None, + ); + assert!(!store.exists("b/zarr.json").await?); + assert!(store.exists("a/zarr.json").await?); + Ok(()) + } + #[tokio::test] async fn test_access_mode() { let storage: Arc =