From cf03fd39507b9c1ad4208c05020888eeeb55de37 Mon Sep 17 00:00:00 2001 From: Zhongsheng Ji <9573586@qq.com> Date: Wed, 31 Jul 2024 22:10:09 +0800 Subject: [PATCH] Add update API (#148) * Intro Alternative Update API as pycrdt.Update * Fix testcase * Add verify merge result string * Change API --------- Co-authored-by: David Brochart --- python/pycrdt/__init__.py | 3 +++ python/pycrdt/_pycrdt.pyi | 4 ++++ python/pycrdt/_update.py | 20 ++++++++++++++++++++ src/lib.rs | 5 +++++ src/update.rs | 37 +++++++++++++++++++++++++++++++++++++ tests/test_transaction.py | 2 +- tests/test_update.py | 36 ++++++++++++++++++++++++++++++++++++ 7 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 python/pycrdt/_update.py create mode 100644 src/update.rs create mode 100644 tests/test_update.py diff --git a/python/pycrdt/__init__.py b/python/pycrdt/__init__.py index 775fca7..94f2bf8 100644 --- a/python/pycrdt/__init__.py +++ b/python/pycrdt/__init__.py @@ -18,3 +18,6 @@ from ._transaction import ReadTransaction as ReadTransaction from ._transaction import Transaction as Transaction from ._undo import UndoManager as UndoManager +from ._update import get_state as get_state +from ._update import get_update as get_update +from ._update import merge_updates as merge_updates diff --git a/python/pycrdt/_pycrdt.pyi b/python/pycrdt/_pycrdt.pyi index 6fd361e..80ef5f4 100644 --- a/python/pycrdt/_pycrdt.pyi +++ b/python/pycrdt/_pycrdt.pyi @@ -221,3 +221,7 @@ class StackItem: compressed information about all updates and deletions tracked by it. """ + +def merge_updates(updates: tuple[bytes, ...]) -> bytes: ... +def get_state(update: bytes) -> bytes: ... +def get_update(update: bytes, state: bytes) -> bytes: ... diff --git a/python/pycrdt/_update.py b/python/pycrdt/_update.py new file mode 100644 index 0000000..91c861a --- /dev/null +++ b/python/pycrdt/_update.py @@ -0,0 +1,20 @@ +from ._pycrdt import get_state as _get_state +from ._pycrdt import get_update as _get_update +from ._pycrdt import merge_updates as _merge_updates + + +def get_state(update: bytes) -> bytes: + """Returns a state from an update.""" + return _get_state(update) + + +def get_update(update: bytes, state: bytes) -> bytes: + """Returns an update consisting of all changes from a given update which have not + been seen in the given state. + """ + return _get_update(update, state) + + +def merge_updates(*updates: bytes) -> bytes: + """Returns an update consisting of a combination of all given updates.""" + return _merge_updates(updates) diff --git a/src/lib.rs b/src/lib.rs index 70e5294..2cb2fc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ mod transaction; mod subscription; mod type_conversions; mod undo; +mod update; use crate::doc::Doc; use crate::doc::TransactionEvent; use crate::doc::SubdocsEvent; @@ -16,6 +17,7 @@ use crate::map::{Map, MapEvent}; use crate::transaction::Transaction; use crate::subscription::Subscription; use crate::undo::{StackItem, UndoManager}; +use crate::update::{get_state, get_update, merge_updates}; #[pymodule] fn _pycrdt(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -32,5 +34,8 @@ fn _pycrdt(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(get_state, m)?)?; + m.add_function(wrap_pyfunction!(get_update, m)?)?; + m.add_function(wrap_pyfunction!(merge_updates, m)?)?; Ok(()) } diff --git a/src/update.rs b/src/update.rs new file mode 100644 index 0000000..faf92e3 --- /dev/null +++ b/src/update.rs @@ -0,0 +1,37 @@ +use pyo3::prelude::*; +use pyo3::exceptions::PyValueError; +use pyo3::types::{PyBytes, PyTuple}; +use yrs::{diff_updates_v1, encode_state_vector_from_update_v1, merge_updates_v1}; + +#[pyfunction] +pub fn merge_updates(updates: &Bound<'_, PyTuple>) -> PyResult { + let updates: Vec> = updates.extract().unwrap(); + let Ok(update) = merge_updates_v1(&updates) else { + return Err(PyValueError::new_err("Cannot merge updates")); + }; + let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &update).into()); + Ok(bytes) +} + +#[pyfunction] +pub fn get_state(update: &Bound<'_, PyBytes>) -> PyResult { + let update: &[u8] = update.extract()?; + let Ok(u) = encode_state_vector_from_update_v1(&update) else { + return Err(PyValueError::new_err( + "Cannot encode state vector from update", + )); + }; + let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &u).into()); + Ok(bytes) +} + +#[pyfunction] +pub fn get_update(update: &Bound<'_, PyBytes>, state: &Bound<'_, PyBytes>) -> PyResult { + let update: &[u8] = update.extract()?; + let state: &[u8] = state.extract()?; + let Ok(u) = diff_updates_v1(&update, &state) else { + return Err(PyValueError::new_err("Cannot diff updates")); + }; + let bytes: PyObject = Python::with_gil(|py| PyBytes::new_bound(py, &u).into()); + Ok(bytes) +} diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 3aff42f..394e560 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -94,7 +94,7 @@ def callback(event, txn): with doc0.transaction(): with pytest.raises(RuntimeError) as excinfo: with doc0.transaction(origin=456): - pass + pass # pragma: no cover assert str(excinfo.value) == "Nested transactions must have same origin as root transaction" diff --git a/tests/test_update.py b/tests/test_update.py new file mode 100644 index 0000000..a892b2e --- /dev/null +++ b/tests/test_update.py @@ -0,0 +1,36 @@ +from pycrdt import Doc, Map, get_state, get_update, merge_updates + + +def test_update(): + data0 = Map({"key0": "val0"}) + doc0 = Doc() + doc0["data"] = data0 + + data1 = Map({"key1": "val1"}) + doc1 = Doc() + doc1["data"] = data1 + + update0 = doc0.get_update() + update1 = doc1.get_update() + + del doc0 + del doc1 + state0 = get_state(update0) + state1 = get_state(update1) + + update01 = get_update(update0, state1) + update10 = get_update(update1, state0) + + # sync clients + update0 = merge_updates(update0, update10) + update1 = merge_updates(update1, update01) + assert update0 == update1 + + doc0 = Doc() + data0 = doc0.get("data", type=Map) + doc0.apply_update(update0) + doc1 = Doc() + data1 = doc1.get("data", type=Map) + doc1.apply_update(update1) + + assert data0.to_py() == data1.to_py() == {"key0": "val0", "key1": "val1"}