diff --git a/.gitignore b/.gitignore index 74d79d1..635e439 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,7 @@ example/*.zarr example/.ipynb_checkpoints/* example/data/** __pycache__ + +.venv +.ipynb_checkpoints +dist/ diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..bdd0ecd --- /dev/null +++ b/python/README.md @@ -0,0 +1,14 @@ +# vizarr + +```sh +pip install vizarr +``` + +```python +import vizarr +import numpy as np + +arr = np.random.randint(0, 255, (1024, 1024), dtype=np.uint8) +viewer = vizarr.Viewer(source=arr) +viewer +``` diff --git a/python/deno.json b/python/deno.json new file mode 100644 index 0000000..9052e4b --- /dev/null +++ b/python/deno.json @@ -0,0 +1,22 @@ +{ + "lock": false, + "compilerOptions": { + "checkJs": true, + "allowJs": true, + "lib": [ + "ES2020", + "DOM", + "DOM.Iterable" + ] + }, + "fmt": { + "useTabs": true + }, + "lint": { + "rules": { + "exclude": [ + "prefer-const" + ] + } + } +} diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000..7173da2 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "vizarr" +version = "0.0.0" +dependencies = ["anywidget", "zarr"] + +[project.optional-dependencies] +dev = ["watchfiles", "jupyterlab"] + +# automatically add the dev feature to the default env (e.g., hatch shell) +[tool.hatch.envs.default] +features = ["dev"] diff --git a/python/src/vizarr/__init__.py b/python/src/vizarr/__init__.py new file mode 100644 index 0000000..dcd48e8 --- /dev/null +++ b/python/src/vizarr/__init__.py @@ -0,0 +1,10 @@ +import importlib.metadata + +try: + __version__ = importlib.metadata.version("vizarr") +except importlib.metadata.PackageNotFoundError: + __version__ = "unknown" + +del importlib + +from ._widget import Viewer diff --git a/python/src/vizarr/_widget.js b/python/src/vizarr/_widget.js new file mode 100644 index 0000000..c76fee1 --- /dev/null +++ b/python/src/vizarr/_widget.js @@ -0,0 +1,109 @@ +import * as vizarr from "https://hms-dbmi.github.io/vizarr/index.js"; +import debounce from "https://esm.sh/just-debounce-it@3"; + +/** + * @template T + * @param {import("npm:@anywidget/types").AnyModel} model + * @param {any} payload + * @param {{ timeout?: number }} [options] + * @returns {Promise<{ data: T, buffers: DataView[] }>} + */ +function send(model, payload, { timeout = 3000 } = {}) { + let uuid = globalThis.crypto.randomUUID(); + return new Promise((resolve, reject) => { + let timer = setTimeout(() => { + reject(new Error(`Promise timed out after ${timeout} ms`)); + model.off("msg:custom", handler); + }, timeout); + /** + * @param {{ uuid: string, payload: T }} msg + * @param {DataView[]} buffers + */ + function handler(msg, buffers) { + if (!(msg.uuid === uuid)) return; + clearTimeout(timer); + resolve({ data: msg.payload, buffers }); + model.off("msg:custom", handler); + } + model.on("msg:custom", handler); + model.send({ payload, uuid }); + }); +} + +/** @param {import("npm:@anywidget/types").AnyModel} model */ +function get_source(model) { + let source = model.get("_source"); + if (typeof source === "string") { + return source; + } + // create a python + return { + /** + * @param {string} key + * @return {Promise} + */ + async getItem(key) { + const { data, buffers } = await send(model, { + type: "get", + source_id: source.id, + key, + }); + if (!data.success) { + throw { __zarr__: "KeyError" }; + } + return buffers[0].buffer; + }, + /** + * @param {string} key + * @return {Promise} + */ + async containsItem(key) { + const { data } = await send(model, { + type: "has", + source_id: source.id, + key, + }); + return data; + }, + }; +} + +/** + * @typedef Model + * @property {string | { id: string }} _source + * @property {string} height + * @property {ViewState=} view_state + */ + +/** + * @typedef ViewState + * @property {number} zoom + * @property {[x: number, y: number]} target + */ + +/** @type {import("npm:@anywidget/types").Render} */ +export function render({ model, el }) { + let div = document.createElement("div"); + { + div.style.height = model.get("height"); + div.style.backgroundColor = "black"; + model.on("change:height", () => { + div.style.height = model.get("height"); + }); + } + let viewer = vizarr.createViewer(div); + { + model.on("change:view_state", () => { + viewer.setViewState(model.get("view_state")); + }); + viewer.on( + "viewStateChange", + debounce((/** @type {ViewState} */ update) => { + model.set("view_state", update); + model.save_changes(); + }, 200), + ); + } + viewer.addImage({ source: get_source(model) }); + el.appendChild(div); +} diff --git a/python/src/vizarr/_widget.py b/python/src/vizarr/_widget.py new file mode 100644 index 0000000..d41249f --- /dev/null +++ b/python/src/vizarr/_widget.py @@ -0,0 +1,57 @@ +import anywidget +import traitlets +import pathlib + +import zarr +import numpy as np + +__all__ = ["Viewer"] + +def _store_keyprefix(obj): + # Just grab the store and key_prefix from zarr.Array and zarr.Group objects + if isinstance(obj, (zarr.Array, zarr.Group)): + return obj.store, obj._key_prefix + + if isinstance(obj, np.ndarray): + # Create an in-memory store, and write array as as single chunk + store = {} + arr = zarr.create(store=store, shape=obj.shape, chunks=obj.shape, dtype=obj.dtype) + arr[:] = obj + return store, "" + + if hasattr(obj, "__getitem__") and hasattr(obj, "__contains__"): + return obj, "" + + raise TypeError("Cannot normalize store path") + +class Viewer(anywidget.AnyWidget): + _esm = pathlib.Path(__file__).parent / "_widget.js" + _source = traitlets.Any().tag(sync=True) + view_state = traitlets.Dict().tag(sync=True) + height = traitlets.Unicode("500px").tag(sync=True); + + def __init__(self, source, **kwargs): + self._store_paths = [] + if not isinstance(source, str): + store, key_prefix = _store_keyprefix(source) + source = { "id": len(self._store_paths) } + self._store_paths.append((store, key_prefix)) + super().__init__(_source=source, **kwargs) + self.on_msg(self._handle_custom_msg) + + def _handle_custom_msg(self, msg, buffers): + store, key_prefix = self._store_paths[msg["payload"]["source_id"]] + key = key_prefix + msg["payload"]["key"].lstrip("/") + + if msg["payload"]["type"] == "has": + self.send({ "uuid": msg["uuid"], "payload": key in store }) + print(store) + return + + if msg["payload"]["type"] == "get": + try: + buffers = [store[key]] + except KeyError: + buffers = [] + self.send({ "uuid": msg["uuid"], "payload": { "success": len(buffers) == 1 } }, buffers) + return