Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python learns to detect conflicts and rebase #420

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions icechunk-python/python/icechunk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from typing import Any, Self

from icechunk._icechunk_python import (
BasicConflictSolver,
ConflictDetector,
PyIcechunkStore,
S3Credentials,
SnapshotMetadata,
StorageConfig,
StoreConfig,
VersionSelection,
VirtualRefConfig,
__version__,
pyicechunk_store_create,
Expand All @@ -22,11 +25,14 @@
from zarr.core.sync import SyncMixin

__all__ = [
"BasicConflictSolver",
"ConflictDetector",
"IcechunkStore",
"S3Credentials",
"SnapshotMetadata",
"StorageConfig",
"StoreConfig",
"VersionSelection",
"VirtualRefConfig",
"__version__",
]
Expand Down Expand Up @@ -323,9 +329,6 @@ async def async_commit(self, message: str) -> str:
def merge(self, changes: bytes) -> None:
"""Merge the changes from another store into this store.
This will create a new snapshot on the current branch and return
the new snapshot id.
This method will fail if:
* there is no currently checked out branch
Expand All @@ -338,9 +341,6 @@ def merge(self, changes: bytes) -> None:
async def async_merge(self, changes: bytes) -> None:
"""Merge the changes from another store into this store.
This will create a new snapshot on the current branch and return
the new snapshot id.
This method will fail if:
* there is no currently checked out branch
Expand All @@ -350,6 +350,30 @@ async def async_merge(self, changes: bytes) -> None:
"""
return await self._store.async_merge(changes)

def rebase(self, solver: ConflictDetector | BasicConflictSolver) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should write a few paragraphs of design document on this. Two topics would be:

  • how much do we want to expose of the solvers? This methods for example requires adding here when different strategies are created, and doesn't allow for users to extend the mechanisms (which I'd love some day)
  • how do we expose the results of the conflict detector. This is going to be important, I expect most production jobs to at least report to logs the conflicts when they find them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should. For 1: I think we can used subclassing/inheritance to solve this in Python probably.

I wanted to start this PR as a jumping off point, cuz I agree there is a lot to figure out here in terms of the API

"""Rebase the current branch onto the given branch by detecting and optionally
attempting to fix conflicts between the current store and the tip of the branch.
When there are more than one commit between the parent snapshot and the tip of
the branch, `rebase` iterates over all of them, older first, trying to fast-forward.
If at some point it finds a conflict it cannot recover from, `rebase` leaves the
store in a consistent state, that would successfully commit on top
of the latest successfully fast-forwarded commit.
"""
return self._store.rebase(solver)

async def async_rebase(self, solver: ConflictDetector | BasicConflictSolver) -> None:
"""Rebase the current branch onto the given branch by detecting and optionally
attempting to fix conflicts between the current store and the tip of the branch.
When there are more than one commit between the parent snapshot and the tip of
the branch, `rebase` iterates over all of them, older first, trying to fast-forward.
If at some point it finds a conflict it cannot recover from, `rebase` leaves the
store in a consistent state, that would successfully commit on top
of the latest successfully fast-forwarded commit.
"""
return await self._store.async_rebase(solver)

@property
def has_uncommitted_changes(self) -> bool:
"""Return True if there are uncommitted changes to the store"""
Expand Down
54 changes: 54 additions & 0 deletions icechunk-python/python/icechunk/_icechunk_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class PyIcechunkStore:
async def async_reset(self) -> bytes: ...
def merge(self, changes: bytes) -> None: ...
async def async_merge(self, changes: bytes) -> None: ...
def rebase(self, solver: ConflictDetector | BasicConflictSolver) -> None: ...
async def async_rebase(
self, solver: ConflictDetector | BasicConflictSolver
) -> None: ...
def new_branch(self, branch_name: str) -> str: ...
async def async_new_branch(self, branch_name: str) -> str: ...
def reset_branch(self, snapshot_id: str) -> None: ...
Expand Down Expand Up @@ -275,6 +279,56 @@ class StoreConfig:
"""
...

class VersionSelection:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this better than an enum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how Rust enums map unfortunately (i think, I'll confirm)

"""Configuration for selecting a version when performing conflict resolution"""

@classmethod
def use_ours(cls) -> VersionSelection:
"""Select the local version when performing conflict resolution"""
...

@classmethod
def use_theirs(cls) -> VersionSelection:
"""Select the remote version when performing conflict resolution"""
...

@classmethod
def fail(cls) -> VersionSelection:
"""Fail if a conflict is encountered when performing conflict resolution"""
...

class BasicConflictSolver:
"""A basic conflict solver that allows for simple configuration of resolution behavior"""

def __init__(
self,
*,
on_user_attributes_conflict: VersionSelection = VersionSelection.use_ours(),
on_chunk_conflict: VersionSelection = VersionSelection.use_ours(),
fail_on_delete_of_updated_array: bool = False,
fail_on_delete_of_updated_group: bool = False,
) -> BasicConflictSolver:
"""Create a BasicConflictSolver object with the given configuration options
Parameters:
on_user_attributes_conflict: VersionSelection
The behavior to use when a user attribute conflict is encountered, by default VersionSelection.use_ours()
on_chunk_conflict: VersionSelection
The behavior to use when a chunk conflict is encountered, by default VersionSelection.use_theirs()
fail_on_delete_of_updated_array: bool
Whether to fail when a chunk is deleted that has been updated, by default False
fail_on_delete_of_updated_group: bool
Whether to fail when a group is deleted that has been updated, by default False
"""
...

class ConflictDetector:
"""A conflict detector that can be used to detect conflicts between two stores and
report if resolution is possible
"""

def __init__(self) -> ConflictDetector: ...

async def async_pyicechunk_store_exists(storage: StorageConfig) -> bool: ...
def pyicechunk_store_exists(storage: StorageConfig) -> bool: ...
async def async_pyicechunk_store_create(
Expand Down
137 changes: 137 additions & 0 deletions icechunk-python/src/conflicts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use icechunk::conflicts::{basic_solver::{BasicConflictSolver, VersionSelection}, detector::ConflictDetector, ConflictSolver};
use pyo3::{prelude::*, types::PyType};

#[pyclass(name = "VersionSelection")]
#[derive(Clone, Debug)]
pub struct PyVersionSelection(VersionSelection);

impl Default for PyVersionSelection {
fn default() -> Self {
PyVersionSelection(VersionSelection::UseOurs)
}
}

#[pymethods]
impl PyVersionSelection {
#[classmethod]
fn use_ours(_cls: &Bound<'_, PyType>) -> Self {
PyVersionSelection(VersionSelection::UseOurs)
}

#[classmethod]
fn use_theirs(_cls: &Bound<'_, PyType>) -> Self {
PyVersionSelection(VersionSelection::UseTheirs)
}

#[classmethod]
fn fail(_cls: &Bound<'_, PyType>) -> Self {
PyVersionSelection(VersionSelection::Fail)
}
}

impl From<PyVersionSelection> for VersionSelection {
fn from(value: PyVersionSelection) -> Self {
value.0
}
}

impl From<&PyVersionSelection> for VersionSelection {
fn from(value: &PyVersionSelection) -> Self {
value.0.clone()
}
}

impl AsRef<VersionSelection> for PyVersionSelection {
fn as_ref(&self) -> &VersionSelection {
&self.0
}
}

#[pyclass(name = "ConflictDetector")]
#[derive(Clone, Debug)]
pub struct PyConflictDetector(ConflictDetector);

#[pymethods]
impl PyConflictDetector {
#[new]
fn new() -> Self {
PyConflictDetector(ConflictDetector)
}
}

impl From<PyConflictDetector> for ConflictDetector {
fn from(value: PyConflictDetector) -> Self {
value.0
}
}

impl From<&PyConflictDetector> for ConflictDetector {
fn from(value: &PyConflictDetector) -> Self {
value.0.clone()
}
}

impl AsRef<ConflictDetector> for PyConflictDetector {
fn as_ref(&self) -> &ConflictDetector {
&self.0
}
}

#[pyclass(name = "BasicConflictSolver")]
#[derive(Clone, Debug)]
pub struct PyBasicConflictSolver(BasicConflictSolver);

#[pymethods]
impl PyBasicConflictSolver {
#[new]
#[pyo3(signature = (*, on_user_attributes_conflict=PyVersionSelection::default(), on_chunk_conflict=PyVersionSelection::default(), fail_on_delete_of_updated_array = false, fail_on_delete_of_updated_group = false))]
fn new(
on_user_attributes_conflict: PyVersionSelection,
on_chunk_conflict: PyVersionSelection,
fail_on_delete_of_updated_array: bool,
fail_on_delete_of_updated_group: bool,
) -> Self {
PyBasicConflictSolver(BasicConflictSolver {
on_user_attributes_conflict: on_user_attributes_conflict.into(),
on_chunk_conflict: on_chunk_conflict.into(),
fail_on_delete_of_updated_array,
fail_on_delete_of_updated_group,
})
}
}

impl From<PyBasicConflictSolver> for BasicConflictSolver {
fn from(value: PyBasicConflictSolver) -> Self {
value.0
}
}

impl From<&PyBasicConflictSolver> for BasicConflictSolver {
fn from(value: &PyBasicConflictSolver) -> Self {
value.0.clone()
}
}

impl AsRef<BasicConflictSolver> for PyBasicConflictSolver {
fn as_ref(&self) -> &BasicConflictSolver {
&self.0
}
}


#[derive(FromPyObject)]
pub enum PyConflictSolver {
#[pyo3(transparent)]
Detect(PyConflictDetector),
#[pyo3(transparent)]
Basic(PyBasicConflictSolver),
}

impl AsRef<dyn ConflictSolver + 'static> for PyConflictSolver {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks weird. I wonder why you need it. If you do, I think this probably should have a generic timeline instead of static. Something like

impl<'a> AsRef<dyn ConflictSolver + 'a> for .....

fn as_ref(&self) -> &(dyn ConflictSolver + 'static) {
match self {
PyConflictSolver::Detect(detector) => detector.as_ref(),
PyConflictSolver::Basic(solver) => solver.as_ref(),
}
}
}
34 changes: 34 additions & 0 deletions icechunk-python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod conflicts;
mod errors;
mod storage;
mod streams;
Expand All @@ -7,9 +8,11 @@ use std::{borrow::Cow, sync::Arc};
use ::icechunk::{format::ChunkOffset, Store};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use conflicts::{PyBasicConflictSolver, PyConflictDetector, PyConflictSolver, PyVersionSelection};
use errors::{PyIcechunkStoreError, PyIcechunkStoreResult};
use futures::{StreamExt, TryStreamExt};
use icechunk::{
conflicts::ConflictSolver,
format::{manifest::VirtualChunkRef, ChunkLength},
refs::Ref,
repository::{ChangeSet, VirtualChunkLocation},
Expand Down Expand Up @@ -481,6 +484,25 @@ impl PyIcechunkStore {
})
}

fn async_rebase<'py>(
&'py self,
py: Python<'py>,
solver: PyBasicConflictSolver,
) -> PyResult<Bound<'py, PyAny>> {
let store = Arc::clone(&self.store);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
do_rebase(store, solver.as_ref()).await
})
}

fn rebase(&self, solver: PyConflictSolver) -> PyIcechunkStoreResult<()> {
let store = Arc::clone(&self.store);
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
do_rebase(store, solver.as_ref()).await?;
Ok(())
})
}

fn change_set_bytes(&self) -> PyIcechunkStoreResult<Vec<u8>> {
let store = self.store.blocking_read();
let res = pyo3_async_runtimes::tokio::get_runtime()
Expand Down Expand Up @@ -945,6 +967,15 @@ async fn do_merge(
Ok(())
}

async fn do_rebase(
store: Arc<RwLock<Store>>,
solver: &dyn ConflictSolver,
) -> PyResult<()> {
let store = store.read().await;
store.rebase(solver).await.map_err(PyIcechunkStoreError::from)?;
Ok(())
}

async fn do_reset<'py>(store: Arc<RwLock<Store>>) -> PyResult<Vec<u8>> {
let changes =
store.write().await.reset().await.map_err(PyIcechunkStoreError::StoreError)?;
Expand Down Expand Up @@ -1018,6 +1049,9 @@ fn _icechunk_python(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyStoreConfig>()?;
m.add_class::<PySnapshotMetadata>()?;
m.add_class::<PyVirtualRefConfig>()?;
m.add_class::<PyVersionSelection>()?;
m.add_class::<PyConflictDetector>()?;
m.add_class::<PyBasicConflictSolver>()?;
m.add_function(wrap_pyfunction!(pyicechunk_store_exists, m)?)?;
m.add_function(wrap_pyfunction!(async_pyicechunk_store_exists, m)?)?;
m.add_function(wrap_pyfunction!(pyicechunk_store_create, m)?)?;
Expand Down
Loading
Loading