From 6424d0b17b4da9d7db15e9a39d14ef8543bbafb0 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 4 Nov 2024 17:18:45 -0800 Subject: [PATCH] More control for merging stores together (#361) * Add merging methods to rust without commits * lint rust * port to pythonrust bindings (but not to actual python) * Chagne reset to eject changeset bytes * Update python bindings, tests, and examples * mypy * Remove errors from merge method * Clean up mutable references --------- Co-authored-by: Deepak Cherian --- icechunk-python/examples/dask_write.py | 10 +- icechunk-python/python/icechunk/__init__.py | 44 ++++----- .../python/icechunk/_icechunk_python.pyi | 12 +-- icechunk-python/src/lib.rs | 72 +++++++------- .../tests/test_distributed_writers.py | 8 +- icechunk/src/repository.rs | 95 ++++++++----------- icechunk/src/zarr.rs | 58 +++++------ icechunk/tests/test_distributed_writes.rs | 8 +- 8 files changed, 139 insertions(+), 168 deletions(-) diff --git a/icechunk-python/examples/dask_write.py b/icechunk-python/examples/dask_write.py index 7d10e6a0..9f2cd1e2 100644 --- a/icechunk-python/examples/dask_write.py +++ b/icechunk-python/examples/dask_write.py @@ -57,7 +57,7 @@ def generate_task_array(task: Task, shape: tuple[int,...]) -> np.typing.ArrayLik return np.random.rand(*shape) -def execute_write_task(task: Task) -> icechunk.IcechunkStore: +def execute_write_task(task: Task) -> bytes: """Execute task as a write task. This will read the time coordinade from `task` and write a "pancake" in that position, @@ -78,7 +78,7 @@ def execute_write_task(task: Task) -> icechunk.IcechunkStore: data = generate_task_array(task, array.shape[0:2]) array[:, :, task.time] = data dprint(f"Writing at t={task.time} done") - return store + return store.change_set_bytes() def execute_read_task(task: Task) -> None: @@ -188,13 +188,15 @@ def update(args: argparse.Namespace) -> None: client = Client(n_workers=args.workers, threads_per_worker=1) map_result = client.map(execute_write_task, tasks) - worker_stores = client.gather(map_result) + worker_changes = client.gather(map_result) print("Starting distributed commit") # we can use the current store as the commit coordinator, because it doesn't have any pending changes, # all changes come from the tasks, Icechunk doesn't care about where the changes come from, the only # important thing is to not count changes twice - commit_res = store.distributed_commit("distributed commit", [ws.change_set_bytes() for ws in worker_stores]) + for changes in worker_changes: + store.merge(changes) + commit_res = store.commit("distributed commit") assert commit_res print("Distributed commit done") diff --git a/icechunk-python/python/icechunk/__init__.py b/icechunk-python/python/icechunk/__init__.py index 4a90c8e8..1ee677c4 100644 --- a/icechunk-python/python/icechunk/__init__.py +++ b/icechunk-python/python/icechunk/__init__.py @@ -326,11 +326,9 @@ async def async_commit(self, message: str) -> str: * some other writer updated the current branch since the repository was checked out """ return await self._store.async_commit(message) - - def distributed_commit( - self, message: str, other_change_set_bytes: list[bytes] - ) -> str: - """Commit any uncommitted changes to the store with a set of distributed changes. + + def merge(self, changes: bytes) -> None: + """Merge the changes from another store into this store. This will create a new snapshot on the current branch and return the new snapshot id. @@ -340,17 +338,12 @@ def distributed_commit( * there is no currently checked out branch * some other writer updated the current branch since the repository was checked out - other_change_set_bytes must be generated as the output of calling `change_set_bytes` - on other stores. The resulting commit will include changes from all stores. - The behavior is undefined if the stores applied conflicting changes. """ - return self._store.distributed_commit(message, other_change_set_bytes) - - async def async_distributed_commit( - self, message: str, other_change_set_bytes: list[bytes] - ) -> str: - """Commit any uncommitted changes to the store with a set of distributed changes. + return self._store.merge(changes) + + async def async_merge(self, changes: bytes) -> None: + """Merge the changes from another store into this store. This will create a new snapshot on the current branch and return the new snapshot id. @@ -360,24 +353,31 @@ async def async_distributed_commit( * there is no currently checked out branch * some other writer updated the current branch since the repository was checked out - other_change_set_bytes must be generated as the output of calling `change_set_bytes` - on other stores. The resulting commit will include changes from all stores. - The behavior is undefined if the stores applied conflicting changes. """ - return await self._store.async_distributed_commit(message, other_change_set_bytes) + return await self._store.async_merge(changes) @property def has_uncommitted_changes(self) -> bool: """Return True if there are uncommitted changes to the store""" return self._store.has_uncommitted_changes - async def async_reset(self) -> None: - """Discard any uncommitted changes and reset to the previous snapshot state.""" + async def async_reset(self) -> bytes: + """Pop any uncommitted changes and reset to the previous snapshot state. + + Returns + ------- + bytes : The changes that were taken from the working set + """ return await self._store.async_reset() - def reset(self) -> None: - """Discard any uncommitted changes and reset to the previous snapshot state.""" + def reset(self) -> bytes: + """Pop any uncommitted changes and reset to the previous snapshot state. + + Returns + ------- + bytes : The changes that were taken from the working set + """ return self._store.reset() async def async_new_branch(self, branch_name: str) -> str: diff --git a/icechunk-python/python/icechunk/_icechunk_python.pyi b/icechunk-python/python/icechunk/_icechunk_python.pyi index e0a73340..74fd42a0 100644 --- a/icechunk-python/python/icechunk/_icechunk_python.pyi +++ b/icechunk-python/python/icechunk/_icechunk_python.pyi @@ -18,18 +18,14 @@ class PyIcechunkStore: async def async_checkout_branch(self, branch: str) -> None: ... def checkout_tag(self, tag: str) -> None: ... async def async_checkout_tag(self, tag: str) -> None: ... - def distributed_commit( - self, message: str, other_change_set_bytes: list[bytes] - ) -> str: ... - async def async_distributed_commit( - self, message: str, other_change_set_bytes: list[bytes] - ) -> str: ... def commit(self, message: str) -> str: ... async def async_commit(self, message: str) -> str: ... @property def has_uncommitted_changes(self) -> bool: ... - def reset(self) -> None: ... - async def async_reset(self) -> None: ... + def reset(self) -> bytes: ... + async def async_reset(self) -> bytes: ... + def merge(self, changes: bytes) -> None: ... + async def async_merge(self, changes: bytes) -> None: ... def new_branch(self, branch_name: str) -> str: ... async def async_new_branch(self, branch_name: str) -> str: ... def reset_branch(self, snapshot_id: str) -> None: ... diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index 8a2e7eb2..5c77e5a0 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -12,7 +12,7 @@ use futures::{StreamExt, TryStreamExt}; use icechunk::{ format::{manifest::VirtualChunkRef, ChunkLength}, refs::Ref, - repository::VirtualChunkLocation, + repository::{ChangeSet, VirtualChunkLocation}, storage::virtual_ref::ObjectStoreVirtualChunkResolverConfig, zarr::{ ConsolidatedStore, ObjectId, RepositoryConfig, StorageConfig, StoreError, @@ -461,29 +461,23 @@ impl PyIcechunkStore { }) } - fn async_distributed_commit<'py>( - &'py self, + fn async_merge<'py>( + &self, py: Python<'py>, - message: String, - other_change_set_bytes: Vec>, + change_set_bytes: Vec, ) -> PyResult> { let store = Arc::clone(&self.store); pyo3_async_runtimes::tokio::future_into_py(py, async move { - do_distributed_commit(store, message, other_change_set_bytes).await + do_merge(store, change_set_bytes).await }) } - fn distributed_commit<'py>( - &'py self, - py: Python<'py>, - message: String, - other_change_set_bytes: Vec>, - ) -> PyResult> { + fn merge(&self, change_set_bytes: Vec) -> PyIcechunkStoreResult<()> { let store = Arc::clone(&self.store); + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - let res = - do_distributed_commit(store, message, other_change_set_bytes).await?; - Ok(PyString::new_bound(py, res.as_str())) + do_merge(store, change_set_bytes).await?; + Ok(()) }) } @@ -512,17 +506,17 @@ impl PyIcechunkStore { fn async_reset<'py>(&'py self, py: Python<'py>) -> PyResult> { let store = Arc::clone(&self.store); - pyo3_async_runtimes::tokio::future_into_py( - py, - async move { do_reset(store).await }, - ) + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let changes = do_reset(store).await?; + Ok(changes) + }) } - fn reset<'py>(&'py self, py: Python<'py>) -> PyResult> { + fn reset<'py>(&'py self, py: Python<'py>) -> PyResult> { let store = Arc::clone(&self.store); pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - do_reset(store).await?; - Ok(PyNone::get_bound(py).to_owned()) + let changes = do_reset(store).await?; + Ok(PyBytes::new_bound(py, &changes)) }) } @@ -927,7 +921,7 @@ impl PyIcechunkStore { } async fn do_commit(store: Arc>, message: String) -> PyResult { - let mut store = store.write().await; + let store = store.write().await; let oid = store.commit(&message).await.map_err(PyIcechunkStoreError::from)?; Ok(String::from(&oid)) } @@ -968,24 +962,26 @@ async fn do_checkout_tag(store: Arc>, tag: String) -> PyResult<()> Ok(()) } -async fn do_distributed_commit( +async fn do_merge( store: Arc>, - message: String, - other_change_set_bytes: Vec>, -) -> PyResult { - let mut writeable_store = store.write().await; - let oid = writeable_store - .distributed_commit(&message, other_change_set_bytes) - .await - .map_err(PyIcechunkStoreError::from)?; - Ok(String::from(&oid)) -} + other_change_set_bytes: Vec, +) -> PyResult<()> { + let change_set = ChangeSet::import_from_bytes(&other_change_set_bytes) + .map_err(PyIcechunkStoreError::RepositoryError)?; -async fn do_reset<'py>(store: Arc>) -> PyResult<()> { - store.write().await.reset().await.map_err(PyIcechunkStoreError::StoreError)?; + let store = store.write().await; + store.merge(change_set).await; Ok(()) } +async fn do_reset<'py>(store: Arc>) -> PyResult> { + let changes = + store.write().await.reset().await.map_err(PyIcechunkStoreError::StoreError)?; + let serialized_changes = + changes.export_to_bytes().map_err(PyIcechunkStoreError::RepositoryError)?; + Ok(serialized_changes) +} + async fn do_new_branch<'py>( store: Arc>, branch_name: String, @@ -1017,10 +1013,10 @@ async fn do_tag<'py>( tag: String, snapshot_id: String, ) -> PyResult<()> { - let mut writeable_store = store.write().await; + let store = store.read().await; let oid = ObjectId::try_from(snapshot_id.as_str()) .map_err(|e| PyIcechunkStoreError::UnkownError(e.to_string()))?; - writeable_store.tag(&tag, &oid).await.map_err(PyIcechunkStoreError::from)?; + store.tag(&tag, &oid).await.map_err(PyIcechunkStoreError::from)?; Ok(()) } diff --git a/icechunk-python/tests/test_distributed_writers.py b/icechunk-python/tests/test_distributed_writers.py index d38bab35..b0e247f7 100644 --- a/icechunk-python/tests/test_distributed_writers.py +++ b/icechunk-python/tests/test_distributed_writers.py @@ -54,7 +54,7 @@ def generate_task_array(task: Task): return np.random.rand(nx, ny) -async def execute_task(task: Task): +async def execute_task(task: Task) -> bytes: store = mk_store("w", task) group = zarr.group(store=store, overwrite=False) @@ -134,12 +134,14 @@ async def test_distributed_writers(): _first_snap = store.commit("array created") map_result = client.map(run_task, tasks) - change_sets_bytes = client.gather(map_result) + changes = client.gather(map_result) # we can use the current store as the commit coordinator, because it doesn't have any pending changes, # all changes come from the tasks, Icechunk doesn't care about where the changes come from, the only # important thing is to not count changes twice - commit_res = store.distributed_commit("distributed commit", change_sets_bytes) + for change in changes: + store.merge(change) + commit_res = store.commit("distributed commit") assert commit_res # Lets open a new store to verify the results diff --git a/icechunk/src/repository.rs b/icechunk/src/repository.rs index b94f3f86..43dd7a58 100644 --- a/icechunk/src/repository.rs +++ b/icechunk/src/repository.rs @@ -635,26 +635,14 @@ impl Repository { all_chunks(self.storage.as_ref(), &self.change_set, self.snapshot_id()).await } - pub async fn distributed_flush>( - &mut self, - other_change_sets: I, - message: &str, - properties: SnapshotProperties, - ) -> RepositoryResult { - // FIXME: this clone can be avoided - let change_sets = iter::once(self.change_set.clone()).chain(other_change_sets); - let new_snapshot_id = distributed_flush( - self.storage.as_ref(), - change_sets, - self.snapshot_id(), - message, - properties, - ) - .await?; + /// Discard all uncommitted changes and return them as a `ChangeSet` + pub fn discard_changes(&mut self) -> ChangeSet { + std::mem::take(&mut self.change_set) + } - self.snapshot_id = new_snapshot_id.clone(); - self.change_set = ChangeSet::default(); - Ok(new_snapshot_id) + /// Merge a set of `ChangeSet`s into the repository without committing them + pub async fn merge(&mut self, changes: ChangeSet) { + self.change_set.merge(changes); } /// After changes to the repository have been made, this generates and writes to `Storage` the updated datastructures. @@ -669,36 +657,31 @@ impl Repository { message: &str, properties: SnapshotProperties, ) -> RepositoryResult { - self.distributed_flush(iter::empty(), message, properties).await - } + let new_snapshot_id = flush( + self.storage.as_ref(), + &self.change_set, + self.snapshot_id(), + message, + properties, + ) + .await?; - pub async fn commit( - &mut self, - update_branch_name: &str, - message: &str, - properties: Option, - ) -> RepositoryResult { - self.distributed_commit(update_branch_name, iter::empty(), message, properties) - .await + self.snapshot_id = new_snapshot_id.clone(); + self.change_set = ChangeSet::default(); + Ok(new_snapshot_id) } - pub async fn distributed_commit>( + pub async fn commit( &mut self, update_branch_name: &str, - other_change_sets: I, message: &str, properties: Option, ) -> RepositoryResult { let current = fetch_branch_tip(self.storage.as_ref(), update_branch_name).await; + match current { Err(RefError::RefNotFound(_)) => { - self.do_distributed_commit( - update_branch_name, - other_change_sets, - message, - properties, - ) - .await + self.do_commit(update_branch_name, message, properties).await } Err(err) => Err(err.into()), Ok(ref_data) => { @@ -709,29 +692,21 @@ impl Repository { actual_parent: Some(ref_data.snapshot.clone()), }) } else { - self.do_distributed_commit( - update_branch_name, - other_change_sets, - message, - properties, - ) - .await + self.do_commit(update_branch_name, message, properties).await } } } } - async fn do_distributed_commit>( + async fn do_commit( &mut self, update_branch_name: &str, - other_change_sets: I, message: &str, properties: Option, ) -> RepositoryResult { let parent_snapshot = self.snapshot_id.clone(); let properties = properties.unwrap_or_default(); - let new_snapshot = - self.distributed_flush(other_change_sets, message, properties).await?; + let new_snapshot = self.flush(message, properties).await?; match update_branch( self.storage.as_ref(), @@ -750,6 +725,10 @@ impl Repository { } } + pub fn changes(&self) -> &ChangeSet { + &self.change_set + } + pub fn change_set_bytes(&self) -> RepositoryResult> { self.change_set.export_to_bytes() } @@ -922,20 +901,18 @@ async fn get_existing_node<'a>( } } -async fn distributed_flush>( +async fn flush( storage: &(dyn Storage + Send + Sync), - change_sets: I, + change_set: &ChangeSet, parent_id: &SnapshotId, message: &str, properties: SnapshotProperties, ) -> RepositoryResult { - let mut change_set = ChangeSet::default(); - change_set.merge_many(change_sets); if change_set.is_empty() { return Err(RepositoryError::NoChangesToCommit); } - let chunks = all_chunks(storage, &change_set, parent_id) + let chunks = all_chunks(storage, change_set, parent_id) .await? .map_ok(|(_path, chunk_info)| chunk_info); @@ -949,7 +926,7 @@ async fn distributed_flush>( }; let all_nodes = - updated_nodes(storage, &change_set, parent_id, new_manifest_id.as_ref()).await?; + updated_nodes(storage, change_set, parent_id, new_manifest_id.as_ref()).await?; let old_snapshot = storage.fetch_snapshot(parent_id).await?; let mut new_snapshot = Snapshot::from_iter( @@ -1614,6 +1591,14 @@ mod tests { // wo commit to test the case of a chunkless array let _snapshot_id = ds.flush("commit", SnapshotProperties::default()).await?; + let new_new_array_path: Path = "/group/array2".try_into().unwrap(); + ds.add_array(new_new_array_path.clone(), zarr_meta.clone()).await?; + + assert!(ds.has_uncommitted_changes()); + let changes = ds.discard_changes(); + assert!(!changes.is_empty()); + assert!(!ds.has_uncommitted_changes()); + // we set a chunk in a new array ds.set_chunk_ref( new_array_path.clone(), diff --git a/icechunk/src/zarr.rs b/icechunk/src/zarr.rs index 287dfdb7..11368c38 100644 --- a/icechunk/src/zarr.rs +++ b/icechunk/src/zarr.rs @@ -276,6 +276,8 @@ pub enum StoreError { NotFound(#[from] KeyNotFoundError), #[error("unsuccessful repository operation: `{0}`")] RepositoryError(#[from] RepositoryError), + #[error("error merging stores: `{0}`")] + MergeError(String), #[error("unsuccessful ref operation: `{0}`")] RefError(#[from] RefError), #[error("cannot commit when no snapshot is present")] @@ -387,16 +389,8 @@ impl Store { /// Resets the store to the head commit state. If there are any uncommitted changes, they will /// be lost. - pub async fn reset(&mut self) -> StoreResult<()> { - let guard = self.repository.read().await; - // carefully avoid the deadlock if we were to call self.snapshot_id() - let head_snapshot = guard.snapshot_id().clone(); - let storage = Arc::clone(guard.storage()); - let new_repository = Repository::update(storage, head_snapshot).build(); - drop(guard); - self.repository = Arc::new(RwLock::new(new_repository)); - - Ok(()) + pub async fn reset(&mut self) -> StoreResult { + Ok(self.repository.write().await.discard_changes()) } /// Checkout a specific version of the repository. This can be a snapshot id, a tag, or a branch tip. @@ -492,37 +486,29 @@ impl Store { } } + pub async fn merge(&self, changes: ChangeSet) { + self.repository.write().await.merge(changes).await; + } + /// Commit the current changes to the current branch. If the store is not currently /// on a branch, this will return an error. - pub async fn commit(&mut self, message: &str) -> StoreResult { - self.distributed_commit(message, vec![]).await - } + pub async fn commit(&self, message: &str) -> StoreResult { + let Some(branch) = &self.current_branch else { + return Err(StoreError::NotOnBranch); + }; - pub async fn distributed_commit<'a, I: IntoIterator>>( - &mut self, - message: &str, - other_changesets_bytes: I, - ) -> StoreResult { - if let Some(branch) = &self.current_branch { - let other_change_sets: Vec = other_changesets_bytes - .into_iter() - .map(|v| ChangeSet::import_from_bytes(v.as_slice())) - .try_collect()?; - let result = self - .repository - .write() - .await - .deref_mut() - .distributed_commit(branch, other_change_sets, message, None) - .await?; - Ok(result) - } else { - Err(StoreError::NotOnBranch) - } + let result = self + .repository + .write() + .await + .deref_mut() + .commit(branch, message, None) + .await?; + Ok(result) } /// Tag the given snapshot with a specified tag - pub async fn tag(&mut self, tag: &str, snapshot_id: &SnapshotId) -> StoreResult<()> { + pub async fn tag(&self, tag: &str, snapshot_id: &SnapshotId) -> StoreResult<()> { self.repository.write().await.deref_mut().tag(tag, snapshot_id).await?; Ok(()) } @@ -1814,7 +1800,7 @@ mod tests { let storage = Arc::clone(&(in_mem_storage.clone() as Arc)); let ds = Repository::init(Arc::clone(&storage), false).await?.build(); - let mut store = Store::from_repository( + let store = Store::from_repository( ds, AccessMode::ReadWrite, Some("main".to_string()), diff --git a/icechunk/tests/test_distributed_writes.rs b/icechunk/tests/test_distributed_writes.rs index 49cafb2c..943f40ee 100644 --- a/icechunk/tests/test_distributed_writes.rs +++ b/icechunk/tests/test_distributed_writes.rs @@ -178,10 +178,14 @@ async fn test_distributed_writes() -> Result<(), Box