From 8760824ba9ce8d522ab8ccc5bed2b22e72ebaad2 Mon Sep 17 00:00:00 2001 From: Vineet Bansal Date: Fri, 22 Nov 2024 13:25:47 -0500 Subject: [PATCH] napari plugin improvements (#94) * selectable clustering variable for a dataset; sample data; added pooch as a dependency * moved dataset stuff to paste3.dataset; some tests * support for multiple overlap_fraction values for pairwise_align; tests for napari plugin * added notebook on using dataset api * napari is also a docs extra dependency * not forcing napari for docs; protected imports * adding layers after return to the main thread for center alignment --- docs/source/notebooks/paste3_dataset.ipynb | 237 +++++++++++++++++ src/paste3/dataset.py | 295 +++++++++++++++++++-- src/paste3/helper.py | 18 +- src/paste3/napari/__init__.py | 9 +- src/paste3/napari/_sample_data.py | 25 +- src/paste3/napari/_widget.py | 129 ++++++--- src/paste3/napari/data/ondemand.py | 64 ++++- src/paste3/napari/napari.yaml | 36 ++- src/paste3/paste.py | 195 ++++++++------ tests/conftest.py | 7 +- tests/test_dataset_slice.py | 10 +- tests/test_napari_plugin.py | 17 +- tests/test_paste.py | 1 + 13 files changed, 838 insertions(+), 205 deletions(-) create mode 100644 docs/source/notebooks/paste3_dataset.ipynb diff --git a/docs/source/notebooks/paste3_dataset.ipynb b/docs/source/notebooks/paste3_dataset.ipynb new file mode 100644 index 0000000..f939427 --- /dev/null +++ b/docs/source/notebooks/paste3_dataset.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6ed2400f-074c-41c1-9b9c-c4164e3e35ff", + "metadata": {}, + "source": [ + "# `Slice` and `AlignmentDataset` objects" + ] + }, + { + "cell_type": "markdown", + "id": "74234097-899a-450b-90c6-ecfc22fbc352", + "metadata": {}, + "source": [ + "The `paste3.dataset` module provides an easy-to-use API to access input datasets to the `paste3` alignment algorithms.\n", + "\n", + "The `Slice` class is a thin layer on top of an `AnnData` class, and an `AlignmentDataset` class is a collection of `Slice` objects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T17:23:44.775643550Z", + "start_time": "2024-11-14T17:23:44.734901656Z" + } + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "from paste3.dataset import AlignmentDataset\n", + "from paste3.napari.data.ondemand import get_file" + ] + }, + { + "cell_type": "markdown", + "id": "d75a8a52-f594-4ff9-b81d-f7e794f1f711", + "metadata": {}, + "source": [ + "Individual `Slice` objects are created by providing a path to an `.h5ad` file. Each `.h5ad` file is expected to contain an `AnnData` object, and is internally read using a `scanpy.read_h5ad`.\n", + "\n", + "Here we download and cache a few `.h5ad` files locally using a `paste3.napari.data.ondemand.get_file` call. These are the files available as the Sample Data in the `paste3` napari plugin." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bdcae26500ecd773", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T17:23:44.898402985Z", + "start_time": "2024-11-14T17:23:44.775557096Z" + }, + "collapsed": false, + "editable": true, + "jupyter": { + "outputs_hidden": false + }, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "slice_files = [Path(f) for f in get_file(\"paste3_sample_patient_2_\")]" + ] + }, + { + "cell_type": "markdown", + "id": "30b304c3-236d-4cbb-8762-973665596109", + "metadata": {}, + "source": [ + "A dataset is created using the paths to the individual slices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab9f0325-ff80-4322-bd8d-a0e80920abce", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = AlignmentDataset(file_paths=slice_files)" + ] + }, + { + "cell_type": "markdown", + "id": "a1cfd794-7890-4f5d-a6dd-bef3b0380980", + "metadata": {}, + "source": [ + "Any individual slice can be rendered in a jupyter notebook by simply typing the slice variable name in a cell, which renders the slice using the `squidpy` library. (Note: This is roughly equivalent to doing `squidpy.pl.spatial_scatter(.adata, ..)`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96c22c2e-d50b-40bc-b9c8-7092274dfdb7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset.slices[0]" + ] + }, + { + "cell_type": "markdown", + "id": "5c2b5106-24c0-467e-8c38-71c39e711bd1", + "metadata": {}, + "source": [ + "An entire dataset can be rendered by typing the dataset variable name in a cell, which renders each slice in order." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89a94821-197c-4fe4-b042-fcecca8355dd", + "metadata": {}, + "outputs": [], + "source": [ + "dataset" + ] + }, + { + "cell_type": "markdown", + "id": "44319b2a-1eca-4eb4-91f4-fca1a5f9807c", + "metadata": {}, + "source": [ + "## Center Aligning a Dataset\n", + "\n", + "A dataset object can be center aligned in 2 steps:\n", + "\n", + "1. Find the \"center slice\" (or the \"consensus slice\") and similarity matrix between spots using the `.find_center_slice` method. **This is a time consuming step and benefits from being run on a GPU-enabled environment.**\n", + "2. Use these values to center align the dataset using the `.center_align` method.\n", + "\n", + "The first returned value is the aligned dataset, along with other useful information (rotations/translations). Here we ignore all returned values except the first one.\n", + "\n", + "Center alignment is explained in detail in the [Paste](https://www.nature.com/articles/s41592-022-01459-6) paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5fba70-a7e8-4924-8268-d3bdd8741e23", + "metadata": {}, + "outputs": [], + "source": [ + "center_slice, pis = dataset.find_center_slice()\n", + "aligned_dataset, *_ = dataset.center_align(center_slice=center_slice, pis=pis)" + ] + }, + { + "cell_type": "markdown", + "id": "866dfa67-3150-4792-a3ae-b1b7a5a01635", + "metadata": {}, + "source": [ + "We can render the center slice and the aligned dataset as usual." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b05c39a9-5aa5-4a2d-be75-8c52029ec0e5", + "metadata": {}, + "outputs": [], + "source": [ + "center_slice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3937cdf5-fd86-487b-ac3b-a9c97da27815", + "metadata": {}, + "outputs": [], + "source": [ + "aligned_dataset" + ] + }, + { + "cell_type": "markdown", + "id": "6616430c-76b2-48b9-9eb3-e69d87838cb7", + "metadata": {}, + "source": [ + "## Pairwise aligning a Dataset\n", + "\n", + "A dataset can be pairwise aligned using the `.pairwise_align` method. An `overlap_fraction` value (between 0 and 1) can be specified.\n", + "\n", + "A value of `None` results in pairwise alignment that is identical to the approach mentioned in the [Paste](https://www.nature.com/articles/s41592-022-01459-6) paper. Any other value between 0 and 1 results in pairwise alignment explained in the [Paste2](https://pubmed.ncbi.nlm.nih.gov/37553263/) paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5fa6853-f112-4cd6-8115-ed697b53d1cc", + "metadata": {}, + "outputs": [], + "source": [ + "pairwise_aligned_dataset = dataset.pairwise_align(overlap_fraction=0.7)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "319621a8-aafb-48c4-b4bc-03f742578e77", + "metadata": {}, + "outputs": [], + "source": [ + "pairwise_aligned_dataset[0]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/paste3/dataset.py b/src/paste3/dataset.py index d47bd0c..afbc44f 100644 --- a/src/paste3/dataset.py +++ b/src/paste3/dataset.py @@ -11,37 +11,92 @@ from anndata import AnnData from sklearn.cluster import KMeans -from paste3.paste import center_align, pairwise_align +from paste3.helper import wait +from paste3.paste import center_align_gen, pairwise_align from paste3.visualization import stack_slices_center, stack_slices_pairwise logger = logging.getLogger(__name__) class Slice: - def __init__(self, filepath: Path | None = None, adata: AnnData | None = None): + """ + A single slice of spatial data. + """ + + def __init__( + self, + filepath: Path | None = None, + adata: AnnData | None = None, + name: str | None = None, + ): + """ + Initialize a slice of spatial data. + + Parameters + ---------- + filepath : Path, optional + Path to an h5ad file containing spatial data. + adata : AnnData, optional + Anndata object containing spatial data. + If specified, takes precedence over `filepath`. + name : str, optional + Name of the slice. + If not specified, the name is inferred from the file path or the adata object. + """ self.filepath = filepath self._adata = adata + if name is None: + if self.filepath is not None: + self.name = Path(self.filepath).stem + else: + self.name = "Slice with adata: " + str(self.adata).split("\n")[0] + else: + self.name = name - # Is the 'obs' array of `adata` indexed by strings of the form "XxY", - # where X/Y are Visium array locations? - # This format has been observed in legacy data. + """ + Is the 'obs' array of `adata` indexed by strings of the form "XxY", + where X/Y are Visium array locations? + This format has been observed in legacy data. + """ self.has_coordinate_indices = all( "x" in index for index in self.adata.obs.index.values ) - self.has_spatial_data = "spatial" in self.adata.obsm - def __str__(self): - if self.filepath is not None: - return Path(self.filepath).stem - return "Slice with adata: " + str(self.adata).split("\n")[0] + return self.name + + def _repr_mimebundle_(self, include=None, exclude=None): # noqa: ARG002 + try: + import squidpy + except ImportError: + return {} + else: + squidpy.pl.spatial_scatter( + self.adata, + frameon=False, + shape=None, + color="original_clusters", + title=str(self), + ) + + # squidpy takes care of the rendering so we return an empty dict + return {} @cached_property def adata(self): + """ + Anndata object containing spatial data. + """ return self._adata or sc.read_h5ad(str(self.filepath)) @cached_property def obs(self): + """ + Anndata object containing observation metadata. + The index of this dataframe is updated to be a MultiIndex + with Visium array coordinates as indices if the observation + metadata was originally indexed by strings of the form "XxY" + """ if self.has_coordinate_indices: logger.debug("Updating obs indices for easy access") obs = self.adata.obs.copy() @@ -51,22 +106,68 @@ def obs(self): return obs return self.adata.obs - def get_obs_values(self, which, coordinates=None): + def get_obs_values(self, which: str, coordinates: Any | None = None): + """ + Get values from the observation metadata for specific coordinates. + + Parameters + ---------- + which : str + Column name to extract values from. + coordinates : Any, optional + List of Visium array coordinates to extract values for. + These should be in the form of a list of tuples (X, Y), + or whatever the format of the index of the observation metadata is. + If not specified, values for all coordinates are returned. + """ assert which in self.obs.columns, f"Unknown column: {which}" if coordinates is None: coordinates = self.obs.index.values return self.obs.loc[coordinates][which].tolist() - def set_obs_values(self, which, values): + def set_obs_values(self, which: str, values: Any): + """ + Set values in the observation metadata for specific coordinates. + + Parameters + ---------- + which : str + Column name to set values for. + values : Any + List of values to set for the specified column. + """ self.obs[which] = values def cluster( self, n_clusters: int, - uns_key: str = "paste_W", - random_state: int = 5, + uns_key: str, + random_state: int = 0, save_as: str | None = None, - ): + ) -> list[Any]: + """ + Cluster observations based on a specified uns (unstructured) key + in the underlying AnnData object of the Slice. + The uns key is expected to contain a matrix of weights with shape + (n_obs, n_features). + + Parameters + ---------- + n_clusters : int + Number of clusters to form. + uns_key : str, optional + Key in the uns array of the AnnData object to use for clustering. + random_state : int, optional + Random seed for reproducibility. Default 0. + save_as : str, optional + Name of the observation metadata column to save the cluster labels to. + If not specified, the labels are not saved. + + Returns + ------- + labels : np.ndarray + Cluster labels for each observation. + """ a = self.adata.uns[uns_key].copy() a = (a.T / a.sum(axis=1)).T a = a + 1 @@ -81,6 +182,10 @@ def cluster( class AlignmentDataset: + """ + A dataset of spatial slices that can be aligned together. + """ + def __init__( self, file_paths: list[Path] | None = None, @@ -89,6 +194,26 @@ def __init__( max_slices: int | None = None, name: str | None = None, ): + """ + Initialize a dataset of spatial slices. + + Parameters + ---------- + file_paths : list of Path, optional + List of paths to h5ad files containing spatial data. + glob_pattern : str, optional + Glob pattern to match files containing spatial data. + If specified, takes precedence over `file_paths`. + slices : list of Slice, optional + List of Slice objects containing spatial data. + If specified, takes precedence over `file_paths` and `glob_pattern`. + max_slices : int, optional + Maximum number of slices to load. + If not specified, all slices are loaded. + name : str, optional + Name of the dataset. + If not specified, the name is inferred from the common prefix of slice names. + """ if slices is not None: self.slices = slices[:max_slices] elif glob_pattern is not None: @@ -117,11 +242,34 @@ def __iter__(self): def __len__(self): return len(self.slices) + def _repr_mimebundle_(self, include=None, exclude=None): + for slice in self.slices: + slice._repr_mimebundle_(include=include, exclude=exclude) + + # each slice takes care of the rendering so we return an empty dict + return {} + @property def slices_adata(self) -> list[AnnData]: + """ + List of AnnData objects containing spatial data. + """ return [slice_.adata for slice_ in self.slices] - def get_obs_values(self, which, coordinates=None): + def get_obs_values(self, which: str, coordinates: Any | None = None): + """ + Get values from the observation metadata for specific coordinates. + + Parameters + ---------- + which : str + Column name to extract values from. + coordinates : Any, optional + List of Visium array coordinates to extract values for. + These should be in the form of a list of tuples (X, Y), + or whatever the format of the index of the observation metadata is. + If not specified, values for all coordinates are returned. + """ return [slice_.get_obs_values(which, coordinates) for slice_ in self.slices] def align( @@ -131,6 +279,24 @@ def align( overlap_fraction: float | list[float] | None = None, max_iters: int = 1000, ): + """ + Align slices in the dataset. + + Parameters + ---------- + center_align : bool, optional + Whether to center-align the slices. Default False. + If False, pairwise-align the slices. + pis : np.ndarray, optional + Pairwise similarity between slices. Only used in pairwise-align + mode. If not specified, the similarity is calculated. + overlap_fraction : float or list of float, optional + Fraction of overlap between slices. Only used, and required + in pairwise-align mode. + max_iters : int, optional + Maximum number of iterations for alignment. Default 1000. + Only used in pairwise-align mode. + """ if center_align: if overlap_fraction is not None: logger.warning( @@ -146,7 +312,9 @@ def align( overlap_fraction=overlap_fraction, pis=pis, max_iters=max_iters ) - def find_pis(self, overlap_fraction: float | list[float], max_iters: int = 1000): + def find_pis( + self, overlap_fraction: float | list[float] | None = None, max_iters: int = 1000 + ): # If multiple overlap_fraction values are specified # ensure that they are |slices| - 1 in length try: @@ -172,17 +340,44 @@ def find_pis(self, overlap_fraction: float | list[float], max_iters: int = 1000) def pairwise_align( self, - overlap_fraction: float | list[float], + overlap_fraction: float | list[float] | None = None, pis: list[np.ndarray] | None = None, max_iters: int = 1000, - ): + ) -> tuple["AlignmentDataset", list[np.ndarray], list[np.ndarray]]: + """ + Pairwise align slices in the dataset. + + Parameters + ---------- + overlap_fraction : float or list of float or None, optional + Fraction of overlap between each adjacent pair of slices. + If a single value between 0 and 1 is specified, it is used for all pairs. + If None, then a full alignment is performed. + pis : list of np.ndarray, optional + Pairwise similarity between slices. + If not specified, the similarity is calculated. + max_iters : int, optional + Maximum number of iterations for alignment. Default 1000. + + Returns + ------- + aligned_dataset : AlignmentDataset + Aligned dataset. + rotation_angles : list of np.ndarray + Rotation angles for each slice. + translations : list of np.ndarray + Mutual translations for each pair of adjacent slices. + """ if pis is None: pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters) new_slices, rotation_angles, translations = stack_slices_pairwise( self.slices_adata, pis ) aligned_dataset = AlignmentDataset( - slices=[Slice(adata=s) for s in new_slices], + slices=[ + Slice(adata=new_slice, name=old_slice.name + "_pairwise_aligned") + for old_slice, new_slice in zip(self.slices, new_slices, strict=False) + ], name=self.name + "_pairwise_aligned", ) @@ -199,13 +394,53 @@ def find_center_slice( exp_dissim_metric: str = "kl", norm: bool = False, random_seed: int | None = None, - pbar: Any = None, + block: bool = True, ) -> tuple[Slice, list[np.ndarray]]: + r""" + Find the center slice of the dataset. + + Parameters + ---------- + initial_slice : Slice, optional + Initial slice to be used as a reference data for alignment. + If not specified, the first slice in the dataset is used. + slice_weights : list of float, optional + Weights for each slice. If not specified, all slices are equally weighted. + alpha : float, optional, default 0.1 + Regularization parameter balancing transcriptional dissimilarity and spatial distance among aligned spots. + Setting \alpha = 0 uses only transcriptional information, while \alpha = 1 uses only spatial coordinates. + n_components : int, optional, default 15 + Number of components to use in the NMF decomposition. + threshold : float, optional, default 0.001 + Convergence threshold for the NMF algorithm. + max_iter : int, optional, default 10 + Maximum number of iterations for the NMF algorithm. + exp_dissim_metric : str, optional, default 'kl' + The metric used to compute dissimilarity. Options include "euclidean" or "kl" for + Kullback-Leibler divergence. + norm : bool, default=False + If True, normalize spatial distances. + random_seed : Optional[int], default=None + Random seed for reproducibility. + block : bool, optional, default True + Whether to block till the center slice is found. + Set False to return a generator. + + Returns + ------- + Tuple[Slice, List[np.ndarray]] + A tuple containing: + - center_slice : Slice + Center slice of the dataset. + - pis : List[np.ndarray] + List of optimal transport distributions for each slice + with the center slice. + """ logger.info("Finding center slice") if initial_slice is None: initial_slice = self.slices[0] - center_slice, pis = center_align( + gen = center_align_gen( initial_slice=initial_slice.adata, slices=self.slices_adata, slice_weights=slice_weights, @@ -216,10 +451,13 @@ def find_center_slice( exp_dissim_metric=exp_dissim_metric, norm=norm, random_seed=random_seed, - pbar=pbar, fast=True, ) - return Slice(adata=center_slice), pis + + if block: + center_slice, pis = wait(gen) + return Slice(adata=center_slice, name=self.name + "_center_slice"), pis + return iter(gen) def center_align( self, @@ -233,7 +471,9 @@ def center_align( logger.warning( "Ignoring pis argument since center_slice is not provided" ) - center_slice, pis = self.find_center_slice(initial_slice=initial_slice) + center_slice, pis = self.find_center_slice( + initial_slice=initial_slice, block=True + ) logger.info("Stacking slices around center slice") new_center, new_slices, rotation_angles, translations = stack_slices_center( @@ -242,7 +482,10 @@ def center_align( pis=pis, ) aligned_dataset = AlignmentDataset( - slices=[Slice(adata=s) for s in new_slices], + slices=[ + Slice(adata=new_slice, name=old_slice.name + "_center_aligned") + for old_slice, new_slice in zip(self.slices, new_slices, strict=False) + ], name=self.name + "_center_aligned", ) diff --git a/src/paste3/helper.py b/src/paste3/helper.py index accefe0..fd93349 100644 --- a/src/paste3/helper.py +++ b/src/paste3/helper.py @@ -1,9 +1,3 @@ -""" -This module provides helper functions to compute an optimal transport plan that aligns multiple tissue slices -using result of an ST experiment that includes a p genes by n spots transcript count matrix and coordinate -matrix of the spots -""" - import logging import anndata as ad @@ -378,3 +372,15 @@ def dissimilarity_metric(which, a_slice, b_slice, a_exp_dissim, b_exp_dissim, ** case _: msg = f"Error: Invalid dissimilarity metric {which}" raise RuntimeError(msg) + + +def wait(gen): + """ + Wait for the completion of a passed-in generator, + returning the final value. + """ + try: + while True: + next(gen) + except StopIteration as e: + return e.value diff --git a/src/paste3/napari/__init__.py b/src/paste3/napari/__init__.py index 7efbd77..d395b04 100644 --- a/src/paste3/napari/__init__.py +++ b/src/paste3/napari/__init__.py @@ -1,12 +1,5 @@ __version__ = "0.0.1" -from ._reader import napari_get_reader from ._sample_data import make_sample_data -from ._widget import CenterAlignContainer, PairwiseAlignContainer -__all__ = ( - "make_sample_data", - "napari_get_reader", - "CenterAlignContainer", - "PairwiseAlignContainer", -) +__all__ = ("make_sample_data",) diff --git a/src/paste3/napari/_sample_data.py b/src/paste3/napari/_sample_data.py index 70ee918..703a142 100644 --- a/src/paste3/napari/_sample_data.py +++ b/src/paste3/napari/_sample_data.py @@ -6,13 +6,8 @@ from paste3.napari.data.ondemand import get_file -def make_sample_data(): - remote_files = [ - "paste3_sample_patient_2_slice_0.h5ad", - "paste3_sample_patient_2_slice_1.h5ad", - "paste3_sample_patient_2_slice_2.h5ad", - ] - local_files = [get_file(file) for file in remote_files] # paths to local files +def make_sample_data(prefix): + local_files = get_file(prefix) # paths to local files dataset = AlignmentDataset(file_paths=[Path(file) for file in local_files]) data = [] # list of 3-tuples (data, kwargs, layer_type) @@ -55,3 +50,19 @@ def make_sample_data(): ) return data + + +def make_sample_data0(): + return make_sample_data("paste3_sample_patient_2_") + + +def make_sample_data1(): + return make_sample_data("paste3_sample_patient_5_") + + +def make_sample_data2(): + return make_sample_data("paste3_sample_patient_9_") + + +def make_sample_data3(): + return make_sample_data("paste3_sample_patient_10_") diff --git a/src/paste3/napari/_widget.py b/src/paste3/napari/_widget.py index dcc9aff..cd894da 100644 --- a/src/paste3/napari/_widget.py +++ b/src/paste3/napari/_widget.py @@ -1,10 +1,10 @@ import napari import seaborn as sns from magicgui.widgets import Container, create_widget +from napari.qt.threading import create_worker from napari.utils.notifications import show_error -from napari.utils.progress import progress -from paste3.dataset import AlignmentDataset +from paste3.dataset import AlignmentDataset, Slice face_color_cycle = sns.color_palette("Paired", 20) @@ -76,14 +76,16 @@ def __init__( options={"choices": [str(slice) for slice in self.dataset]}, ) - keys = list(self.dataset.slices[0].adata.obs.keys()) spot_color_key = None - for key in keys: - if "cluster" in key: - spot_color_key = key - break - if spot_color_key is None and len(keys) > 0: - spot_color_key = keys[0] + keys = [] + if len(self.dataset) > 0: + keys = list(self.dataset.slices[0].adata.obs.keys()) + for key in keys: + if "cluster" in key: + spot_color_key = key + break + if spot_color_key is None and len(keys) > 0: + spot_color_key = keys[0] self._spot_color_key_dropdown = create_widget( label="Spot Color Key", @@ -109,7 +111,7 @@ def __init__( ) self._threshold_textbox = create_widget( - label="Threshold", annotation=float, value=0.001 + label="Threshold", annotation=float, value=0.001, options={"step": 0.00001} ) self._max_iterations_textbox = create_widget( @@ -136,7 +138,7 @@ def __init__( label="Run", annotation=None, widget_type="PushButton" ) - self._run_button.changed.connect(self._run) + self._run_button.changed.connect(self.run) self.extend( [ @@ -154,12 +156,23 @@ def __init__( ] ) - def _run(self): - cluster_indices = set() - for slice in self.dataset: - clusters = set(slice.get_obs_values(self._spot_color_key_dropdown.value)) - cluster_indices |= clusters - n_clusters = len(cluster_indices) + def _find_center_slice(self): + """ + Start center alignment. + Since center alignment is typically a long running process, + we'll run it in a separate thread and yield for progress updates. + + Returns + ------- + Tuple[Slice, List[np.ndarray]] + A tuple containing: + - center_slice : Slice + The aligned Slice object representing the center slice after optimization. + - pis : List[np.ndarray] + List of optimal transport distributions for each slice after alignment. + """ + if self.dataset is None or len(self.dataset) < 2: + return show_error("Please select a dataset with at least 2 slices.") reference_slice = self._viewer.layers[ self._reference_slice_dropdown.value @@ -172,20 +185,33 @@ def _run(self): except ValueError: slice_weights = None - with progress(total=self._max_iterations_textbox.value) as pbar: - center_slice, pis = self.dataset.find_center_slice( - initial_slice=reference_slice, - slice_weights=slice_weights, - alpha=self._alpha_slider.value, - n_components=self._n_components_textbox.value, - threshold=self._threshold_textbox.value, - max_iter=self._max_iterations_textbox.value, - exp_dissim_metric=self._exp_dis_metric_dropdown.value, - norm=self._norm_checkbox.value, - random_seed=self._random_seed_textbox.value, - pbar=pbar, + gen = self.dataset.find_center_slice( + initial_slice=reference_slice, + slice_weights=slice_weights, + alpha=self._alpha_slider.value, + n_components=self._n_components_textbox.value, + threshold=self._threshold_textbox.value, + max_iter=self._max_iterations_textbox.value, + exp_dissim_metric=self._exp_dis_metric_dropdown.value, + norm=self._norm_checkbox.value, + random_seed=self._random_seed_textbox.value, + block=False, + ) + + try: + while True: + # Be a good citizen and yield for progress updates + yield next(gen) + except StopIteration as e: + center_slice, pis = e.value + center_slice = Slice( + adata=center_slice, name=self.dataset.name + "_center_slice" ) + return center_slice, pis + + def _found_center_slice(self, center_slice_and_pis): + center_slice, pis = center_slice_and_pis aligned_dataset, _, translations = self.dataset.center_align( center_slice=center_slice, pis=pis ) @@ -201,9 +227,19 @@ def _run(self): first_layer_translation=first_layer_translation, ) - # Show center slice + # Find the number of clusters in the original dataset + cluster_indices = set() + for slice in self.dataset: + clusters = set(slice.get_obs_values(self._spot_color_key_dropdown.value)) + cluster_indices |= clusters + n_clusters = len(cluster_indices) + + # Show center slice with the same no. of clusters as the original + # dataset center_slice_points = center_slice.adata.obsm["spatial"] - center_slice_clusters = center_slice.cluster(n_clusters) + center_slice_clusters = center_slice.cluster( + n_clusters=n_clusters, uns_key="paste_W" + ) self._viewer.add_points( center_slice_points, ndim=2, @@ -214,6 +250,13 @@ def _run(self): name="paste3_center_slice", ) + def run(self): + worker = create_worker( + self._find_center_slice, _start_thread=False, _progress=True + ) + worker.returned.connect(self._found_center_slice) + worker.start() + class PairwiseAlignContainer(AlignContainer): def __init__( @@ -261,7 +304,7 @@ def __init__( label="Run", annotation=None, widget_type="PushButton" ) - self._run_button.changed.connect(self._run) + self._run_button.changed.connect(self.run) self.extend( [ @@ -272,14 +315,20 @@ def __init__( ] ) - def _run(self): - overlap = [float(w) for w in self._overlap_textbox.value.split(",")] - if len(overlap) == 1: # scalar provided - overlap = [overlap[0]] * (len(self.dataset) - 1) - if len(overlap) != len(self.dataset) - 1: - return show_error( - "Overlap fraction must be a scalar or a list of length n-1, where n is the number of slices" - ) + def run(self): + if self.dataset is None or len(self.dataset) < 2: + return show_error("Please select a dataset with at least 2 slices.") + + if self._overlap_textbox.value.strip() == "": + overlap = None + else: + overlap = [float(w) for w in self._overlap_textbox.value.split(",")] + if len(overlap) == 1: # scalar provided + overlap = [overlap[0]] * (len(self.dataset) - 1) + if len(overlap) != len(self.dataset) - 1: + return show_error( + "Overlap fraction must be a scalar or a list of length n-1, where n is the number of slices" + ) aligned_dataset, _, translations = self.dataset.pairwise_align( overlap_fraction=overlap, diff --git a/src/paste3/napari/data/ondemand.py b/src/paste3/napari/data/ondemand.py index 063fa77..ef176c3 100644 --- a/src/paste3/napari/data/ondemand.py +++ b/src/paste3/napari/data/ondemand.py @@ -15,27 +15,65 @@ files = { "paste3_sample_patient_2_slice_0.h5ad": { - "url": "https://dl.dropboxusercontent.com/scl/fi/zq0dlcgjaxfe9fqbp0hf4/patient_2_slice_0.h5ad?rlkey=sxj5c843b38vd3iv2n74824hu&st=pdelsbuz&dl=1", + "url": "https://dl.dropboxusercontent.com/scl/fi/zq0dlcgjaxfe9fqbp0hf4/paste3_sample_patient_2_slice_0.h5ad?rlkey=sxj5c843b38vd3iv2n74824hu&st=wcy6oxbt&dl=1", "hash": "md5:3f2a599a067d3752bd735ea2a01e19f3", }, "paste3_sample_patient_2_slice_1.h5ad": { - "url": "https://dl.dropboxusercontent.com/scl/fi/a5ufhjylxfnvcn5sw4yp0/patient_2_slice_1.h5ad?rlkey=p6dp78qhz6qrh0ut49s7b3fvj&st=2ysuoay4&dl=1", + "url": "https://dl.dropboxusercontent.com/scl/fi/a5ufhjylxfnvcn5sw4yp0/paste3_sample_patient_2_slice_1.h5ad?rlkey=p6dp78qhz6qrh0ut49s7b3fvj&st=tyamjq8b&dl=1", "hash": "md5:a6d1db8ae803e52154cb47e7f8433ffa", }, "paste3_sample_patient_2_slice_2.h5ad": { - "url": "https://dl.dropboxusercontent.com/scl/fi/u7aaq9az8sia26cn4ac4s/patient_2_slice_2.h5ad?rlkey=3ynobd5ajhlvc7lwdbyg0akj1&st=fp7aq5zh&dl=1", + "url": "https://dl.dropboxusercontent.com/scl/fi/u7aaq9az8sia26cn4ac4s/paste3_sample_patient_2_slice_2.h5ad?rlkey=3ynobd5ajhlvc7lwdbyg0akj1&st=0l2nw8i2&dl=1", "hash": "md5:7a64c48af327554dd314439fdbe718ce", }, + "paste3_sample_patient_5_slice_0.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/ypj05gsopwh74ruycjll8/paste3_sample_patient_5_slice_0.h5ad?rlkey=fdbdpuncunpcmqyxed5x687t6&st=u61mutsr&dl=1", + "hash": "md5:d74b47b1e8e9af45085a76c463169f75", + }, + "paste3_sample_patient_5_slice_1.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/e1cqlwi313ykjgzl8pmoi/paste3_sample_patient_5_slice_1.h5ad?rlkey=g60hoh2d6qpleaqb59m4xr8n4&st=whwe4oxr&dl=1", + "hash": "md5:cfa621ccb3d13181bd82edc58de4ba22", + }, + "paste3_sample_patient_5_slice_2.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/0jim40rezs1kfk0hhx8r7/paste3_sample_patient_5_slice_2.h5ad?rlkey=gu5wh4m2i58so35gwiyvpvkmj&st=qhpnoqjg&dl=1", + "hash": "md5:2c93234cf9592a6a3d79afe08f59a144", + }, + "paste3_sample_patient_9_slice_0.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/auu0hj6b7b7adhek23eal/paste3_sample_patient_9_slice_0.h5ad?rlkey=k2cwouul7pk7zjeymn6ryt2ef&st=sdseia87&dl=1", + "hash": "md5:6854eed7b4dc768007ca91e4c7ea35df", + }, + "paste3_sample_patient_9_slice_1.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/4nt6rd70u7czsftv0ka0u/paste3_sample_patient_9_slice_1.h5ad?rlkey=b79bmtp9dz48u9oa4tlkq3uqg&st=bvlw62ym&dl=1", + "hash": "md5:a8734cf971d08b851a4502c96f7b56a5", + }, + "paste3_sample_patient_9_slice_2.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/o9jhgkwzzppgfsuewy6eo/paste3_sample_patient_9_slice_2.h5ad?rlkey=kjl6jn24awtwafhoz9yjamdob&st=r7le0v8j&dl=1", + "hash": "md5:87df8ce67e0b3ce0891a868d565e2216", + }, + "paste3_sample_patient_10_slice_0.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/b7s1hfkfy3ajtzb0wy1yb/paste3_sample_patient_10_slice_0.h5ad?rlkey=u77grnq5xud7q7wwzm2hq05hx&st=icmc2ttg&dl=1", + "hash": "md5:3dacdb5d8b39056d1b764b401b94cbac", + }, + "paste3_sample_patient_10_slice_1.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/4nlwrpllzows1u0rodi7j/paste3_sample_patient_10_slice_1.h5ad?rlkey=6m3qzp0gqwrgpahlm3qn1o2w5&st=gondhkyg&dl=1", + "hash": "md5:ec8abf1a6a0cf6e8f5841b5dad15bdb7", + }, + "paste3_sample_patient_10_slice_2.h5ad": { + "url": "https://dl.dropboxusercontent.com/scl/fi/qjb82v0kqirkr00x0acyf/paste3_sample_patient_10_slice_2.h5ad?rlkey=z8j0pb57le9h6h802dcn9t9i2&st=qxt6a0bw&dl=1", + "hash": "md5:1b60fbd0bf267babb0f0c0434bdc5d21", + }, } -def get_file(which): - assert which in files, f"Unknown file {which}" - file = files[which] - return pooch.retrieve( - url=file["url"], - known_hash=file["hash"], - fname=which, - processor=file.get("processor"), - path=CACHE_PATH, - ) +def get_file(which: str) -> list[str]: + remote_files = {k: v for k, v in files.items() if k.startswith(which)} + return [ + pooch.retrieve( + url=v["url"], + known_hash=v["hash"], + fname=k, + processor=v.get("processor"), + path=CACHE_PATH, + ) + for k, v in remote_files.items() + ] diff --git a/src/paste3/napari/napari.yaml b/src/paste3/napari/napari.yaml index 8912bf5..2df74a3 100644 --- a/src/paste3/napari/napari.yaml +++ b/src/paste3/napari/napari.yaml @@ -1,30 +1,46 @@ name: paste3 display_name: Paste3 -# use 'hidden' to remove plugin from napari hub search results -visibility: hidden contributions: commands: - - id: paste3.make_sample_data - python_name: paste3.napari._sample_data:make_sample_data - title: Load sample data for Paste3 + - id: paste3.make_sample_data0 + python_name: paste3.napari._sample_data:make_sample_data0 + title: sample_data0 + - id: paste3.make_sample_data1 + python_name: paste3.napari._sample_data:make_sample_data1 + title: sample_data1 + - id: paste3.make_sample_data2 + python_name: paste3.napari._sample_data:make_sample_data2 + title: sample_data2 + - id: paste3.make_sample_data3 + python_name: paste3.napari._sample_data:make_sample_data3 + title: sample_data3 - id: paste3.get_reader python_name: paste3.napari._reader:napari_get_reader title: Open data with Paste3 - id: paste3.center_align_widget - python_name: paste3.napari:CenterAlignContainer + python_name: paste3.napari._widget:CenterAlignContainer title: Center Align - id: paste3.pairwise_align_widget - python_name: paste3.napari:PairwiseAlignContainer + python_name: paste3.napari._widget:PairwiseAlignContainer title: Pairwise Align readers: - command: paste3.get_reader accepts_directories: false filename_patterns: [ '*.h5ad' ] sample_data: - - command: paste3.make_sample_data - display_name: Paste3 Sample Data - key: data.napari.paste3.raphael-group.github.com + - command: paste3.make_sample_data0 + display_name: SCC Patient 2 + key: data0.napari.paste3.raphael-group.github.com + - command: paste3.make_sample_data1 + display_name: SCC Patient 5 + key: data1.napari.paste3.raphael-group.github.com + - command: paste3.make_sample_data2 + display_name: SCC Patient 9 + key: data2.napari.paste3.raphael-group.github.com + - command: paste3.make_sample_data3 + display_name: SCC Patient 10 + key: data3.napari.paste3.raphael-group.github.com widgets: - command: paste3.center_align_widget display_name: Center Align diff --git a/src/paste3/paste.py b/src/paste3/paste.py index 5a1dd19..5111e98 100644 --- a/src/paste3/paste.py +++ b/src/paste3/paste.py @@ -6,7 +6,7 @@ import logging from collections.abc import Callable -from typing import Any +from time import sleep import numpy as np import ot @@ -21,6 +21,7 @@ dissimilarity_metric, get_common_genes, to_dense_array, + wait, ) logger = logging.getLogger(__name__) @@ -196,7 +197,7 @@ def pairwise_align( if norm: a_spatial_dist /= torch.min(a_spatial_dist[a_spatial_dist > 0]) b_spatial_dist /= torch.min(b_spatial_dist[b_spatial_dist > 0]) - if overlap_fraction: + if overlap_fraction is not None: a_spatial_dist /= a_spatial_dist[a_spatial_dist > 0].max() a_spatial_dist *= exp_dissim_matrix.max() b_spatial_dist /= b_spatial_dist[b_spatial_dist > 0].max() @@ -220,7 +221,7 @@ def pairwise_align( ) -def center_align( +def center_align_gen( initial_slice: AnnData, slices: list[AnnData], slice_weights=None, @@ -235,81 +236,11 @@ def center_align( spots_weights=None, use_gpu: bool = True, fast: bool = False, - pbar: Any = None, ) -> tuple[AnnData, list[np.ndarray]]: - r""" - Infers a "center" slice consisting of a low rank expression matrix :math:`X = WH` and a collection of - :math:`\pi` of mappings from the spots of the center slice to the spots of each input slice. - - Given slices :math:`(X^{(1)}, D^{(1)}, g^{(1)}), \dots, (X^{(t)}, D^{(t)}, g^{(t)})` containing :math:`n_1, \dots, n_t` - spots, respectively over the same :math:`p` genes, a spot distance matrix :math:`D \in \mathbb{R}^{n \times n}_{+}`, - a distribution :math:`g` over :math:`n` spots, an expression cost function :math:`c`, a distribution - :math:`\lambda \in \mathbb{R}^t_{+}` and parameters :math:`0 \leq \alpha \leq 1`, :math:`m \in \mathbb{N}`, - find an expression matrix :math:`X = WH` where :math:`W \in \mathbb{R}^{p \times m}_{+}` and :math:`H \in \mathbb{R}^{m \times n}_{+}`, - and mappings :math:`\Pi^{(q)} \in \Gamma(g, g^{(q)})` for each slice :math:`q = 1, \dots, t` that minimize the following objective: - - .. math:: - R(W, H, \Pi^{(1)}, \dots, \Pi^{(t)}) = \sum_q \lambda_q F(\Pi^{(q)}; WH, D, X^{(q)}, D^{(q)}, c, \alpha) - - = \sum_q \lambda_q \left[(1 - \alpha) \sum_{i,j} c(WH_{\cdot,i}, x^{(q)}_j) \pi^{(q)}_{ij} + \alpha \sum_{i,j,k,l} (d_{ik} - d^{(q)}_{jl})^2 \pi^{(q)}_{ij} \pi^{(q)}_{kl} \right]. - - Where: - - - :math:`X^{q} = [x_{ij}] \in \mathbb{N}^{p \times n_t}` is a :math:`p` genes by :math:`n_t` spots transcript count matrix for :math:`q^{th}` slice, - - :math:`D^{(q)}`, where :math:`d_ij = \parallel z_.i - z_.j \parallel` is the spatial distance between spot :math:`i` and :math:`j`, represents the spot pairwise distance matrix for :math:`q^{th}` slice, - - :math:`c: \mathbb{R}^{p}_{+} \times \mathbb{R}^{p}_{+} \to \mathbb{R}_{+}`, is a function that measures a nonnegative cost between the expression profiles of two spots over all genes - - :math:`\alpha` is a parameter balancing expression and spatial distance preservation, - - :math:`W` and :math:`H` form the low-rank approximation of the center slice's expression matrix, and - - :math:`\lambda_q` weighs each slice :math:`q` in the objective. - - Parameters - ---------- - initial_slice : AnnData - An AnnData object that represent a slice to be used as a reference data for alignment - slices : List[AnnData] - A list of AnnData objects that represent different slices to be aligned with the initial slice. - slice_weights : List[float], optional - Weights for each slice in the alignment process. If None, all slices are treated equally. - alpha : float, default=0.1 - Regularization parameter balancing transcriptional dissimilarity and spatial distance among aligned spots. - Setting \alpha = 0 uses only transcriptional information, while \alpha = 1 uses only spatial coordinates. - n_components : int, default=15 - Number of components to use for the NMF. - threshold : float, default=0.001 - Convergence threshold for the optimization process. The process stops when the change - in loss is below this threshold. - max_iter : int, default=10 - Maximum number of iterations for the optimization process. - exp_dissim_metric : str, default="kl" - The metric used to compute dissimilarity. Options include "euclidean" or "kl" for - Kullback-Leibler divergence. - norm : bool, default=False - If True, normalizes spatial distances. - random_seed : Optional[int], default=None - Random seed for reproducibility. - pi_inits : Optional[List[np.ndarray]], default=None - Initial transport plans for each slice. If None, it will be computed. - spots_weights : List[float], optional - Weights for individual spots in each slices. If None, uniform distribution is used. - use_gpu : bool, default=True - Whether to use GPU for computations. If True but no GPU is available, will default to CPU. - fast : bool, default=False - Whether to use the fast (untested) torch nmf library - pbar : Any, default=None - Progress bar (tqdm or derived) for tracking the optimization process. - Something that has an `update` method. - Returns - ------- - Tuple[AnnData, List[np.ndarray]] - A tuple containing: - - center_slice : AnnData - The aligned AnnData object representing the center slice after optimization. - - pis : List[np.ndarray] - List of optimal transport distributions for each slice after alignment. - - Returns: - - Inferred center slice with full and low dimensional representations (feature_matrix, coeff_matrix) of the gene expression matrix. - - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns). + """ + Analogous to the blocking `center_align` function, but since it can be + time-intensive, implemented as a generator function that yields at each + iteration. """ if use_gpu and not torch.cuda.is_available(): logger.info("GPU is not available, resorting to torch CPU.") @@ -375,8 +306,8 @@ def center_align( logger.info(f"Objective {loss_new} | Difference: {loss_diff}") loss_init = loss_new - if pbar is not None: - pbar.update(1) + yield + sleep(0.01) center_slice = initial_slice.copy() center_slice.X = np.dot(feature_matrix, coeff_matrix) @@ -387,9 +318,115 @@ def center_align( * compute_slice_weights(slice_weights, pis, slices, device).cpu().numpy() ) center_slice.uns["obj"] = loss_init + logger.info("Center slice computed.") return center_slice, pis +def center_align( + initial_slice: AnnData, + slices: list[AnnData], + slice_weights=None, + alpha: float = 0.1, + n_components: int = 15, + threshold: float = 0.001, + max_iter: int = 10, + exp_dissim_metric: str = "kl", + norm: bool = False, + random_seed: int | None = None, + pi_inits: list[np.ndarray] | None = None, + spots_weights=None, + use_gpu: bool = True, + fast: bool = False, +) -> tuple[AnnData, list[np.ndarray]]: + r""" + Infers a "center" slice consisting of a low rank expression matrix :math:`X = WH` and a collection of + :math:`\pi` of mappings from the spots of the center slice to the spots of each input slice. + + Given slices :math:`(X^{(1)}, D^{(1)}, g^{(1)}), \dots, (X^{(t)}, D^{(t)}, g^{(t)})` containing :math:`n_1, \dots, n_t` + spots, respectively over the same :math:`p` genes, a spot distance matrix :math:`D \in \mathbb{R}^{n \times n}_{+}`, + a distribution :math:`g` over :math:`n` spots, an expression cost function :math:`c`, a distribution + :math:`\lambda \in \mathbb{R}^t_{+}` and parameters :math:`0 \leq \alpha \leq 1`, :math:`m \in \mathbb{N}`, + find an expression matrix :math:`X = WH` where :math:`W \in \mathbb{R}^{p \times m}_{+}` and :math:`H \in \mathbb{R}^{m \times n}_{+}`, + and mappings :math:`\Pi^{(q)} \in \Gamma(g, g^{(q)})` for each slice :math:`q = 1, \dots, t` that minimize the following objective: + + .. math:: + R(W, H, \Pi^{(1)}, \dots, \Pi^{(t)}) = \sum_q \lambda_q F(\Pi^{(q)}; WH, D, X^{(q)}, D^{(q)}, c, \alpha) + + = \sum_q \lambda_q \left[(1 - \alpha) \sum_{i,j} c(WH_{\cdot,i}, x^{(q)}_j) \pi^{(q)}_{ij} + \alpha \sum_{i,j,k,l} (d_{ik} - d^{(q)}_{jl})^2 \pi^{(q)}_{ij} \pi^{(q)}_{kl} \right]. + + Where: + + - :math:`X^{q} = [x_{ij}] \in \mathbb{N}^{p \times n_t}` is a :math:`p` genes by :math:`n_t` spots transcript count matrix for :math:`q^{th}` slice, + - :math:`D^{(q)}`, where :math:`d_ij = \parallel z_.i - z_.j \parallel` is the spatial distance between spot :math:`i` and :math:`j`, represents the spot pairwise distance matrix for :math:`q^{th}` slice, + - :math:`c: \mathbb{R}^{p}_{+} \times \mathbb{R}^{p}_{+} \to \mathbb{R}_{+}`, is a function that measures a nonnegative cost between the expression profiles of two spots over all genes + - :math:`\alpha` is a parameter balancing expression and spatial distance preservation, + - :math:`W` and :math:`H` form the low-rank approximation of the center slice's expression matrix, and + - :math:`\lambda_q` weighs each slice :math:`q` in the objective. + + Parameters + ---------- + initial_slice : AnnData + An AnnData object that represent a slice to be used as a reference data for alignment + slices : List[AnnData] + A list of AnnData objects that represent different slices to be aligned with the initial slice. + slice_weights : List[float], optional + Weights for each slice in the alignment process. If None, all slices are treated equally. + alpha : float, default=0.1 + Regularization parameter balancing transcriptional dissimilarity and spatial distance among aligned spots. + Setting \alpha = 0 uses only transcriptional information, while \alpha = 1 uses only spatial coordinates. + n_components : int, default=15 + Number of components to use for the NMF. + threshold : float, default=0.001 + Convergence threshold for the optimization process. The process stops when the change + in loss is below this threshold. + max_iter : int, default=10 + Maximum number of iterations for the optimization process. + exp_dissim_metric : str, default="kl" + The metric used to compute dissimilarity. Options include "euclidean" or "kl" for + Kullback-Leibler divergence. + norm : bool, default=False + If True, normalizes spatial distances. + random_seed : Optional[int], default=None + Random seed for reproducibility. + pi_inits : Optional[List[np.ndarray]], default=None + Initial transport plans for each slice. If None, it will be computed. + spots_weights : List[float], optional + Weights for individual spots in each slices. If None, uniform distribution is used. + use_gpu : bool, default=True + Whether to use GPU for computations. If True but no GPU is available, will default to CPU. + fast : bool, default=False + Whether to use the fast (untested) torch nmf library + + Returns + ------- + Tuple[AnnData, List[np.ndarray]] + A tuple containing: + - center_slice : AnnData + The aligned AnnData object representing the center slice after optimization. + - pis : List[np.ndarray] + List of optimal transport distributions for each slice after alignment. + """ + # Call our generator function, but wait for its completion + return wait( + center_align_gen( + initial_slice=initial_slice, + slices=slices, + slice_weights=slice_weights, + alpha=alpha, + n_components=n_components, + threshold=threshold, + max_iter=max_iter, + exp_dissim_metric=exp_dissim_metric, + norm=norm, + random_seed=random_seed, + pi_inits=pi_inits, + spots_weights=spots_weights, + use_gpu=use_gpu, + fast=fast, + ) + ) + + # --------------------------- HELPER METHODS ----------------------------------- diff --git a/tests/conftest.py b/tests/conftest.py index 5493e86..88adb8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,12 +75,7 @@ def slices2(): @pytest.fixture(scope="session") def sample_data_files(): - files = [ - "paste3_sample_patient_2_slice_0.h5ad", - "paste3_sample_patient_2_slice_1.h5ad", - "paste3_sample_patient_2_slice_2.h5ad", - ] - return [Path(get_file(f)) for f in files] + return [Path(f) for f in get_file("paste3_sample_patient_2_")] @pytest.fixture(autouse=True) diff --git a/tests/test_dataset_slice.py b/tests/test_dataset_slice.py index 7caaf79..08e5c11 100644 --- a/tests/test_dataset_slice.py +++ b/tests/test_dataset_slice.py @@ -19,7 +19,14 @@ def test_slice_adata(slices): assert np.all(slice.adata.obs_names == slices[0].obs_names) -def test_slice_adata_str(slices): +def test_slice_adata_name_str(slices): + # We can give names to slices + slice = Slice(adata=slices[0], name="my_slice") + assert str(slice) == "my_slice" + + +def test_slice_adata_noname_str(slices): + # If we don't give names to slices, we have a sensible default slice = Slice(adata=slices[0]) assert ( str(slice) @@ -28,6 +35,7 @@ def test_slice_adata_str(slices): def test_slice_filepath_str(sample_data_files): + # If we don't give names to slices, we have a sensible default slice = Slice(filepath=sample_data_files[0]) assert str(slice) == "paste3_sample_patient_2_slice_0" diff --git a/tests/test_napari_plugin.py b/tests/test_napari_plugin.py index 2cbc025..1ffb9fb 100644 --- a/tests/test_napari_plugin.py +++ b/tests/test_napari_plugin.py @@ -1,10 +1,8 @@ from paste3.dataset import AlignmentDataset -from paste3.napari import ( - CenterAlignContainer, - PairwiseAlignContainer, - make_sample_data, - napari_get_reader, -) +from paste3.helper import wait +from paste3.napari._reader import napari_get_reader +from paste3.napari._sample_data import make_sample_data +from paste3.napari._widget import CenterAlignContainer, PairwiseAlignContainer def test_reader(sample_data_files, make_napari_viewer_proxy): @@ -28,7 +26,7 @@ def test_reader(sample_data_files, make_napari_viewer_proxy): def test_sample_data(): - layer_data = make_sample_data() + layer_data = make_sample_data("paste3_sample_patient_2") # We should have 4 Point layers, one for each slice # and one for the volume assert len(layer_data) == 4 @@ -45,7 +43,8 @@ def test_center_align_widget(sample_data_files, make_napari_viewer_proxy): widget._reference_slice_dropdown.value = "paste3_sample_patient_2_slice_0" widget._max_iterations_textbox.value = "1" - widget._run() + center_slice_and_pis = wait(widget._find_center_slice()) # wait for completion + widget._found_center_slice(center_slice_and_pis) layers = viewer.layers @@ -74,7 +73,7 @@ def test_pairwise_align_widget(sample_data_files, make_napari_viewer_proxy): # emulate UI interaction widget._max_iterations_textbox.value = "1" - widget._run() + widget.run() layers = viewer.layers diff --git a/tests/test_paste.py b/tests/test_paste.py index f775203..5425650 100644 --- a/tests/test_paste.py +++ b/tests/test_paste.py @@ -82,6 +82,7 @@ def test_center_alignment(slices): for i in range(len(slices)) ], ) + assert_frame_equal( pd.DataFrame( center_slice.uns["paste_W"],