Skip to content

Commit

Permalink
Support transaction origin (#142)
Browse files Browse the repository at this point in the history
* Support transaction origin

* Add undo manager include_origin and exclude_origin
  • Loading branch information
davidbrochart authored Jul 24, 2024
1 parent 903aed0 commit 938f764
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 15 deletions.
33 changes: 26 additions & 7 deletions python/pycrdt/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import partial
from functools import lru_cache, partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Type, cast

from ._pycrdt import Doc as _Doc
Expand Down Expand Up @@ -151,17 +152,29 @@ def unobserve(self, subscription: Subscription) -> None:
subscription.drop()


def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any):
def observe_callback(
callback: Callable[[], None] | Callable[[Any], None] | Callable[[Any, ReadTransaction], None],
doc: Doc,
event: Any,
):
param_nb = count_parameters(callback)
_event = event_types[type(event)](event, doc)
with doc._read_transaction(event.transaction):
callback(_event)
with doc._read_transaction(event.transaction) as txn:
params = (_event, txn)
callback(*params[:param_nb]) # type: ignore[arg-type]


def observe_deep_callback(callback: Callable[[Any], None], doc: Doc, events: list[Any]):
def observe_deep_callback(
callback: Callable[[], None] | Callable[[Any], None] | Callable[[Any, ReadTransaction], None],
doc: Doc,
events: list[Any],
):
param_nb = count_parameters(callback)
for idx, event in enumerate(events):
events[idx] = event_types[type(event)](event, doc)
with doc._read_transaction(event.transaction):
callback(events)
with doc._read_transaction(event.transaction) as txn:
params = (events, txn)
callback(*params[:param_nb]) # type: ignore[arg-type]


class BaseEvent:
Expand Down Expand Up @@ -199,3 +212,9 @@ def process_event(value: Any, doc: Doc) -> Any:
base_type = cast(Type[BaseType], base_types[val_type])
value = base_type(_integrated=value, _doc=doc)
return value


@lru_cache(maxsize=1024)
def count_parameters(func: Callable) -> int:
"""Count the number of parameters in a callable"""
return len(signature(func).parameters)
6 changes: 3 additions & 3 deletions python/pycrdt/_doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Type, TypeVar, cast
from typing import Any, Callable, Type, TypeVar, cast

from ._base import BaseDoc, BaseType, base_types
from ._pycrdt import Doc as _Doc
Expand Down Expand Up @@ -34,10 +34,10 @@ def guid(self) -> int:
def client_id(self) -> int:
return self._doc.client_id()

def transaction(self) -> Transaction:
def transaction(self, origin: Any = None) -> Transaction:
if self._txn is not None:
return self._txn
return Transaction(self)
return Transaction(self, origin=origin)

def _read_transaction(self, _txn: _Transaction) -> ReadTransaction:
return ReadTransaction(self, _txn)
Expand Down
13 changes: 13 additions & 0 deletions python/pycrdt/_pycrdt.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Doc:
def create_transaction(self) -> Transaction:
"""Create a document transaction."""

def create_transaction_with_origin(self, origin: Any) -> Transaction:
"""Create a document transaction with an origin."""

def get_or_insert_text(self, name: str) -> Text:
"""Create a text root type on this document, or get an existing one."""

Expand Down Expand Up @@ -60,6 +63,10 @@ class Transaction:
def commit(self) -> None:
"""Commit the document changes."""

@property
def origin(self) -> Any:
"""The origin of the transaction."""

class TransactionEvent:
"""Event generated by `Doc.observe` method. Emitted during transaction commit
phase."""
Expand Down Expand Up @@ -182,6 +189,12 @@ class UndoManager:
def expand_scope(self, scope: Text | Array | Map) -> None:
"""Extends a list of shared types tracked by current undo manager by a given scope."""

def include_origin(self, origin: int) -> None:
"""Extends a list of origins tracked by current undo manager by a given origin."""

def exclude_origin(self, origin: int) -> None:
"""Removes an origin from the list of origins tracked by current undo manager."""

def can_undo(self) -> bool:
"""Whether there is any change to undo."""

Expand Down
28 changes: 25 additions & 3 deletions python/pycrdt/_transaction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from types import TracebackType
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from ._pycrdt import Transaction as _Transaction

Expand All @@ -14,15 +14,23 @@ class Transaction:
_txn: _Transaction | None
_nb: int

def __init__(self, doc: Doc, _txn: _Transaction | None = None) -> None:
def __init__(self, doc: Doc, _txn: _Transaction | None = None, *, origin: Any = None) -> None:
self._doc = doc
self._txn = _txn
self._origin = origin
self._nb = 0
if origin is None:
self._origin = None
else:
self._origin = hash_origin(origin)

def __enter__(self) -> Transaction:
self._nb += 1
if self._txn is None:
self._txn = self._doc._doc.create_transaction()
if self._origin is not None:
self._txn = self._doc._doc.create_transaction_with_origin(self._origin)
else:
self._txn = self._doc._doc.create_transaction()
self._doc._txn = self
return self

Expand All @@ -43,6 +51,20 @@ def __exit__(
self._txn = None
self._doc._txn = None

@property
def origin(self) -> int:
if self._txn is None:
raise RuntimeError("No current transaction")

return self._txn.origin()


class ReadTransaction(Transaction):
pass


def hash_origin(origin: Any) -> int:
try:
return hash(origin)
except Exception:
raise TypeError("Origin must be hashable")
9 changes: 8 additions & 1 deletion python/pycrdt/_undo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from ._base import BaseType
from ._pycrdt import (
Expand All @@ -9,6 +9,7 @@
from ._pycrdt import (
UndoManager as _UndoManager,
)
from ._transaction import hash_origin

if TYPE_CHECKING: # pragma: no cover
from ._doc import Doc
Expand Down Expand Up @@ -36,6 +37,12 @@ def expand_scope(self, scope: BaseType) -> None:
method = getattr(self._undo_manager, f"expand_scope_{scope.type_name}")
method(scope._integrated)

def include_origin(self, origin: Any) -> None:
self._undo_manager.include_origin(hash_origin(origin))

def exclude_origin(self, origin: Any) -> None:
self._undo_manager.exclude_origin(hash_origin(origin))

def can_undo(self) -> bool:
return self._undo_manager.can_undo()

Expand Down
6 changes: 6 additions & 0 deletions src/doc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl Doc {
Ok(t)
}

fn create_transaction_with_origin(&self, py: Python<'_>, origin: i128) -> PyResult<Py<Transaction>> {
let txn = self.doc.transact_mut_with(origin);
let t: Py<Transaction> = Py::new(py, Transaction::from(txn))?;
Ok(t)
}

fn get_state(&mut self) -> PyObject {
let txn = self.doc.transact_mut();
let state = txn.state_vector().encode_v1();
Expand Down
13 changes: 12 additions & 1 deletion src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::prelude::*;
use std::cell::{RefCell, RefMut};
use yrs::TransactionMut;
use yrs::{Origin, TransactionMut};

pub enum Cell<'a, T> {
Owned(T),
Expand Down Expand Up @@ -59,4 +59,15 @@ impl Transaction {
pub fn drop(&self) {
self.0.replace(None);
}

pub fn origin(&self) -> Option<i128> {
let transaction = self.0.borrow();
let origin: Option<&Origin> = transaction.as_ref().unwrap().as_ref().origin();
if origin.is_some() {
let data: [u8; 16] = origin.unwrap().as_ref().try_into().expect("Slice with incorrect length");
Some(i128::from_be_bytes(data))
} else {
None
}
}
}
8 changes: 8 additions & 0 deletions src/undo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ impl UndoManager {
self.undo_manager.expand_scope(&scope.map);
}

pub fn include_origin(&mut self, origin: i128) {
self.undo_manager.include_origin(origin);
}

pub fn exclude_origin(&mut self, origin: i128) {
self.undo_manager.exclude_origin(origin);
}

pub fn can_undo(&mut self) -> bool {
self.undo_manager.can_undo()
}
Expand Down
60 changes: 60 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from pycrdt import Array, Doc, Map, Text


Expand Down Expand Up @@ -38,3 +39,62 @@ def callback(event):
{"foo": "bar"},
'{"foo":"bar"}',
]


def test_origin():
doc = Doc()
doc["text"] = text = Text()
origin = None

def callback(event, txn):
nonlocal origin
origin = txn.origin

text.observe(callback)

with doc.transaction(origin=123) as txn:
text += "Hello"

assert origin == 123

with pytest.raises(RuntimeError) as excinfo:
txn.origin()

assert str(excinfo.value) == "No current transaction"

with pytest.raises(TypeError) as excinfo:
doc.transaction(origin={})

assert str(excinfo.value) == "Origin must be hashable"


def test_observe_callback_params():
doc = Doc()
doc["text"] = text = Text()

cb0_called = False
cb1_called = False
cb2_called = False

def callback0():
nonlocal cb0_called
cb0_called = True

def callback1(event):
nonlocal cb1_called
cb1_called = True

def callback2(event, txn):
nonlocal cb2_called
cb2_called = True

text.observe(callback0)
text.observe(callback1)
text.observe(callback2)

with doc.transaction():
text += "Hello, World!"

assert cb0_called
assert cb1_called
assert cb2_called
23 changes: 23 additions & 0 deletions tests/test_undo.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,26 @@ def test_undo_redo_stacks():
undo_manager.undo()
assert len(undo_manager.undo_stack) == 0
assert len(undo_manager.redo_stack) == 2


def test_origin():
doc = Doc()
doc["text"] = text = Text()
undo_manager = UndoManager(scopes=[text], capture_timeout_millis=0)
undo_manager.include_origin(456)
text += "Hello"
assert not undo_manager.can_undo()
with doc.transaction(origin=456):
text += ", World!"
assert str(text) == "Hello, World!"
assert undo_manager.can_undo()
undo_manager.undo()
assert str(text) == "Hello"
assert not undo_manager.can_undo()
undo_manager.exclude_origin(456)
text += ", World!"
assert str(text) == "Hello, World!"
assert undo_manager.can_undo()
undo_manager.undo()
assert str(text) == "Hello"
assert not undo_manager.can_undo()

0 comments on commit 938f764

Please sign in to comment.