From 9ff8f5c9033837b1a50f1eb18f170b530f1d488e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 6 Dec 2024 18:58:34 -0700 Subject: [PATCH] Set store to writeable after creating a branch (#455) * Set store to writeable after creating a branch Closes #439 * Add tests --- icechunk-python/python/icechunk/__init__.py | 11 +++++++---- icechunk-python/tests/test_timetravel.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/icechunk-python/python/icechunk/__init__.py b/icechunk-python/python/icechunk/__init__.py index c6d8addc..8c09399c 100644 --- a/icechunk-python/python/icechunk/__init__.py +++ b/icechunk-python/python/icechunk/__init__.py @@ -242,7 +242,7 @@ def checkout( "only one of snapshot_id, branch, or tag may be specified" ) self._store.checkout_snapshot(snapshot_id) - self._read_only = True + self.set_read_only() return if branch is not None: if tag is not None: @@ -250,11 +250,12 @@ def checkout( "only one of snapshot_id, branch, or tag may be specified" ) self._store.checkout_branch(branch) - self._read_only = True + # We preserve the read-only status here so you can checkout a branch + # on a read-only store, and be guaranteed that you won't modify the store. return if tag is not None: self._store.checkout_tag(tag) - self._read_only = True + self.set_read_only() return raise ValueError("a snapshot_id, branch, or tag must be specified") @@ -385,7 +386,9 @@ def new_branch(self, branch_name: str) -> str: This requires having no uncommitted changes. """ - return self._store.new_branch(branch_name) + ret = self._store.new_branch(branch_name) + self.set_writeable() + return ret async def async_reset_branch(self, to_snapshot: str) -> None: """Reset the currently checked out branch to point to a different snapshot. diff --git a/icechunk-python/tests/test_timetravel.py b/icechunk-python/tests/test_timetravel.py index 594c789d..fe28d13b 100644 --- a/icechunk-python/tests/test_timetravel.py +++ b/icechunk-python/tests/test_timetravel.py @@ -20,6 +20,7 @@ def test_timetravel(): assert air_temp[200, 6] == 42 snapshot_id = store.commit("commit 1") + assert not store._read_only air_temp[:, :] = 54 assert air_temp[200, 6] == 54 @@ -49,12 +50,14 @@ def test_timetravel(): assert air_temp[200, 6] == 54 store.new_branch("feature") + assert not store._read_only assert store.branch == "feature" air_temp[:, :] = 90 feature_snapshot_id = store.commit("commit 3") store.tag("v1.0", feature_snapshot_id) store.checkout(tag="v1.0") + assert store._read_only assert store.branch is None assert air_temp[200, 6] == 90