Skip to content

Commit

Permalink
feat: Add Jupyter Widget
Browse files Browse the repository at this point in the history
  • Loading branch information
manzt committed Oct 19, 2023
1 parent c5a1091 commit f532c5e
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ example/*.zarr
example/.ipynb_checkpoints/*
example/data/**
__pycache__

.venv
.ipynb_checkpoints
dist/
14 changes: 14 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# vizarr

```sh
pip install vizarr
```

```python
import vizarr
import numpy as np

arr = np.random.randint(0, 255, (1024, 1024), dtype=np.uint8)
viewer = vizarr.Viewer(source=arr)
viewer
```
22 changes: 22 additions & 0 deletions python/deno.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"lock": false,
"compilerOptions": {
"checkJs": true,
"allowJs": true,
"lib": [
"ES2020",
"DOM",
"DOM.Iterable"
]
},
"fmt": {
"useTabs": true
},
"lint": {
"rules": {
"exclude": [
"prefer-const"
]
}
}
}
15 changes: 15 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "vizarr"
version = "0.0.0"
dependencies = ["anywidget", "zarr"]

[project.optional-dependencies]
dev = ["watchfiles", "jupyterlab"]

# automatically add the dev feature to the default env (e.g., hatch shell)
[tool.hatch.envs.default]
features = ["dev"]
10 changes: 10 additions & 0 deletions python/src/vizarr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import importlib.metadata

try:
__version__ = importlib.metadata.version("vizarr")
except importlib.metadata.PackageNotFoundError:
__version__ = "unknown"

del importlib

from ._widget import Viewer
109 changes: 109 additions & 0 deletions python/src/vizarr/_widget.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import * as vizarr from "https://hms-dbmi.github.io/vizarr/index.js";
import debounce from "https://esm.sh/just-debounce-it@3";

/**
* @template T
* @param {import("npm:@anywidget/types").AnyModel} model
* @param {any} payload
* @param {{ timeout?: number }} [options]
* @returns {Promise<{ data: T, buffers: DataView[] }>}
*/
function send(model, payload, { timeout = 3000 } = {}) {
let uuid = globalThis.crypto.randomUUID();
return new Promise((resolve, reject) => {
let timer = setTimeout(() => {
reject(new Error(`Promise timed out after ${timeout} ms`));
model.off("msg:custom", handler);
}, timeout);
/**
* @param {{ uuid: string, payload: T }} msg
* @param {DataView[]} buffers
*/
function handler(msg, buffers) {
if (!(msg.uuid === uuid)) return;
clearTimeout(timer);
resolve({ data: msg.payload, buffers });
model.off("msg:custom", handler);
}
model.on("msg:custom", handler);
model.send({ payload, uuid });
});
}

/** @param {import("npm:@anywidget/types").AnyModel} model */
function get_source(model) {
let source = model.get("_source");
if (typeof source === "string") {
return source;
}
// create a python
return {
/**
* @param {string} key
* @return {Promise<ArrayBuffer>}
*/
async getItem(key) {
const { data, buffers } = await send(model, {
type: "get",
source_id: source.id,
key,
});
if (!data.success) {
throw { __zarr__: "KeyError" };
}
return buffers[0].buffer;
},
/**
* @param {string} key
* @return {Promise<boolean>}
*/
async containsItem(key) {
const { data } = await send(model, {
type: "has",
source_id: source.id,
key,
});
return data;
},
};
}

/**
* @typedef Model
* @property {string | { id: string }} _source
* @property {string} height
* @property {ViewState=} view_state
*/

/**
* @typedef ViewState
* @property {number} zoom
* @property {[x: number, y: number]} target
*/

/** @type {import("npm:@anywidget/types").Render<Model>} */
export function render({ model, el }) {
let div = document.createElement("div");
{
div.style.height = model.get("height");
div.style.backgroundColor = "black";
model.on("change:height", () => {
div.style.height = model.get("height");
});
}
let viewer = vizarr.createViewer(div);
{
model.on("change:view_state", () => {
viewer.setViewState(model.get("view_state"));
});
viewer.on(
"viewStateChange",
debounce((/** @type {ViewState} */ update) => {
model.set("view_state", update);
model.save_changes();
}, 200),
);
}
viewer.addImage({ source: get_source(model) });
el.appendChild(div);
}
57 changes: 57 additions & 0 deletions python/src/vizarr/_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import anywidget
import traitlets
import pathlib

import zarr
import numpy as np

__all__ = ["Viewer"]

def _store_keyprefix(obj):
# Just grab the store and key_prefix from zarr.Array and zarr.Group objects
if isinstance(obj, (zarr.Array, zarr.Group)):
return obj.store, obj._key_prefix

if isinstance(obj, np.ndarray):
# Create an in-memory store, and write array as as single chunk
store = {}
arr = zarr.create(store=store, shape=obj.shape, chunks=obj.shape, dtype=obj.dtype)
arr[:] = obj
return store, ""

if hasattr(obj, "__getitem__") and hasattr(obj, "__contains__"):
return obj, ""

raise TypeError("Cannot normalize store path")

class Viewer(anywidget.AnyWidget):
_esm = pathlib.Path(__file__).parent / "_widget.js"
_source = traitlets.Any().tag(sync=True)
view_state = traitlets.Dict().tag(sync=True)
height = traitlets.Unicode("500px").tag(sync=True);

def __init__(self, source, **kwargs):
self._store_paths = []
if not isinstance(source, str):
store, key_prefix = _store_keyprefix(source)
source = { "id": len(self._store_paths) }
self._store_paths.append((store, key_prefix))
super().__init__(_source=source, **kwargs)
self.on_msg(self._handle_custom_msg)

def _handle_custom_msg(self, msg, buffers):
store, key_prefix = self._store_paths[msg["payload"]["source_id"]]
key = key_prefix + msg["payload"]["key"].lstrip("/")

if msg["payload"]["type"] == "has":
self.send({ "uuid": msg["uuid"], "payload": key in store })
print(store)
return

if msg["payload"]["type"] == "get":
try:
buffers = [store[key]]
except KeyError:
buffers = []
self.send({ "uuid": msg["uuid"], "payload": { "success": len(buffers) == 1 } }, buffers)
return

0 comments on commit f532c5e

Please sign in to comment.