Skip to content

Commit

Permalink
Check origin in nested transactions (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Jul 25, 2024
1 parent 9330409 commit 6dc298b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 5 additions & 0 deletions python/pycrdt/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def client_id(self) -> int:

def transaction(self, origin: Any = None) -> Transaction:
if self._txn is not None:
if origin is not None:
if origin != self._txn.origin:
raise RuntimeError(
"Nested transactions must have same origin as root transaction"
)
return self._txn
return Transaction(self, origin=origin)

Expand Down
11 changes: 6 additions & 5 deletions python/pycrdt/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ class Transaction:
_doc: Doc
_txn: _Transaction | None
_nb: int
_origin_hash: int | None

def __init__(self, doc: Doc, _txn: _Transaction | None = None, *, origin: Any = None) -> None:
self._doc = doc
self._txn = _txn
self._nb = 0
if origin is None:
self._origin = None
self._origin_hash = None
else:
self._origin = hash_origin(origin)
doc._origins[self._origin] = origin
self._origin_hash = hash_origin(origin)
doc._origins[self._origin_hash] = origin

def __enter__(self) -> Transaction:
self._nb += 1
if self._txn is None:
if self._origin is not None:
self._txn = self._doc._doc.create_transaction_with_origin(self._origin)
if self._origin_hash is not None:
self._txn = self._doc._doc.create_transaction_with_origin(self._origin_hash)
else:
self._txn = self._doc._doc.create_transaction()
self._doc._txn = self
Expand Down
9 changes: 9 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def callback(event, txn):
assert len(doc0._origins) == 0
assert len(doc1._origins) == 0

with doc0.transaction(origin=123):
with doc0.transaction(origin=123):
with doc0.transaction():
with pytest.raises(RuntimeError) as excinfo:
with doc0.transaction(origin=456):
pass

assert str(excinfo.value) == "Nested transactions must have same origin as root transaction"


def test_observe_callback_params():
doc = Doc()
Expand Down

0 comments on commit 6dc298b

Please sign in to comment.