Skip to content

Commit

Permalink
Support type annotations (#204)
Browse files Browse the repository at this point in the history
* Support type annotations

* Add tests

* Allow a Map to be a TypedDict

* Simplify

* Support Doc

* Update publish workflow

* Add documentation
  • Loading branch information
davidbrochart authored Jan 2, 2025
1 parent 47ec580 commit 6a5567e
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 50 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
args: --release --target universal2-apple-darwin --out dist -i 3.9 3.10 3.11 3.12 3.13 pypy3.9 pypy3.10
- name: Test built wheel - universal2
run: |
pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall
pytest
- name: Upload wheels
Expand Down Expand Up @@ -64,7 +64,7 @@ jobs:
args: --release --out dist -i ${{ matrix.platform.interpreter }}
- name: Test built wheel
run: |
pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall
pytest
- name: Upload wheels
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
- name: Test built wheel
if: matrix.target == 'x86_64'
run: |
pip install pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip install pycrdt --no-deps --no-index --find-links dist --force-reinstall
pytest
- name: Upload wheels
Expand Down Expand Up @@ -136,7 +136,7 @@ jobs:
install: |
apt-get update
apt-get install -y --no-install-recommends python3 python3-pip
pip3 install -U pip pytest "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
pip3 install -U pip pytest pytest-mypy-testing "pydantic>=2.5.2,<3" "anyio>=4.4.0,<5" "trio>=0.25.1,<0.26" "exceptiongroup; python_version<'3.11'"
run: |
pip3 install pycrdt --no-deps --no-index --find-links dist/ --force-reinstall
pytest
Expand Down
66 changes: 66 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,69 @@ print(str(text))
```

Undoing a change doesn't remove the change from the document's history, but applies a change that is the opposite of the previous change.

## Type annotations

`Array`, `Map` and `Doc` can be type-annotated for static type analysis. For instance, here is how to declare a `Doc` where all root types are `Array`s of `int`s:

```py
from pycrdt import Array, Doc
doc: Doc[Array[int]] = Doc()
array0: Array[int] = doc.get("array0", type=Array)
array0.append(0)
array0.append("foo") # error: Argument 1 to "append" of "Array" has incompatible type "str"; expected "int" [arg-type]
array1: Array[str] = doc.get("array1", type=Array) # error: Incompatible types in assignment (expression has type "Array[int]", variable has type "Array[str]") [assignment]
```

Trying to append a `str` will result in a type check error. Likewise if trying to get a root type of `Array[str]`.

Like an `Array`, a `Map` can be declared as uniform, i.e. with values of the same type. But it can also be declared as a [TypedDict](https://mypy.readthedocs.io/en/stable/typed_dict.html):

```py
from typing import TypedDict
from pycrdt import Doc, Map
doc: Doc[Map] = Doc()
MyMap = TypedDict(
"MyMap",
{
"name": str,
"toggle": bool,
"nested": Array[bool],
},
)
map0: MyMap = doc.get("map0", type=Map) # type: ignore[assignment]
map0["name"] = "foo"
map0["toggle"] = False
map0["toggle"] = 3 # error: Value of "toggle" has incompatible type "int"; expected "bool"
array0 = Array([1, 2, 3])
map0["nested"] = array0 # error: Value of "nested" has incompatible type "Array[int]"; expected "Array[bool]"
array1 = Array([False, True])
map0["nested"] = array1
v0: str = map0["name"]
v1: str = map0["toggle"] # error: Incompatible types in assignment (expression has type "bool", variable has type "str")
v2: bool = map0["toggle"]
map0["key0"] # error: TypedDict "MyMap@7" has no key "key0"
```

Like a `Map`, a `Doc` can be declared as consisting of uniform root types, or as a `TypedDict`:

```py
from typing import TypedDict
from pycrdt import Doc, Array, Text
MyDoc = TypedDict(
"MyDoc",
{
"text0": Text,
"array0": Array[int],
}
)
doc: MyDoc = Doc() # type: ignore[assignment]
doc["text0"] = Text()
doc["array0"] = Array[bool]() # error: Value of "array0" has incompatible type "Array[bool]"; expected "Array[int]"
doc["array0"] = Array[int]()
```
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
[project.optional-dependencies]
test = [
"pytest >=7.4.2,<8",
"pytest-mypy-testing",
"anyio",
"trio >=0.25.1,<0.27",
"pydantic >=2.5.2,<3",
Expand All @@ -58,14 +59,19 @@ python-source = "python"
module-name = "pycrdt._pycrdt"

[tool.ruff]
exclude = ["tests/test_types.py"]
line-length = 100
lint.select = ["F", "E", "W", "I001"]

[tool.coverage.run]
source = ["python", "tests"]
omit = ["tests/test_types.py"]

[tool.coverage.report]
show_missing = true
exclude_also = [
"if TYPE_CHECKING:"
]

[tool.mypy]
check_untyped_defs = true
46 changes: 30 additions & 16 deletions python/pycrdt/_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload

from ._base import BaseDoc, BaseEvent, BaseType, base_types, event_types
from ._pycrdt import Array as _Array
Expand All @@ -10,18 +10,20 @@
if TYPE_CHECKING:
from ._doc import Doc

T = TypeVar("T")

class Array(BaseType):

class Array(BaseType, Generic[T]):
"""
A collection used to store data in an indexed sequence structure, similar to a Python `list`.
"""

_prelim: list | None
_prelim: list[T] | None
_integrated: _Array | None

def __init__(
self,
init: list | None = None,
init: list[T] | None = None,
*,
_doc: Doc | None = None,
_integrated: _Array | None = None,
Expand All @@ -42,14 +44,14 @@ def __init__(
_integrated=_integrated,
)

def _init(self, value: list[Any] | None) -> None:
def _init(self, value: list[T] | None) -> None:
if value is None:
return
with self.doc.transaction():
for i, v in enumerate(value):
self._set(i, v)

def _set(self, index: int, value: Any) -> None:
def _set(self, index: int, value: T) -> None:
with self.doc.transaction() as txn:
self._forbid_read_transaction(txn)
if isinstance(value, BaseDoc):
Expand Down Expand Up @@ -79,7 +81,7 @@ def __len__(self) -> int:
with self.doc.transaction() as txn:
return self.integrated.len(txn._txn)

def append(self, value: Any) -> None:
def append(self, value: T) -> None:
"""
Appends an item to the array.
Expand All @@ -89,7 +91,7 @@ def append(self, value: Any) -> None:
with self.doc.transaction():
self += [value]

def extend(self, value: list[Any]) -> None:
def extend(self, value: list[T]) -> None:
"""
Extends the array with a list of items.
Expand All @@ -105,7 +107,7 @@ def clear(self) -> None:
"""
del self[:]

def insert(self, index: int, object: Any) -> None:
def insert(self, index: int, object: T) -> None:
"""
Inserts an item at a given index in the array.
Expand All @@ -115,7 +117,7 @@ def insert(self, index: int, object: Any) -> None:
"""
self[index:index] = [object]

def pop(self, index: int = -1) -> Any:
def pop(self, index: int = -1) -> T:
"""
Removes the item at the given index from the array, and returns it.
If no index is passed, removes and returns the last item.
Expand Down Expand Up @@ -148,7 +150,7 @@ def move(self, source_index: int, destination_index: int) -> None:
destination_index = self._check_index(destination_index)
self.integrated.move_to(txn._txn, source_index, destination_index)

def __add__(self, value: list[Any]) -> Array:
def __add__(self, value: list[T]) -> Array[T]:
"""
Extends the array with a list of items:
```py
Expand All @@ -168,7 +170,7 @@ def __add__(self, value: list[Any]) -> Array:
self[length:length] = value
return self

def __radd__(self, value: list[Any]) -> Array:
def __radd__(self, value: list[T]) -> Array[T]:
"""
Prepends a list of items to the array:
```py
Expand All @@ -187,7 +189,13 @@ def __radd__(self, value: list[Any]) -> Array:
self[0:0] = value
return self

def __setitem__(self, key: int | slice, value: Any | list[Any]) -> None:
@overload
def __setitem__(self, key: int, value: T) -> None: ...

@overload
def __setitem__(self, key: slice, value: list[T]) -> None: ...

def __setitem__(self, key, value):
"""
Replaces the item at the given index with a new item:
```py
Expand Down Expand Up @@ -271,7 +279,13 @@ def __delitem__(self, key: int | slice) -> None:
f"Array indices must be integers or slices, not {type(key).__name__}"
)

def __getitem__(self, key: int) -> BaseType:
@overload
def __getitem__(self, key: int) -> T: ...

@overload
def __getitem__(self, key: slice) -> list[T]: ...

def __getitem__(self, key):
"""
Gets the item at the given index:
```py
Expand Down Expand Up @@ -304,7 +318,7 @@ def __iter__(self) -> ArrayIterator:
"""
return ArrayIterator(self)

def __contains__(self, item: Any) -> bool:
def __contains__(self, item: T) -> bool:
"""
Checks if the given item is in the array:
```py
Expand Down Expand Up @@ -333,7 +347,7 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.to_json(txn._txn)

def to_py(self) -> list | None:
def to_py(self) -> list[T] | None:
"""
Recursively converts the array's items to Python objects, and
returns them in a list. If the array was not yet inserted in a document,
Expand Down
1 change: 1 addition & 0 deletions python/pycrdt/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, event: Any, doc: Doc):

def __str__(self):
str_list = []
slot: Any
for slot in self.__slots__:
val = str(getattr(self, slot))
str_list.append(f"{slot}: {val}")
Expand Down
22 changes: 11 additions & 11 deletions python/pycrdt/_doc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from typing import Any, Callable, Iterable, Type, TypeVar, cast
from typing import Any, Callable, Generic, Iterable, Type, TypeVar, cast

from ._base import BaseDoc, BaseType, base_types, forbid_read_transaction
from ._pycrdt import Doc as _Doc
from ._pycrdt import SubdocsEvent, Subscription, TransactionEvent
from ._pycrdt import Transaction as _Transaction
from ._transaction import NewTransaction, ReadTransaction, Transaction

T_BaseType = TypeVar("T_BaseType", bound=BaseType)
T = TypeVar("T", bound=BaseType)


class Doc(BaseDoc):
class Doc(BaseDoc, Generic[T]):
"""
A shared document.
Expand All @@ -23,7 +23,7 @@ class Doc(BaseDoc):

def __init__(
self,
init: dict[str, BaseType] = {},
init: dict[str, T] = {},
*,
client_id: int | None = None,
doc: _Doc | None = None,
Expand Down Expand Up @@ -165,7 +165,7 @@ def apply_update(self, update: bytes) -> None:
assert txn._txn is not None
self._doc.apply_update(txn._txn, update)

def __setitem__(self, key: str, value: BaseType) -> None:
def __setitem__(self, key: str, value: T) -> None:
"""
Sets a document root type:
```py
Expand All @@ -185,7 +185,7 @@ def __setitem__(self, key: str, value: BaseType) -> None:
prelim = value._integrate(self, integrated)
value._init(prelim)

def __getitem__(self, key: str) -> BaseType:
def __getitem__(self, key: str) -> T:
"""
Gets the document root type corresponding to the given key:
```py
Expand All @@ -207,7 +207,7 @@ def __iter__(self) -> Iterable[str]:
"""
return iter(self.keys())

def get(self, key: str, *, type: type[T_BaseType]) -> T_BaseType:
def get(self, key: str, *, type: type[T]) -> T:
"""
Gets the document root type corresponding to the given key.
If it already exists, it will be cast to the given type (if different),
Expand All @@ -230,29 +230,29 @@ def keys(self) -> Iterable[str]:
"""
return self._roots.keys()

def values(self) -> Iterable[BaseType]:
def values(self) -> Iterable[T]:
"""
Returns:
An iterable over the document root types.
"""
return self._roots.values()

def items(self) -> Iterable[tuple[str, BaseType]]:
def items(self) -> Iterable[tuple[str, T]]:
"""
Returns:
An iterable over the key-value pairs of document root types.
"""
return self._roots.items()

@property
def _roots(self) -> dict[str, BaseType]:
def _roots(self) -> dict[str, T]:
with self.transaction() as txn:
assert txn._txn is not None
return {
key: (
None
if val is None
else cast(Type[BaseType], base_types[type(val)])(_integrated=val, _doc=self)
else cast(Type[T], base_types[type(val)])(_integrated=val, _doc=self)
)
for key, val in self._doc.roots(txn._txn).items()
}
Expand Down
Loading

0 comments on commit 6a5567e

Please sign in to comment.