diff --git a/README.md b/README.md index 3aa05c6..05e9181 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,43 @@ Package for PCA-based spatial domain identification in single-cell spatial trans - [API documentation][link-api]. --> -Given an AnnData object `adata`, you can run nichepca as follows: +Given an AnnData object `adata`, you can run nichepca starting from raw counts as follows: ```python import scanpy as sc import nichepca as npc -npc.wf.run_nichepca(adata, knn=5) -sc.pp.neighbors(adata) +npc.wf.nichepca(adata, knn=25) +sc.pp.neighbors(adata, use_rep="X_npca") sc.tl.leiden(adata, resolution=0.5) ``` +If you have multiple samples in `adata.obs['sample']`, you can provide the key `sample` to `npc.wf.nichepca`: + +```python +npc.wf.nichepca(adata, knn=25, sample_key="sample") +``` + +If you have cell type labels in `adata.obs['cell_type']`, you can directly provide them to `nichepca` as follows: + +```python +npc.wf.nichepca(adata, knn=25, obs_key='cell_type') +``` + +The `nichepca` functiopn also allows to customize the original `("norm", "log1p", "agg", "pca")` pipeline, e.g., without median normalization: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["log1p", "agg", "pca"]) +``` +or with `"pca"` before `"agg"`: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "pca", "agg"]) +``` +or without `"pca"` at all: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "agg"]) +``` +## Setting parameters +We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. The number of PCs (`n_comps=30` by default) seems to have negligible effect on the results. ## Installation You need to have Python 3.10 or newer installed on your system. If you don't have diff --git a/src/nichepca/graph_construction/_spatial_graph.py b/src/nichepca/graph_construction/_spatial_graph.py index 1e02673..72f69fd 100644 --- a/src/nichepca/graph_construction/_spatial_graph.py +++ b/src/nichepca/graph_construction/_spatial_graph.py @@ -209,7 +209,7 @@ def knn_graph( def distance_graph( adata: AnnData, - radius: int = 50, + radius: float = 50, obsm_key: str = "spatial", remove_self_loops: bool = False, p: int = 2, @@ -223,7 +223,7 @@ def distance_graph( ---------- adata : AnnData Annotated data object. - radius : int, default 50 + radius : float, default 50 Radius for the distance threshold. obsm_key : str, default "spatial" Key in `obsm` attribute where the spatial data is stored. diff --git a/src/nichepca/workflows/__init__.py b/src/nichepca/workflows/__init__.py index c80f759..f4e1b54 100644 --- a/src/nichepca/workflows/__init__.py +++ b/src/nichepca/workflows/__init__.py @@ -1 +1 @@ -from ._nichepca import run_nichepca +from ._nichepca import nichepca diff --git a/src/nichepca/workflows/_nichepca.py b/src/nichepca/workflows/_nichepca.py index b1a2a68..a9d92e7 100644 --- a/src/nichepca/workflows/_nichepca.py +++ b/src/nichepca/workflows/_nichepca.py @@ -2,87 +2,182 @@ from typing import TYPE_CHECKING +import numpy as np +import pandas as pd import scanpy as sc from nichepca.graph_construction import ( construct_multi_sample_graph, - distance_graph, - knn_graph, + resolve_graph_constructor, ) from nichepca.nhood_embedding import aggregate -from nichepca.utils import check_for_raw_counts +from nichepca.utils import check_for_raw_counts, normalize_per_sample if TYPE_CHECKING: from anndata import AnnData -def run_nichepca( +def nichepca( adata: AnnData, - knn: int = None, - radius: float = None, - sample_key: str = None, + knn: int | None = None, + radius: float | None = None, + delaunay: bool = False, n_comps: int = 30, - max_iter_harmony: int = 50, + obs_key: str | None = None, + obsm_key: str | None = None, + sample_key: str | None = None, + pipeline: tuple | list = ("norm", "log1p", "agg", "pca"), norm_per_sample: bool = True, + backend: str = "pyg", + aggr: str = "mean", + allow_harmony: bool = True, + max_iter_harmony: int = 50, **kwargs, ): """ - Run the NichePCA workflow. + Run the general NichePCA workflow. Parameters ---------- adata : AnnData - Annotated data object. - knn : int - Number of nearest neighbors for the kNN graph. - sample_key : str, optional - Key in `adata.obs` that identifies distinct samples. If provided, harmony will be used to - integrate the data. - radius : float, optional - The radius of the neighborhood graph. + The input AnnData object. + knn : int | None, optional + Number of nearest neighbors to use for graph construction. + radius : float | None, optional + Radius for graph construction. + delaunay : bool, optional + Whether to use Delaunay triangulation for graph construction. n_comps : int, optional Number of principal components to compute. - max_iter_harmony : int, optional - Maximum number of iterations for harmony. + obs_key : str | None, optional + Observation key to use for generating a new AnnData object. + obsm_key : str | None, optional + Observation matrix key to use as input. + sample_key : str | None, optional + Sample key to use for multi-sample graph construction. + pipeline : tuple | list, optional + Pipeline of steps to perform. Must include 'agg'. norm_per_sample : bool, optional - Whether to normalize the data per sample. - kwargs : dict, optional - Additional keyword arguments for the graph construction. + Whether to normalize per sample. + backend : str, optional + Backend to use for aggregation. + aggr : str, optional + Aggregation method to use. + allow_harmony : bool, optional + Whether to allow Harmony integration. + max_iter_harmony : int, optional + Maximum number of iterations for Harmony. + **kwargs : dict + Additional keyword arguments. Returns ------- None """ - check_for_raw_counts(adata) + # we always need to use agg + assert "agg" in pipeline, "aggregation must be part of the pipeline" + # assert that the pca is behind norm and log1p + if "pca" in pipeline: + pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax( + np.array(pipeline) == "norm" + ) + pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax( + np.array(pipeline) == "log1p" + ) + assert ( + pca_after_norm and pca_after_log1p + ), "pca must be executed after norm and log1p" + + # perform sanity check in case we are normalizing the data + if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None: + check_for_raw_counts(adata) + # extract any additional kwargs that are not directed to the graph construction + target_sum = kwargs.pop("target_sum", None) + + # construct the (multi-sample) graph if sample_key is not None: construct_multi_sample_graph( - adata, sample_key=sample_key, knn=knn, radius=radius, **kwargs + adata, + sample_key=sample_key, + knn=knn, + radius=radius, + delaunay=delaunay, + **kwargs, ) else: - if knn is not None: - knn_graph(adata, knn, **kwargs) - elif radius is not None: - distance_graph(adata, radius, **kwargs) - else: - raise ValueError("Either knn or radius must be provided.") - - if norm_per_sample and sample_key is not None: - for sample in adata.obs[sample_key].unique(): - mask = adata.obs[sample_key] == sample - sub_ad = adata[mask].copy() - sc.pp.normalize_total(sub_ad) - adata[mask].X = sub_ad.X + resolve_graph_constructor(radius, knn, delaunay)(adata, **kwargs) + + # if an obs_key is provided generate a new AnnData + if obs_key is not None: + df = pd.get_dummies(adata.obs[obs_key], dtype=np.int8) + X = df.values + var = pd.DataFrame(index=df.columns) + # remove normalization steps + pipeline = [p for p in pipeline if p not in ["norm", "log1p"]] + print(f"obs_key provided, running pipeline: {'->'.join(pipeline)}") + elif obsm_key is not None: + X = adata.obsm[obsm_key] + var = adata.var[[]] else: - sc.pp.normalize_total(adata) + X = adata.X + var = adata.var[[]] + print(f"Running pipeline: {'->'.join(pipeline)}") - sc.pp.log1p(adata) + # create intermediate AnnData + ad_tmp = sc.AnnData( + X=X, + obs=adata.obs, + var=var, + uns=adata.uns, + ) - aggregate(adata) + for fn in pipeline: + if fn == "norm": + if norm_per_sample and sample_key is not None: + normalize_per_sample( + ad_tmp, sample_key=sample_key, target_sum=target_sum + ) + else: + sc.pp.normalize_total(ad_tmp, target_sum=target_sum) + elif fn == "log1p": + sc.pp.log1p(ad_tmp) + elif fn == "agg": + # if pca is executed before agg, we need to aggregate the pca results + if "X_pca_harmony" in ad_tmp.obsm: + obsm_key_agg = "X_pca_harmony" + elif "X_pca" in ad_tmp.obsm: + obsm_key_agg = "X_pca" + else: + obsm_key_agg = None + aggregate( + ad_tmp, + backend=backend, + aggr=aggr, + obsm_key=obsm_key_agg, + suffix="", + ) + elif fn == "pca": + sc.tl.pca(ad_tmp, n_comps=n_comps) + # run harmony if sample_key is provided and obs key is None + if sample_key is not None and obs_key is None and allow_harmony: + sc.external.pp.harmony_integrate( + ad_tmp, key=sample_key, max_iter_harmony=max_iter_harmony + ) + else: + raise ValueError(f"Unknown step in the pipeline: {fn}") - sc.tl.pca(adata, n_comps=n_comps) + # extract the results and remove old keys + if "X_pca_harmony" in ad_tmp.obsm: + X_npca = ad_tmp.obsm["X_pca_harmony"] + else: + X_npca = ad_tmp.obsm["X_pca"] - if sample_key is not None: - sc.external.pp.harmony_integrate( - adata, key=sample_key, max_iter_harmony=max_iter_harmony - ) + # store the results + adata.obsm["X_npca"] = X_npca + adata.uns["npca"] = ad_tmp.uns["pca"] + adata.uns["npca"]["PCs"] = pd.DataFrame( + data=ad_tmp.varm["PCs"], + index=ad_tmp.var_names, + columns=[f"PC{i}" for i in range(n_comps)], + ) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index bb769c4..645153b 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import scanpy as sc from utils import generate_dummy_adata @@ -7,7 +8,7 @@ def test_nichepca_single(): adata_1 = generate_dummy_adata() - npc.wf.run_nichepca(adata_1, knn=10, n_comps=30) + npc.wf.nichepca(adata_1, knn=10, n_comps=30) adata_2 = generate_dummy_adata() sc.pp.normalize_total(adata_2) @@ -16,12 +17,76 @@ def test_nichepca_single(): npc.ne.aggregate(adata_2) sc.tl.pca(adata_2, n_comps=30) - assert np.all(adata_1.obsm["X_pca"] == adata_2.obsm["X_pca"]) + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"]) + # test with obs_key + obs_key = "cell_type" + n_celltypes = 5 + adata_1 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.wf.nichepca(adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key) -def test_nichepca_multi(): - adata = generate_dummy_adata() - npc.wf.run_nichepca(adata, knn=10, sample_key="sample") + adata_2 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.gc.knn_graph(adata_2, knn=10) + df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8) + ad_tmp = sc.AnnData( + X=df.values, + obs=adata_2.obs, + var=pd.DataFrame(index=df.columns), + uns=adata_2.uns, + ) + npc.ne.aggregate(ad_tmp) + sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1) + + assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"]) + + # test with pca before agg + adata_1 = generate_dummy_adata() + npc.wf.nichepca( + adata_1, knn=10, n_comps=30, pipeline=["norm", "log1p", "pca", "agg"] + ) + + adata_2 = generate_dummy_adata() + npc.gc.knn_graph(adata_2, knn=10) + sc.pp.normalize_total(adata_2) + sc.pp.log1p(adata_2) + sc.pp.pca(adata_2, n_comps=30) + npc.ne.aggregate(adata_2, obsm_key="X_pca", suffix="") + + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"]) + + +def test_nichepca_multi_sample(): + adata_1 = generate_dummy_adata() + npc.wf.nichepca(adata_1, knn=10, n_comps=30, sample_key="sample") + + adata_2 = generate_dummy_adata() + npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample") + npc.utils.normalize_per_sample(adata_2, sample_key="sample") + sc.pp.log1p(adata_2) + npc.ne.aggregate(adata_2) + sc.tl.pca(adata_2, n_comps=30) + sc.external.pp.harmony_integrate(adata_2, key="sample", max_iter_harmony=50) + + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca_harmony"]) + + # test with obs_key + obs_key = "cell_type" + n_celltypes = 5 + adata_1 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.wf.nichepca( + adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key, sample_key="sample" + ) + + adata_2 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample") + df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8) + ad_tmp = sc.AnnData( + X=df.values, + obs=adata_2.obs, + var=pd.DataFrame(index=df.columns), + uns=adata_2.uns, + ) + npc.ne.aggregate(ad_tmp) + sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1) - assert "X_pca" in adata.obsm.keys() - assert "X_pca_harmony" in adata.obsm.keys() + assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"]) diff --git a/tests/utils.py b/tests/utils.py index aa00dbf..4a40c2c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ from sklearn.cluster import KMeans -def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, seed=0): +def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, n_celltypes=5, seed=0): random_state = np.random.RandomState(seed) X = random_state.randint(0, 400, size=(n_cells, n_genes)) @@ -23,4 +23,8 @@ def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, seed=0): samples = kmeans.fit_predict(coords) adata.obs["sample"] = [str(s) for s in samples] + # create artificial cell type column + adata.obs["cell_type"] = random_state.randint(0, n_celltypes, size=n_cells) + adata.obs["cell_type"] = adata.obs["cell_type"].astype(str).astype("category") + return adata