diff --git a/python/pycrdt/_doc.py b/python/pycrdt/_doc.py index 9658723..5ea08e2 100644 --- a/python/pycrdt/_doc.py +++ b/python/pycrdt/_doc.py @@ -36,6 +36,11 @@ def client_id(self) -> int: def transaction(self, origin: Any = None) -> Transaction: if self._txn is not None: + if origin is not None: + if origin != self._txn.origin: + raise RuntimeError( + "Nested transactions must have same origin as root transaction" + ) return self._txn return Transaction(self, origin=origin) diff --git a/python/pycrdt/_transaction.py b/python/pycrdt/_transaction.py index 9d54506..6128efa 100644 --- a/python/pycrdt/_transaction.py +++ b/python/pycrdt/_transaction.py @@ -13,22 +13,23 @@ class Transaction: _doc: Doc _txn: _Transaction | None _nb: int + _origin_hash: int | None def __init__(self, doc: Doc, _txn: _Transaction | None = None, *, origin: Any = None) -> None: self._doc = doc self._txn = _txn self._nb = 0 if origin is None: - self._origin = None + self._origin_hash = None else: - self._origin = hash_origin(origin) - doc._origins[self._origin] = origin + self._origin_hash = hash_origin(origin) + doc._origins[self._origin_hash] = origin def __enter__(self) -> Transaction: self._nb += 1 if self._txn is None: - if self._origin is not None: - self._txn = self._doc._doc.create_transaction_with_origin(self._origin) + if self._origin_hash is not None: + self._txn = self._doc._doc.create_transaction_with_origin(self._origin_hash) else: self._txn = self._doc._doc.create_transaction() self._doc._txn = self diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 93247df..3aff42f 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -89,6 +89,15 @@ def callback(event, txn): assert len(doc0._origins) == 0 assert len(doc1._origins) == 0 + with doc0.transaction(origin=123): + with doc0.transaction(origin=123): + with doc0.transaction(): + with pytest.raises(RuntimeError) as excinfo: + with doc0.transaction(origin=456): + pass + + assert str(excinfo.value) == "Nested transactions must have same origin as root transaction" + def test_observe_callback_params(): doc = Doc()