diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 09c9f27..1c8fb24 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -31,7 +31,7 @@ jobs: args: --release --target universal2-apple-darwin --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10 - name: Test built wheel - universal2 run: | - pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" + pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall pytest - name: Upload wheels @@ -64,7 +64,7 @@ jobs: args: --release --out dist -i ${{ matrix.platform.interpreter }} - name: Test built wheel run: | - pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" + pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall pytest - name: Upload wheels @@ -100,7 +100,7 @@ jobs: - name: Test built wheel if: matrix.target == 'x86_64' run: | - pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" + pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall pytest - name: Upload wheels @@ -136,7 +136,7 @@ jobs: install: | apt-get update apt-get install -y --no-install-recommends python3 python3-pip - pip3 install -U pip pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" + pip3 install -U pip pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'" run: | pip3 install pycrdt --no-deps --no-index --find-links dist/ --force-reinstall pytest diff --git a/docs/usage.md b/docs/usage.md index 81a17dc..decbefa 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -265,3 +265,69 @@ print(str(text)) ``` Undoing a change doesn't remove the change from the document's history, but applies a change that is the opposite of the previous change. + +## Type annotations + +`Array`, `Map` and `Doc` can be type-annotated for static type analysis. For instance, here is how to declare a `Doc` where all root types are `Array`s of `int`s: + +```py +from pycrdt import Array, Doc + +doc: Doc[Array[int]] = Doc() +array0: Array[int] = doc.get("array0", type=Array) +array0.append(0) +array0.append("foo") # error: Argument 1 to "append" of "Array" has incompatible type "str"; expected "int" [arg-type] +array1: Array[str] = doc.get("array1", type=Array) # error: Incompatible types in assignment (expression has type "Array[int]", variable has type "Array[str]") [assignment] +``` + +Trying to append a `str` will result in a type check error. Likewise if trying to get a root type of `Array[str]`. + +Like an `Array`, a `Map` can be declared as uniform, i.e. with values of the same type. But it can also be declared as a [TypedDict](https://mypy.readthedocs.io/en/stable/typed_dict.html): + +```py +from typing import TypedDict +from pycrdt import Doc, Map + +doc: Doc[Map] = Doc() + +MyMap = TypedDict( + "MyMap", + { + "name": str, + "toggle": bool, + "nested": Array[bool], + }, +) + +map0: MyMap = doc.get("map0", type=Map) # type: ignore[assignment] +map0["name"] = "foo" +map0["toggle"] = False +map0["toggle"] = 3 # error: Value of "toggle" has incompatible type "int"; expected "bool" +array0 = Array([1, 2, 3]) +map0["nested"] = array0 # error: Value of "nested" has incompatible type "Array[int]"; expected "Array[bool]" +array1 = Array([False, True]) +map0["nested"] = array1 +v0: str = map0["name"] +v1: str = map0["toggle"] # error: Incompatible types in assignment (expression has type "bool", variable has type "str") +v2: bool = map0["toggle"] +map0["key0"] # error: TypedDict "MyMap@7" has no key "key0" +``` + +Like a `Map`, a `Doc` can be declared as consisting of uniform root types, or as a `TypedDict`: + +```py +from typing import TypedDict +from pycrdt import Doc, Array, Text + +MyDoc = TypedDict( + "MyDoc", + { + "text0": Text, + "array0": Array[int], + } +) +doc: MyDoc = Doc() # type: ignore[assignment] +doc["text0"] = Text() +doc["array0"] = Array[bool]() # error: Value of "array0" has incompatible type "Array[bool]"; expected "Array[int]" +doc["array0"] = Array[int]() +``` diff --git a/pyproject.toml b/pyproject.toml index 390f087..791c9f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ [project.optional-dependencies] test = [ "pytest >=7.4.2,<8", + "pytest-mypy-testing", "anyio", "trio >=0.25.1,<0.27", "pydantic >=2.5.2,<3", @@ -58,14 +59,19 @@ python-source = "python" module-name = "pycrdt._pycrdt" [tool.ruff] +exclude = ["tests/test_types.py"] line-length = 100 lint.select = ["F", "E", "W", "I001"] [tool.coverage.run] source = ["python", "tests"] +omit = ["tests/test_types.py"] [tool.coverage.report] show_missing = true exclude_also = [ "if TYPE_CHECKING:" ] + +[tool.mypy] +check_untyped_defs = true diff --git a/python/pycrdt/_array.py b/python/pycrdt/_array.py index 3e146e2..9d80918 100644 --- a/python/pycrdt/_array.py +++ b/python/pycrdt/_array.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types from ._pycrdt import Array as _Array @@ -10,18 +10,20 @@ if TYPE_CHECKING: from ._doc import Doc +T = TypeVar("T") -class Array(BaseType): + +class Array(BaseType, Generic[T]): """ A collection used to store data in an indexed sequence structure, similar to a Python `list`. """ - _prelim: list | None + _prelim: list[T] | None _integrated: _Array | None def __init__( self, - init: list | None = None, + init: list[T] | None = None, *, _doc: Doc | None = None, _integrated: _Array | None = None, @@ -42,14 +44,14 @@ def __init__( _integrated=_integrated, ) - def _init(self, value: list[Any] | None) -> None: + def _init(self, value: list[T] | None) -> None: if value is None: return with self.doc.transaction(): for i, v in enumerate(value): self._set(i, v) - def _set(self, index: int, value: Any) -> None: + def _set(self, index: int, value: T) -> None: with self.doc.transaction() as txn: self._forbid_read_transaction(txn) if isinstance(value, BaseDoc): @@ -79,7 +81,7 @@ def __len__(self) -> int: with self.doc.transaction() as txn: return self.integrated.len(txn._txn) - def append(self, value: Any) -> None: + def append(self, value: T) -> None: """ Appends an item to the array. @@ -89,7 +91,7 @@ def append(self, value: Any) -> None: with self.doc.transaction(): self += [value] - def extend(self, value: list[Any]) -> None: + def extend(self, value: list[T]) -> None: """ Extends the array with a list of items. @@ -105,7 +107,7 @@ def clear(self) -> None: """ del self[:] - def insert(self, index: int, object: Any) -> None: + def insert(self, index: int, object: T) -> None: """ Inserts an item at a given index in the array. @@ -115,7 +117,7 @@ def insert(self, index: int, object: Any) -> None: """ self[index:index] = [object] - def pop(self, index: int = -1) -> Any: + def pop(self, index: int = -1) -> T: """ Removes the item at the given index from the array, and returns it. If no index is passed, removes and returns the last item. @@ -148,7 +150,7 @@ def move(self, source_index: int, destination_index: int) -> None: destination_index = self._check_index(destination_index) self.integrated.move_to(txn._txn, source_index, destination_index) - def __add__(self, value: list[Any]) -> Array: + def __add__(self, value: list[T]) -> Array[T]: """ Extends the array with a list of items: ```py @@ -168,7 +170,7 @@ def __add__(self, value: list[Any]) -> Array: self[length:length] = value return self - def __radd__(self, value: list[Any]) -> Array: + def __radd__(self, value: list[T]) -> Array[T]: """ Prepends a list of items to the array: ```py @@ -187,7 +189,13 @@ def __radd__(self, value: list[Any]) -> Array: self[0:0] = value return self - def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None: + @overload + def __setitem__(self, key: int, value: T) -> None: ... + + @overload + def __setitem__(self, key: slice, value: list[T]) -> None: ... + + def __setitem__(self, key, value): """ Replaces the item at the given index with a new item: ```py @@ -271,7 +279,13 @@ def __delitem__(self, key: int | slice) -> None: f"Array indices must be integers or slices, not {type(key).__name__}" ) - def __getitem__(self, key: int) -> BaseType: + @overload + def __getitem__(self, key: int) -> T: ... + + @overload + def __getitem__(self, key: slice) -> list[T]: ... + + def __getitem__(self, key): """ Gets the item at the given index: ```py @@ -304,7 +318,7 @@ def __iter__(self) -> ArrayIterator: """ return ArrayIterator(self) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: T) -> bool: """ Checks if the given item is in the array: ```py @@ -333,7 +347,7 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) - def to_py(self) -> list | None: + def to_py(self) -> list[T] | None: """ Recursively converts the array's items to Python objects, and returns them in a list. If the array was not yet inserted in a document, diff --git a/python/pycrdt/_base.py b/python/pycrdt/_base.py index 3e1c5b0..327c09f 100644 --- a/python/pycrdt/_base.py +++ b/python/pycrdt/_base.py @@ -222,6 +222,7 @@ def __init__(self, event: Any, doc: Doc): def __str__(self): str_list = [] + slot: Any for slot in self.__slots__: val = str(getattr(self, slot)) str_list.append(f"{slot}: {val}") diff --git a/python/pycrdt/_doc.py b/python/pycrdt/_doc.py index 8273af6..cbd6848 100644 --- a/python/pycrdt/_doc.py +++ b/python/pycrdt/_doc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Iterable, Type, TypeVar, cast +from typing import Any, Callable, Generic, Iterable, Type, TypeVar, cast from ._base import BaseDoc, BaseType, base_types, forbid_read_transaction from ._pycrdt import Doc as _Doc @@ -8,10 +8,10 @@ from ._pycrdt import Transaction as _Transaction from ._transaction import NewTransaction, ReadTransaction, Transaction -T_BaseType = TypeVar("T_BaseType", bound=BaseType) +T = TypeVar("T", bound=BaseType) -class Doc(BaseDoc): +class Doc(BaseDoc, Generic[T]): """ A shared document. @@ -23,7 +23,7 @@ class Doc(BaseDoc): def __init__( self, - init: dict[str, BaseType] = {}, + init: dict[str, T] = {}, *, client_id: int | None = None, doc: _Doc | None = None, @@ -165,7 +165,7 @@ def apply_update(self, update: bytes) -> None: assert txn._txn is not None self._doc.apply_update(txn._txn, update) - def __setitem__(self, key: str, value: BaseType) -> None: + def __setitem__(self, key: str, value: T) -> None: """ Sets a document root type: ```py @@ -185,7 +185,7 @@ def __setitem__(self, key: str, value: BaseType) -> None: prelim = value._integrate(self, integrated) value._init(prelim) - def __getitem__(self, key: str) -> BaseType: + def __getitem__(self, key: str) -> T: """ Gets the document root type corresponding to the given key: ```py @@ -207,7 +207,7 @@ def __iter__(self) -> Iterable[str]: """ return iter(self.keys()) - def get(self, key: str, *, type: type[T_BaseType]) -> T_BaseType: + def get(self, key: str, *, type: type[T]) -> T: """ Gets the document root type corresponding to the given key. If it already exists, it will be cast to the given type (if different), @@ -230,14 +230,14 @@ def keys(self) -> Iterable[str]: """ return self._roots.keys() - def values(self) -> Iterable[BaseType]: + def values(self) -> Iterable[T]: """ Returns: An iterable over the document root types. """ return self._roots.values() - def items(self) -> Iterable[tuple[str, BaseType]]: + def items(self) -> Iterable[tuple[str, T]]: """ Returns: An iterable over the key-value pairs of document root types. @@ -245,14 +245,14 @@ def items(self) -> Iterable[tuple[str, BaseType]]: return self._roots.items() @property - def _roots(self) -> dict[str, BaseType]: + def _roots(self) -> dict[str, T]: with self.transaction() as txn: assert txn._txn is not None return { key: ( None if val is None - else cast(Type[BaseType], base_types[type(val)])(_integrated=val, _doc=self) + else cast(Type[T], base_types[type(val)])(_integrated=val, _doc=self) ) for key, val in self._doc.roots(txn._txn).items() } diff --git a/python/pycrdt/_map.py b/python/pycrdt/_map.py index 260e981..b99453e 100644 --- a/python/pycrdt/_map.py +++ b/python/pycrdt/_map.py @@ -1,6 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, cast +from typing import ( + TYPE_CHECKING, + Callable, + Generic, + Iterable, + TypeVar, + cast, + overload, +) from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types from ._pycrdt import Map as _Map @@ -10,8 +18,11 @@ if TYPE_CHECKING: from ._doc import Doc +T = TypeVar("T") +T_DefaultValue = TypeVar("T_DefaultValue") -class Map(BaseType): + +class Map(BaseType, Generic[T]): """ A collection used to store key-value entries in an unordered manner, similar to a Python `dict`. """ @@ -21,7 +32,7 @@ class Map(BaseType): def __init__( self, - init: dict | None = None, + init: dict[str, T] | None = None, *, _doc: Doc | None = None, _integrated: _Map | None = None, @@ -42,14 +53,14 @@ def __init__( _integrated=_integrated, ) - def _init(self, value: dict[str, Any] | None) -> None: + def _init(self, value: dict[str, T] | None) -> None: if value is None: return with self.doc.transaction(): for k, v in value.items(): self._set(k, v) - def _set(self, key: str, value: Any) -> None: + def _set(self, key: str, value: T) -> None: with self.doc.transaction() as txn: self._forbid_read_transaction(txn) if isinstance(value, BaseDoc): @@ -91,7 +102,7 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) - def to_py(self) -> dict | None: + def to_py(self) -> dict[str, T] | None: """ Recursively converts the map's items to Python objects, and returns them in a `dict`. If the map was not yet inserted in a document, @@ -128,7 +139,7 @@ def __delitem__(self, key: str) -> None: self._check_key(key) self.integrated.remove(txn._txn, key) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> T: """ Gets the value at the given key: ```py @@ -143,7 +154,7 @@ def __getitem__(self, key: str) -> Any: self._check_key(key) return self._maybe_as_type_or_doc(self.integrated.get(txn._txn, key)) - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: str, value: T) -> None: """ Sets a value at the given key: ```py @@ -192,24 +203,38 @@ def __contains__(self, item: str) -> bool: """ return item in self.keys() - def get(self, key: str, default_value: Any | None = None) -> Any | None: + @overload + def get(self, key: str) -> T | None: ... + + @overload + def get(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ... + + def get(self, *args): """ Returns the value corresponding to the given key if it exists, otherwise - returns the `default_value`. + returns the default value if passed, or `None`. Args: - key: The key of the value to get. - default_value: The optional default value to return if the key is not found. + args: The key of the value to get, and an optional default value. Returns: - The value at the given key, or the default value. + The value at the given key, or the default value or `None`. """ + key, *default_value = args with self.doc.transaction(): if key in self.keys(): return self[key] - return default_value + if not default_value: + return None + return default_value[0] + + @overload + def pop(self, key: str) -> T: ... + + @overload + def pop(self, key: str, default_value: T_DefaultValue) -> T | T_DefaultValue: ... - def pop(self, *args: Any) -> Any: + def pop(self, *args): """ Removes the entry at the given key from the map, and returns the corresponding value. @@ -231,7 +256,7 @@ def pop(self, *args: Any) -> Any: del self[key] return res - def _check_key(self, key: str): + def _check_key(self, key: str) -> None: if not isinstance(key, str): raise RuntimeError("Key must be of type string") if key not in self.keys(): @@ -245,7 +270,7 @@ def keys(self) -> Iterable[str]: with self.doc.transaction() as txn: return iter(self.integrated.keys(txn._txn)) - def values(self) -> Iterable[Any]: + def values(self) -> Iterable[T]: """ Returns: An iterable over the values of the map. @@ -254,7 +279,7 @@ def values(self) -> Iterable[Any]: for k in self.integrated.keys(txn._txn): yield self[k] - def items(self) -> Iterable[tuple[str, Any]]: + def items(self) -> Iterable[tuple[str, T]]: """ Returns: An iterable over the key-value pairs of the map. @@ -271,7 +296,7 @@ def clear(self) -> None: for k in self.integrated.keys(txn._txn): del self[k] - def update(self, value: dict[str, Any]) -> None: + def update(self, value: dict[str, T]) -> None: """ Sets entries in the map from all entries in the passed `dict`. diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..f39eea1 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,85 @@ +from typing import TypedDict + +import pytest +from pycrdt import Array, Doc, Map, Text + + +@pytest.mark.mypy_testing +def mypy_test_array(): + doc: Doc[Array[int]] = Doc() + array0: Array[int] = doc.get("array0", type=Array) + array0.append(0) + array0.append("foo") # E: Argument 1 to "append" of "Array" has incompatible type "str"; expected "int" [arg-type] + array1: Array[str] = doc.get("array1", type=Array) # E: Incompatible types in assignment (expression has type "Array[int]", variable has type "Array[str]") [assignment] + + +@pytest.mark.mypy_testing +def mypy_test_uniform_map(): + doc: Doc[Map] = Doc() + map0: Map[bool] = doc.get("map0", type=Map) + map0["foo"] = True + map0["foo"] = "bar" # E: Incompatible types in assignment (expression has type "str", target has type "bool") + v0: str = map0.pop("foo") # E: Incompatible types in assignment (expression has type "bool", variable has type "str") + v1: bool = map0.pop("foo") + + +@pytest.mark.mypy_testing +def mypy_test_typed_map(): + doc: Doc[Map] = Doc() + + MyMap = TypedDict( + "MyMap", + { + "name": str, + "toggle": bool, + "nested": Array[bool], + }, + ) + map0: MyMap = doc.get("map0", type=Map) # type: ignore[assignment] + map0["name"] = "foo" + map0["toggle"] = False + map0["toggle"] = 3 # E: Value of "toggle" has incompatible type "int"; expected "bool" + array0 = Array([1, 2, 3]) + map0["nested"] = array0 # E: Value of "nested" has incompatible type "Array[int]"; expected "Array[bool]" + array1 = Array([False, True]) + map0["nested"] = array1 + v0: str = map0["name"] + v1: str = map0["toggle"] # E: Incompatible types in assignment (expression has type "bool", variable has type "str") + v2: bool = map0["toggle"] + map0["key0"] # E: TypedDict "MyMap@30" has no key "key0" + + +@pytest.mark.mypy_testing +def mypy_test_uniform_doc(): + doc: Doc[Text] = Doc() + doc.get("text0", type=Text) + doc.get("array0", type=Array) # E: Argument "type" to "get" of "Doc" has incompatible type "type[pycrdt._array.Array[Any]]"; expected "type[Text]" + doc.get("Map0", type=Map) # E: Argument "type" to "get" of "Doc" has incompatible type "type[pycrdt._map.Map[Any]]"; expected "type[Text]" + + +@pytest.mark.mypy_testing +def mypy_test_typed_doc(): + MyMap = TypedDict( + "MyMap", + { + "name": str, + "toggle": bool, + "nested": Array[bool], + }, + ) + + MyDoc = TypedDict( + "MyDoc", + { + "text0": Text, + "array0": Array[int], + "map0": MyMap, + } + ) + doc: MyDoc = Doc() # type: ignore[assignment] + map0: MyMap = Map() # type: ignore[assignment] + doc["map0"] = map0 + doc["map0"] = Array() # E: Value of "map0" has incompatible type "Array[Never]"; expected "MyMap@62" + doc["text0"] = Text() + doc["array0"] = Array[bool]() # E: Value of "array0" has incompatible type "Array[bool]"; expected "Array[int]" + doc["array0"] = Array()