Skip to content

Commit

Permalink
Merge pull request #4 from dschaub95/nichepca-refactor
Browse files Browse the repository at this point in the history
Nichepca refactor
  • Loading branch information
dschaub95 authored Sep 30, 2024
2 parents b64111d + bef5eea commit 9d8c311
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 59 deletions.
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,43 @@ Package for PCA-based spatial domain identification in single-cell spatial trans
- [API documentation][link-api]. -->

Given an AnnData object `adata`, you can run nichepca as follows:
Given an AnnData object `adata`, you can run nichepca starting from raw counts as follows:

```python
import scanpy as sc
import nichepca as npc

npc.wf.run_nichepca(adata, knn=5)
sc.pp.neighbors(adata)
npc.wf.nichepca(adata, knn=25)
sc.pp.neighbors(adata, use_rep="X_npca")
sc.tl.leiden(adata, resolution=0.5)
```

If you have multiple samples in `adata.obs['sample']`, you can provide the key `sample` to `npc.wf.nichepca`:

```python
npc.wf.nichepca(adata, knn=25, sample_key="sample")
```

If you have cell type labels in `adata.obs['cell_type']`, you can directly provide them to `nichepca` as follows:

```python
npc.wf.nichepca(adata, knn=25, obs_key='cell_type')
```

The `nichepca` functiopn also allows to customize the original `("norm", "log1p", "agg", "pca")` pipeline, e.g., without median normalization:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["log1p", "agg", "pca"])
```
or with `"pca"` before `"agg"`:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "pca", "agg"])
```
or without `"pca"` at all:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "agg"])
```
## Setting parameters
We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. The number of PCs (`n_comps=30` by default) seems to have negligible effect on the results.
## Installation

You need to have Python 3.10 or newer installed on your system. If you don't have
Expand Down
4 changes: 2 additions & 2 deletions src/nichepca/graph_construction/_spatial_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def knn_graph(

def distance_graph(
adata: AnnData,
radius: int = 50,
radius: float = 50,
obsm_key: str = "spatial",
remove_self_loops: bool = False,
p: int = 2,
Expand All @@ -223,7 +223,7 @@ def distance_graph(
----------
adata : AnnData
Annotated data object.
radius : int, default 50
radius : float, default 50
Radius for the distance threshold.
obsm_key : str, default "spatial"
Key in `obsm` attribute where the spatial data is stored.
Expand Down
2 changes: 1 addition & 1 deletion src/nichepca/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._nichepca import run_nichepca
from ._nichepca import nichepca
185 changes: 140 additions & 45 deletions src/nichepca/workflows/_nichepca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,182 @@

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import scanpy as sc

from nichepca.graph_construction import (
construct_multi_sample_graph,
distance_graph,
knn_graph,
resolve_graph_constructor,
)
from nichepca.nhood_embedding import aggregate
from nichepca.utils import check_for_raw_counts
from nichepca.utils import check_for_raw_counts, normalize_per_sample

if TYPE_CHECKING:
from anndata import AnnData


def run_nichepca(
def nichepca(
adata: AnnData,
knn: int = None,
radius: float = None,
sample_key: str = None,
knn: int | None = None,
radius: float | None = None,
delaunay: bool = False,
n_comps: int = 30,
max_iter_harmony: int = 50,
obs_key: str | None = None,
obsm_key: str | None = None,
sample_key: str | None = None,
pipeline: tuple | list = ("norm", "log1p", "agg", "pca"),
norm_per_sample: bool = True,
backend: str = "pyg",
aggr: str = "mean",
allow_harmony: bool = True,
max_iter_harmony: int = 50,
**kwargs,
):
"""
Run the NichePCA workflow.
Run the general NichePCA workflow.
Parameters
----------
adata : AnnData
Annotated data object.
knn : int
Number of nearest neighbors for the kNN graph.
sample_key : str, optional
Key in `adata.obs` that identifies distinct samples. If provided, harmony will be used to
integrate the data.
radius : float, optional
The radius of the neighborhood graph.
The input AnnData object.
knn : int | None, optional
Number of nearest neighbors to use for graph construction.
radius : float | None, optional
Radius for graph construction.
delaunay : bool, optional
Whether to use Delaunay triangulation for graph construction.
n_comps : int, optional
Number of principal components to compute.
max_iter_harmony : int, optional
Maximum number of iterations for harmony.
obs_key : str | None, optional
Observation key to use for generating a new AnnData object.
obsm_key : str | None, optional
Observation matrix key to use as input.
sample_key : str | None, optional
Sample key to use for multi-sample graph construction.
pipeline : tuple | list, optional
Pipeline of steps to perform. Must include 'agg'.
norm_per_sample : bool, optional
Whether to normalize the data per sample.
kwargs : dict, optional
Additional keyword arguments for the graph construction.
Whether to normalize per sample.
backend : str, optional
Backend to use for aggregation.
aggr : str, optional
Aggregation method to use.
allow_harmony : bool, optional
Whether to allow Harmony integration.
max_iter_harmony : int, optional
Maximum number of iterations for Harmony.
**kwargs : dict
Additional keyword arguments.
Returns
-------
None
"""
check_for_raw_counts(adata)
# we always need to use agg
assert "agg" in pipeline, "aggregation must be part of the pipeline"
# assert that the pca is behind norm and log1p
if "pca" in pipeline:
pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "norm"
)
pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "log1p"
)
assert (
pca_after_norm and pca_after_log1p
), "pca must be executed after norm and log1p"

# perform sanity check in case we are normalizing the data
if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None:
check_for_raw_counts(adata)

# extract any additional kwargs that are not directed to the graph construction
target_sum = kwargs.pop("target_sum", None)

# construct the (multi-sample) graph
if sample_key is not None:
construct_multi_sample_graph(
adata, sample_key=sample_key, knn=knn, radius=radius, **kwargs
adata,
sample_key=sample_key,
knn=knn,
radius=radius,
delaunay=delaunay,
**kwargs,
)
else:
if knn is not None:
knn_graph(adata, knn, **kwargs)
elif radius is not None:
distance_graph(adata, radius, **kwargs)
else:
raise ValueError("Either knn or radius must be provided.")

if norm_per_sample and sample_key is not None:
for sample in adata.obs[sample_key].unique():
mask = adata.obs[sample_key] == sample
sub_ad = adata[mask].copy()
sc.pp.normalize_total(sub_ad)
adata[mask].X = sub_ad.X
resolve_graph_constructor(radius, knn, delaunay)(adata, **kwargs)

# if an obs_key is provided generate a new AnnData
if obs_key is not None:
df = pd.get_dummies(adata.obs[obs_key], dtype=np.int8)
X = df.values
var = pd.DataFrame(index=df.columns)
# remove normalization steps
pipeline = [p for p in pipeline if p not in ["norm", "log1p"]]
print(f"obs_key provided, running pipeline: {'->'.join(pipeline)}")
elif obsm_key is not None:
X = adata.obsm[obsm_key]
var = adata.var[[]]
else:
sc.pp.normalize_total(adata)
X = adata.X
var = adata.var[[]]
print(f"Running pipeline: {'->'.join(pipeline)}")

sc.pp.log1p(adata)
# create intermediate AnnData
ad_tmp = sc.AnnData(
X=X,
obs=adata.obs,
var=var,
uns=adata.uns,
)

aggregate(adata)
for fn in pipeline:
if fn == "norm":
if norm_per_sample and sample_key is not None:
normalize_per_sample(
ad_tmp, sample_key=sample_key, target_sum=target_sum
)
else:
sc.pp.normalize_total(ad_tmp, target_sum=target_sum)
elif fn == "log1p":
sc.pp.log1p(ad_tmp)
elif fn == "agg":
# if pca is executed before agg, we need to aggregate the pca results
if "X_pca_harmony" in ad_tmp.obsm:
obsm_key_agg = "X_pca_harmony"
elif "X_pca" in ad_tmp.obsm:
obsm_key_agg = "X_pca"
else:
obsm_key_agg = None
aggregate(
ad_tmp,
backend=backend,
aggr=aggr,
obsm_key=obsm_key_agg,
suffix="",
)
elif fn == "pca":
sc.tl.pca(ad_tmp, n_comps=n_comps)
# run harmony if sample_key is provided and obs key is None
if sample_key is not None and obs_key is None and allow_harmony:
sc.external.pp.harmony_integrate(
ad_tmp, key=sample_key, max_iter_harmony=max_iter_harmony
)
else:
raise ValueError(f"Unknown step in the pipeline: {fn}")

sc.tl.pca(adata, n_comps=n_comps)
# extract the results and remove old keys
if "X_pca_harmony" in ad_tmp.obsm:
X_npca = ad_tmp.obsm["X_pca_harmony"]
else:
X_npca = ad_tmp.obsm["X_pca"]

if sample_key is not None:
sc.external.pp.harmony_integrate(
adata, key=sample_key, max_iter_harmony=max_iter_harmony
)
# store the results
adata.obsm["X_npca"] = X_npca
adata.uns["npca"] = ad_tmp.uns["pca"]
adata.uns["npca"]["PCs"] = pd.DataFrame(
data=ad_tmp.varm["PCs"],
index=ad_tmp.var_names,
columns=[f"PC{i}" for i in range(n_comps)],
)
79 changes: 72 additions & 7 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import scanpy as sc
from utils import generate_dummy_adata

Expand All @@ -7,7 +8,7 @@

def test_nichepca_single():
adata_1 = generate_dummy_adata()
npc.wf.run_nichepca(adata_1, knn=10, n_comps=30)
npc.wf.nichepca(adata_1, knn=10, n_comps=30)

adata_2 = generate_dummy_adata()
sc.pp.normalize_total(adata_2)
Expand All @@ -16,12 +17,76 @@ def test_nichepca_single():
npc.ne.aggregate(adata_2)
sc.tl.pca(adata_2, n_comps=30)

assert np.all(adata_1.obsm["X_pca"] == adata_2.obsm["X_pca"])
assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"])

# test with obs_key
obs_key = "cell_type"
n_celltypes = 5
adata_1 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.wf.nichepca(adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key)

def test_nichepca_multi():
adata = generate_dummy_adata()
npc.wf.run_nichepca(adata, knn=10, sample_key="sample")
adata_2 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.gc.knn_graph(adata_2, knn=10)
df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8)
ad_tmp = sc.AnnData(
X=df.values,
obs=adata_2.obs,
var=pd.DataFrame(index=df.columns),
uns=adata_2.uns,
)
npc.ne.aggregate(ad_tmp)
sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1)

assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"])

# test with pca before agg
adata_1 = generate_dummy_adata()
npc.wf.nichepca(
adata_1, knn=10, n_comps=30, pipeline=["norm", "log1p", "pca", "agg"]
)

adata_2 = generate_dummy_adata()
npc.gc.knn_graph(adata_2, knn=10)
sc.pp.normalize_total(adata_2)
sc.pp.log1p(adata_2)
sc.pp.pca(adata_2, n_comps=30)
npc.ne.aggregate(adata_2, obsm_key="X_pca", suffix="")

assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"])


def test_nichepca_multi_sample():
adata_1 = generate_dummy_adata()
npc.wf.nichepca(adata_1, knn=10, n_comps=30, sample_key="sample")

adata_2 = generate_dummy_adata()
npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample")
npc.utils.normalize_per_sample(adata_2, sample_key="sample")
sc.pp.log1p(adata_2)
npc.ne.aggregate(adata_2)
sc.tl.pca(adata_2, n_comps=30)
sc.external.pp.harmony_integrate(adata_2, key="sample", max_iter_harmony=50)

assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca_harmony"])

# test with obs_key
obs_key = "cell_type"
n_celltypes = 5
adata_1 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.wf.nichepca(
adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key, sample_key="sample"
)

adata_2 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample")
df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8)
ad_tmp = sc.AnnData(
X=df.values,
obs=adata_2.obs,
var=pd.DataFrame(index=df.columns),
uns=adata_2.uns,
)
npc.ne.aggregate(ad_tmp)
sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1)

assert "X_pca" in adata.obsm.keys()
assert "X_pca_harmony" in adata.obsm.keys()
assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"])
Loading

0 comments on commit 9d8c311

Please sign in to comment.