Skip to content

Commit

Permalink
More control for merging stores together (#361)
Browse files Browse the repository at this point in the history
* Add merging methods to rust without commits

* lint rust

* port to pythonrust bindings (but not to actual python)

* Chagne reset to eject changeset bytes

* Update python bindings, tests, and examples

* mypy

* Remove errors from merge method

* Clean up mutable references

---------

Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
mpiannucci and dcherian authored Nov 5, 2024
1 parent 587090e commit 6424d0b
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 168 deletions.
10 changes: 6 additions & 4 deletions icechunk-python/examples/dask_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def generate_task_array(task: Task, shape: tuple[int,...]) -> np.typing.ArrayLik
return np.random.rand(*shape)


def execute_write_task(task: Task) -> icechunk.IcechunkStore:
def execute_write_task(task: Task) -> bytes:
"""Execute task as a write task.
This will read the time coordinade from `task` and write a "pancake" in that position,
Expand All @@ -78,7 +78,7 @@ def execute_write_task(task: Task) -> icechunk.IcechunkStore:
data = generate_task_array(task, array.shape[0:2])
array[:, :, task.time] = data
dprint(f"Writing at t={task.time} done")
return store
return store.change_set_bytes()


def execute_read_task(task: Task) -> None:
Expand Down Expand Up @@ -188,13 +188,15 @@ def update(args: argparse.Namespace) -> None:
client = Client(n_workers=args.workers, threads_per_worker=1)

map_result = client.map(execute_write_task, tasks)
worker_stores = client.gather(map_result)
worker_changes = client.gather(map_result)

print("Starting distributed commit")
# we can use the current store as the commit coordinator, because it doesn't have any pending changes,
# all changes come from the tasks, Icechunk doesn't care about where the changes come from, the only
# important thing is to not count changes twice
commit_res = store.distributed_commit("distributed commit", [ws.change_set_bytes() for ws in worker_stores])
for changes in worker_changes:
store.merge(changes)
commit_res = store.commit("distributed commit")
assert commit_res
print("Distributed commit done")

Expand Down
44 changes: 22 additions & 22 deletions icechunk-python/python/icechunk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,9 @@ async def async_commit(self, message: str) -> str:
* some other writer updated the current branch since the repository was checked out
"""
return await self._store.async_commit(message)

def distributed_commit(
self, message: str, other_change_set_bytes: list[bytes]
) -> str:
"""Commit any uncommitted changes to the store with a set of distributed changes.

def merge(self, changes: bytes) -> None:
"""Merge the changes from another store into this store.
This will create a new snapshot on the current branch and return
the new snapshot id.
Expand All @@ -340,17 +338,12 @@ def distributed_commit(
* there is no currently checked out branch
* some other writer updated the current branch since the repository was checked out
other_change_set_bytes must be generated as the output of calling `change_set_bytes`
on other stores. The resulting commit will include changes from all stores.
The behavior is undefined if the stores applied conflicting changes.
"""
return self._store.distributed_commit(message, other_change_set_bytes)

async def async_distributed_commit(
self, message: str, other_change_set_bytes: list[bytes]
) -> str:
"""Commit any uncommitted changes to the store with a set of distributed changes.
return self._store.merge(changes)

async def async_merge(self, changes: bytes) -> None:
"""Merge the changes from another store into this store.
This will create a new snapshot on the current branch and return
the new snapshot id.
Expand All @@ -360,24 +353,31 @@ async def async_distributed_commit(
* there is no currently checked out branch
* some other writer updated the current branch since the repository was checked out
other_change_set_bytes must be generated as the output of calling `change_set_bytes`
on other stores. The resulting commit will include changes from all stores.
The behavior is undefined if the stores applied conflicting changes.
"""
return await self._store.async_distributed_commit(message, other_change_set_bytes)
return await self._store.async_merge(changes)

@property
def has_uncommitted_changes(self) -> bool:
"""Return True if there are uncommitted changes to the store"""
return self._store.has_uncommitted_changes

async def async_reset(self) -> None:
"""Discard any uncommitted changes and reset to the previous snapshot state."""
async def async_reset(self) -> bytes:
"""Pop any uncommitted changes and reset to the previous snapshot state.
Returns
-------
bytes : The changes that were taken from the working set
"""
return await self._store.async_reset()

def reset(self) -> None:
"""Discard any uncommitted changes and reset to the previous snapshot state."""
def reset(self) -> bytes:
"""Pop any uncommitted changes and reset to the previous snapshot state.
Returns
-------
bytes : The changes that were taken from the working set
"""
return self._store.reset()

async def async_new_branch(self, branch_name: str) -> str:
Expand Down
12 changes: 4 additions & 8 deletions icechunk-python/python/icechunk/_icechunk_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,14 @@ class PyIcechunkStore:
async def async_checkout_branch(self, branch: str) -> None: ...
def checkout_tag(self, tag: str) -> None: ...
async def async_checkout_tag(self, tag: str) -> None: ...
def distributed_commit(
self, message: str, other_change_set_bytes: list[bytes]
) -> str: ...
async def async_distributed_commit(
self, message: str, other_change_set_bytes: list[bytes]
) -> str: ...
def commit(self, message: str) -> str: ...
async def async_commit(self, message: str) -> str: ...
@property
def has_uncommitted_changes(self) -> bool: ...
def reset(self) -> None: ...
async def async_reset(self) -> None: ...
def reset(self) -> bytes: ...
async def async_reset(self) -> bytes: ...
def merge(self, changes: bytes) -> None: ...
async def async_merge(self, changes: bytes) -> None: ...
def new_branch(self, branch_name: str) -> str: ...
async def async_new_branch(self, branch_name: str) -> str: ...
def reset_branch(self, snapshot_id: str) -> None: ...
Expand Down
72 changes: 34 additions & 38 deletions icechunk-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use futures::{StreamExt, TryStreamExt};
use icechunk::{
format::{manifest::VirtualChunkRef, ChunkLength},
refs::Ref,
repository::VirtualChunkLocation,
repository::{ChangeSet, VirtualChunkLocation},
storage::virtual_ref::ObjectStoreVirtualChunkResolverConfig,
zarr::{
ConsolidatedStore, ObjectId, RepositoryConfig, StorageConfig, StoreError,
Expand Down Expand Up @@ -461,29 +461,23 @@ impl PyIcechunkStore {
})
}

fn async_distributed_commit<'py>(
&'py self,
fn async_merge<'py>(
&self,
py: Python<'py>,
message: String,
other_change_set_bytes: Vec<Vec<u8>>,
change_set_bytes: Vec<u8>,
) -> PyResult<Bound<'py, PyAny>> {
let store = Arc::clone(&self.store);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
do_distributed_commit(store, message, other_change_set_bytes).await
do_merge(store, change_set_bytes).await
})
}

fn distributed_commit<'py>(
&'py self,
py: Python<'py>,
message: String,
other_change_set_bytes: Vec<Vec<u8>>,
) -> PyResult<Bound<'py, PyString>> {
fn merge(&self, change_set_bytes: Vec<u8>) -> PyIcechunkStoreResult<()> {
let store = Arc::clone(&self.store);

pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let res =
do_distributed_commit(store, message, other_change_set_bytes).await?;
Ok(PyString::new_bound(py, res.as_str()))
do_merge(store, change_set_bytes).await?;
Ok(())
})
}

Expand Down Expand Up @@ -512,17 +506,17 @@ impl PyIcechunkStore {

fn async_reset<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let store = Arc::clone(&self.store);
pyo3_async_runtimes::tokio::future_into_py(
py,
async move { do_reset(store).await },
)
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let changes = do_reset(store).await?;
Ok(changes)
})
}

fn reset<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyNone>> {
fn reset<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let store = Arc::clone(&self.store);
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
do_reset(store).await?;
Ok(PyNone::get_bound(py).to_owned())
let changes = do_reset(store).await?;
Ok(PyBytes::new_bound(py, &changes))
})
}

Expand Down Expand Up @@ -927,7 +921,7 @@ impl PyIcechunkStore {
}

async fn do_commit(store: Arc<RwLock<Store>>, message: String) -> PyResult<String> {
let mut store = store.write().await;
let store = store.write().await;
let oid = store.commit(&message).await.map_err(PyIcechunkStoreError::from)?;
Ok(String::from(&oid))
}
Expand Down Expand Up @@ -968,24 +962,26 @@ async fn do_checkout_tag(store: Arc<RwLock<Store>>, tag: String) -> PyResult<()>
Ok(())
}

async fn do_distributed_commit(
async fn do_merge(
store: Arc<RwLock<Store>>,
message: String,
other_change_set_bytes: Vec<Vec<u8>>,
) -> PyResult<String> {
let mut writeable_store = store.write().await;
let oid = writeable_store
.distributed_commit(&message, other_change_set_bytes)
.await
.map_err(PyIcechunkStoreError::from)?;
Ok(String::from(&oid))
}
other_change_set_bytes: Vec<u8>,
) -> PyResult<()> {
let change_set = ChangeSet::import_from_bytes(&other_change_set_bytes)
.map_err(PyIcechunkStoreError::RepositoryError)?;

async fn do_reset<'py>(store: Arc<RwLock<Store>>) -> PyResult<()> {
store.write().await.reset().await.map_err(PyIcechunkStoreError::StoreError)?;
let store = store.write().await;
store.merge(change_set).await;
Ok(())
}

async fn do_reset<'py>(store: Arc<RwLock<Store>>) -> PyResult<Vec<u8>> {
let changes =
store.write().await.reset().await.map_err(PyIcechunkStoreError::StoreError)?;
let serialized_changes =
changes.export_to_bytes().map_err(PyIcechunkStoreError::RepositoryError)?;
Ok(serialized_changes)
}

async fn do_new_branch<'py>(
store: Arc<RwLock<Store>>,
branch_name: String,
Expand Down Expand Up @@ -1017,10 +1013,10 @@ async fn do_tag<'py>(
tag: String,
snapshot_id: String,
) -> PyResult<()> {
let mut writeable_store = store.write().await;
let store = store.read().await;
let oid = ObjectId::try_from(snapshot_id.as_str())
.map_err(|e| PyIcechunkStoreError::UnkownError(e.to_string()))?;
writeable_store.tag(&tag, &oid).await.map_err(PyIcechunkStoreError::from)?;
store.tag(&tag, &oid).await.map_err(PyIcechunkStoreError::from)?;
Ok(())
}

Expand Down
8 changes: 5 additions & 3 deletions icechunk-python/tests/test_distributed_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def generate_task_array(task: Task):
return np.random.rand(nx, ny)


async def execute_task(task: Task):
async def execute_task(task: Task) -> bytes:
store = mk_store("w", task)

group = zarr.group(store=store, overwrite=False)
Expand Down Expand Up @@ -134,12 +134,14 @@ async def test_distributed_writers():
_first_snap = store.commit("array created")

map_result = client.map(run_task, tasks)
change_sets_bytes = client.gather(map_result)
changes = client.gather(map_result)

# we can use the current store as the commit coordinator, because it doesn't have any pending changes,
# all changes come from the tasks, Icechunk doesn't care about where the changes come from, the only
# important thing is to not count changes twice
commit_res = store.distributed_commit("distributed commit", change_sets_bytes)
for change in changes:
store.merge(change)
commit_res = store.commit("distributed commit")
assert commit_res

# Lets open a new store to verify the results
Expand Down
Loading

0 comments on commit 6424d0b

Please sign in to comment.