Skip to content

Commit

Permalink
Add update API (#148)
Browse files Browse the repository at this point in the history
* Intro Alternative Update API as pycrdt.Update

* Fix testcase

* Add verify merge result string

* Change API

---------

Co-authored-by: David Brochart <[email protected]>
  • Loading branch information
Wh1isper and davidbrochart authored Jul 31, 2024
1 parent b135c24 commit cf03fd3
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 1 deletion.
3 changes: 3 additions & 0 deletions python/pycrdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@
from ._transaction import ReadTransaction as ReadTransaction
from ._transaction import Transaction as Transaction
from ._undo import UndoManager as UndoManager
from ._update import get_state as get_state
from ._update import get_update as get_update
from ._update import merge_updates as merge_updates
4 changes: 4 additions & 0 deletions python/pycrdt/_pycrdt.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,7 @@ class StackItem:
compressed information about all updates and
deletions tracked by it.
"""

def merge_updates(updates: tuple[bytes, ...]) -> bytes: ...
def get_state(update: bytes) -> bytes: ...
def get_update(update: bytes, state: bytes) -> bytes: ...
20 changes: 20 additions & 0 deletions python/pycrdt/_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ._pycrdt import get_state as _get_state
from ._pycrdt import get_update as _get_update
from ._pycrdt import merge_updates as _merge_updates


def get_state(update: bytes) -> bytes:
"""Returns a state from an update."""
return _get_state(update)


def get_update(update: bytes, state: bytes) -> bytes:
"""Returns an update consisting of all changes from a given update which have not
been seen in the given state.
"""
return _get_update(update, state)


def merge_updates(*updates: bytes) -> bytes:
"""Returns an update consisting of a combination of all given updates."""
return _merge_updates(updates)
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod transaction;
mod subscription;
mod type_conversions;
mod undo;
mod update;
use crate::doc::Doc;
use crate::doc::TransactionEvent;
use crate::doc::SubdocsEvent;
Expand All @@ -16,6 +17,7 @@ use crate::map::{Map, MapEvent};
use crate::transaction::Transaction;
use crate::subscription::Subscription;
use crate::undo::{StackItem, UndoManager};
use crate::update::{get_state, get_update, merge_updates};

#[pymodule]
fn _pycrdt(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -32,5 +34,8 @@ fn _pycrdt(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<StackItem>()?;
m.add_class::<Subscription>()?;
m.add_class::<UndoManager>()?;
m.add_function(wrap_pyfunction!(get_state, m)?)?;
m.add_function(wrap_pyfunction!(get_update, m)?)?;
m.add_function(wrap_pyfunction!(merge_updates, m)?)?;
Ok(())
}
37 changes: 37 additions & 0 deletions src/update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use pyo3::prelude::*;
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyBytes, PyTuple};
use yrs::{diff_updates_v1, encode_state_vector_from_update_v1, merge_updates_v1};

#[pyfunction]
pub fn merge_updates(updates: &Bound<'_, PyTuple>) -> PyResult<PyObject> {
let updates: Vec<Vec<u8>> = updates.extract().unwrap();
let Ok(update) = merge_updates_v1(&updates) else {
return Err(PyValueError::new_err("Cannot merge updates"));
};
let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &update).into());
Ok(bytes)
}

#[pyfunction]
pub fn get_state(update: &Bound<'_, PyBytes>) -> PyResult<PyObject> {
let update: &[u8] = update.extract()?;
let Ok(u) = encode_state_vector_from_update_v1(&update) else {
return Err(PyValueError::new_err(
"Cannot encode state vector from update",
));
};
let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &u).into());
Ok(bytes)
}

#[pyfunction]
pub fn get_update(update: &Bound<'_, PyBytes>, state: &Bound<'_, PyBytes>) -> PyResult<PyObject> {
let update: &[u8] = update.extract()?;
let state: &[u8] = state.extract()?;
let Ok(u) = diff_updates_v1(&update, &state) else {
return Err(PyValueError::new_err("Cannot diff updates"));
};
let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &u).into());
Ok(bytes)
}
2 changes: 1 addition & 1 deletion tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def callback(event, txn):
with doc0.transaction():
with pytest.raises(RuntimeError) as excinfo:
with doc0.transaction(origin=456):
pass
pass # pragma: no cover

assert str(excinfo.value) == "Nested transactions must have same origin as root transaction"

Expand Down
36 changes: 36 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pycrdt import Doc, Map, get_state, get_update, merge_updates


def test_update():
data0 = Map({"key0": "val0"})
doc0 = Doc()
doc0["data"] = data0

data1 = Map({"key1": "val1"})
doc1 = Doc()
doc1["data"] = data1

update0 = doc0.get_update()
update1 = doc1.get_update()

del doc0
del doc1
state0 = get_state(update0)
state1 = get_state(update1)

update01 = get_update(update0, state1)
update10 = get_update(update1, state0)

# sync clients
update0 = merge_updates(update0, update10)
update1 = merge_updates(update1, update01)
assert update0 == update1

doc0 = Doc()
data0 = doc0.get("data", type=Map)
doc0.apply_update(update0)
doc1 = Doc()
data1 = doc1.get("data", type=Map)
doc1.apply_update(update1)

assert data0.to_py() == data1.to_py() == {"key0": "val0", "key1": "val1"}

0 comments on commit cf03fd3

Please sign in to comment.