Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nichepca refactor #4

Merged
merged 8 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,43 @@ Package for PCA-based spatial domain identification in single-cell spatial trans

- [API documentation][link-api]. -->

Given an AnnData object `adata`, you can run nichepca as follows:
Given an AnnData object `adata`, you can run nichepca starting from raw counts as follows:

```python
import scanpy as sc
import nichepca as npc

npc.wf.run_nichepca(adata, knn=5)
sc.pp.neighbors(adata)
npc.wf.nichepca(adata, knn=25)
sc.pp.neighbors(adata, use_rep="X_npca")
sc.tl.leiden(adata, resolution=0.5)
```

If you have multiple samples in `adata.obs['sample']`, you can provide the key `sample` to `npc.wf.nichepca`:

```python
npc.wf.nichepca(adata, knn=25, sample_key="sample")
```

If you have cell type labels in `adata.obs['cell_type']`, you can directly provide them to `nichepca` as follows:

```python
npc.wf.nichepca(adata, knn=25, obs_key='cell_type')
```

The `nichepca` functiopn also allows to customize the original `("norm", "log1p", "agg", "pca")` pipeline, e.g., without median normalization:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["log1p", "agg", "pca"])
```
or with `"pca"` before `"agg"`:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "pca", "agg"])
```
or without `"pca"` at all:
```python
npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "agg"])
```
## Setting parameters
We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. The number of PCs (`n_comps=30` by default) seems to have negligible effect on the results.
## Installation

You need to have Python 3.10 or newer installed on your system. If you don't have
Expand Down
4 changes: 2 additions & 2 deletions src/nichepca/graph_construction/_spatial_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def knn_graph(

def distance_graph(
adata: AnnData,
radius: int = 50,
radius: float = 50,
obsm_key: str = "spatial",
remove_self_loops: bool = False,
p: int = 2,
Expand All @@ -223,7 +223,7 @@ def distance_graph(
----------
adata : AnnData
Annotated data object.
radius : int, default 50
radius : float, default 50
Radius for the distance threshold.
obsm_key : str, default "spatial"
Key in `obsm` attribute where the spatial data is stored.
Expand Down
2 changes: 1 addition & 1 deletion src/nichepca/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._nichepca import run_nichepca
from ._nichepca import nichepca
185 changes: 140 additions & 45 deletions src/nichepca/workflows/_nichepca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,182 @@

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import scanpy as sc

from nichepca.graph_construction import (
construct_multi_sample_graph,
distance_graph,
knn_graph,
resolve_graph_constructor,
)
from nichepca.nhood_embedding import aggregate
from nichepca.utils import check_for_raw_counts
from nichepca.utils import check_for_raw_counts, normalize_per_sample

if TYPE_CHECKING:
from anndata import AnnData


def run_nichepca(
def nichepca(
adata: AnnData,
knn: int = None,
radius: float = None,
sample_key: str = None,
knn: int | None = None,
radius: float | None = None,
delaunay: bool = False,
n_comps: int = 30,
max_iter_harmony: int = 50,
obs_key: str | None = None,
obsm_key: str | None = None,
sample_key: str | None = None,
pipeline: tuple | list = ("norm", "log1p", "agg", "pca"),
norm_per_sample: bool = True,
backend: str = "pyg",
aggr: str = "mean",
allow_harmony: bool = True,
max_iter_harmony: int = 50,
**kwargs,
):
"""
Run the NichePCA workflow.
Run the general NichePCA workflow.

Parameters
----------
adata : AnnData
Annotated data object.
knn : int
Number of nearest neighbors for the kNN graph.
sample_key : str, optional
Key in `adata.obs` that identifies distinct samples. If provided, harmony will be used to
integrate the data.
radius : float, optional
The radius of the neighborhood graph.
The input AnnData object.
knn : int | None, optional
Number of nearest neighbors to use for graph construction.
radius : float | None, optional
Radius for graph construction.
delaunay : bool, optional
Whether to use Delaunay triangulation for graph construction.
n_comps : int, optional
Number of principal components to compute.
max_iter_harmony : int, optional
Maximum number of iterations for harmony.
obs_key : str | None, optional
Observation key to use for generating a new AnnData object.
obsm_key : str | None, optional
Observation matrix key to use as input.
sample_key : str | None, optional
Sample key to use for multi-sample graph construction.
pipeline : tuple | list, optional
Pipeline of steps to perform. Must include 'agg'.
norm_per_sample : bool, optional
Whether to normalize the data per sample.
kwargs : dict, optional
Additional keyword arguments for the graph construction.
Whether to normalize per sample.
backend : str, optional
Backend to use for aggregation.
aggr : str, optional
Aggregation method to use.
allow_harmony : bool, optional
Whether to allow Harmony integration.
max_iter_harmony : int, optional
Maximum number of iterations for Harmony.
**kwargs : dict
Additional keyword arguments.

Returns
-------
None
"""
check_for_raw_counts(adata)
# we always need to use agg
assert "agg" in pipeline, "aggregation must be part of the pipeline"
# assert that the pca is behind norm and log1p
if "pca" in pipeline:
pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "norm"
)
pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "log1p"
)
assert (
pca_after_norm and pca_after_log1p
), "pca must be executed after norm and log1p"

# perform sanity check in case we are normalizing the data
if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None:
check_for_raw_counts(adata)

# extract any additional kwargs that are not directed to the graph construction
target_sum = kwargs.pop("target_sum", None)

# construct the (multi-sample) graph
if sample_key is not None:
construct_multi_sample_graph(
adata, sample_key=sample_key, knn=knn, radius=radius, **kwargs
adata,
sample_key=sample_key,
knn=knn,
radius=radius,
delaunay=delaunay,
**kwargs,
)
else:
if knn is not None:
knn_graph(adata, knn, **kwargs)
elif radius is not None:
distance_graph(adata, radius, **kwargs)
else:
raise ValueError("Either knn or radius must be provided.")

if norm_per_sample and sample_key is not None:
for sample in adata.obs[sample_key].unique():
mask = adata.obs[sample_key] == sample
sub_ad = adata[mask].copy()
sc.pp.normalize_total(sub_ad)
adata[mask].X = sub_ad.X
resolve_graph_constructor(radius, knn, delaunay)(adata, **kwargs)

# if an obs_key is provided generate a new AnnData
if obs_key is not None:
df = pd.get_dummies(adata.obs[obs_key], dtype=np.int8)
X = df.values
var = pd.DataFrame(index=df.columns)
# remove normalization steps
pipeline = [p for p in pipeline if p not in ["norm", "log1p"]]
print(f"obs_key provided, running pipeline: {'->'.join(pipeline)}")
elif obsm_key is not None:
X = adata.obsm[obsm_key]
var = adata.var[[]]
else:
sc.pp.normalize_total(adata)
X = adata.X
var = adata.var[[]]
print(f"Running pipeline: {'->'.join(pipeline)}")

sc.pp.log1p(adata)
# create intermediate AnnData
ad_tmp = sc.AnnData(
X=X,
obs=adata.obs,
var=var,
uns=adata.uns,
)

aggregate(adata)
for fn in pipeline:
if fn == "norm":
if norm_per_sample and sample_key is not None:
normalize_per_sample(
ad_tmp, sample_key=sample_key, target_sum=target_sum
)
else:
sc.pp.normalize_total(ad_tmp, target_sum=target_sum)
elif fn == "log1p":
sc.pp.log1p(ad_tmp)
elif fn == "agg":
# if pca is executed before agg, we need to aggregate the pca results
if "X_pca_harmony" in ad_tmp.obsm:
obsm_key_agg = "X_pca_harmony"
elif "X_pca" in ad_tmp.obsm:
obsm_key_agg = "X_pca"
else:
obsm_key_agg = None
aggregate(
ad_tmp,
backend=backend,
aggr=aggr,
obsm_key=obsm_key_agg,
suffix="",
)
elif fn == "pca":
sc.tl.pca(ad_tmp, n_comps=n_comps)
# run harmony if sample_key is provided and obs key is None
if sample_key is not None and obs_key is None and allow_harmony:
sc.external.pp.harmony_integrate(
ad_tmp, key=sample_key, max_iter_harmony=max_iter_harmony
)
else:
raise ValueError(f"Unknown step in the pipeline: {fn}")

sc.tl.pca(adata, n_comps=n_comps)
# extract the results and remove old keys
if "X_pca_harmony" in ad_tmp.obsm:
X_npca = ad_tmp.obsm["X_pca_harmony"]
else:
X_npca = ad_tmp.obsm["X_pca"]

if sample_key is not None:
sc.external.pp.harmony_integrate(
adata, key=sample_key, max_iter_harmony=max_iter_harmony
)
# store the results
adata.obsm["X_npca"] = X_npca
adata.uns["npca"] = ad_tmp.uns["pca"]
adata.uns["npca"]["PCs"] = pd.DataFrame(
data=ad_tmp.varm["PCs"],
index=ad_tmp.var_names,
columns=[f"PC{i}" for i in range(n_comps)],
)
79 changes: 72 additions & 7 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import scanpy as sc
from utils import generate_dummy_adata

Expand All @@ -7,7 +8,7 @@

def test_nichepca_single():
adata_1 = generate_dummy_adata()
npc.wf.run_nichepca(adata_1, knn=10, n_comps=30)
npc.wf.nichepca(adata_1, knn=10, n_comps=30)

adata_2 = generate_dummy_adata()
sc.pp.normalize_total(adata_2)
Expand All @@ -16,12 +17,76 @@ def test_nichepca_single():
npc.ne.aggregate(adata_2)
sc.tl.pca(adata_2, n_comps=30)

assert np.all(adata_1.obsm["X_pca"] == adata_2.obsm["X_pca"])
assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"])

# test with obs_key
obs_key = "cell_type"
n_celltypes = 5
adata_1 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.wf.nichepca(adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key)

def test_nichepca_multi():
adata = generate_dummy_adata()
npc.wf.run_nichepca(adata, knn=10, sample_key="sample")
adata_2 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.gc.knn_graph(adata_2, knn=10)
df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8)
ad_tmp = sc.AnnData(
X=df.values,
obs=adata_2.obs,
var=pd.DataFrame(index=df.columns),
uns=adata_2.uns,
)
npc.ne.aggregate(ad_tmp)
sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1)

assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"])

# test with pca before agg
adata_1 = generate_dummy_adata()
npc.wf.nichepca(
adata_1, knn=10, n_comps=30, pipeline=["norm", "log1p", "pca", "agg"]
)

adata_2 = generate_dummy_adata()
npc.gc.knn_graph(adata_2, knn=10)
sc.pp.normalize_total(adata_2)
sc.pp.log1p(adata_2)
sc.pp.pca(adata_2, n_comps=30)
npc.ne.aggregate(adata_2, obsm_key="X_pca", suffix="")

assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"])


def test_nichepca_multi_sample():
adata_1 = generate_dummy_adata()
npc.wf.nichepca(adata_1, knn=10, n_comps=30, sample_key="sample")

adata_2 = generate_dummy_adata()
npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample")
npc.utils.normalize_per_sample(adata_2, sample_key="sample")
sc.pp.log1p(adata_2)
npc.ne.aggregate(adata_2)
sc.tl.pca(adata_2, n_comps=30)
sc.external.pp.harmony_integrate(adata_2, key="sample", max_iter_harmony=50)

assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca_harmony"])

# test with obs_key
obs_key = "cell_type"
n_celltypes = 5
adata_1 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.wf.nichepca(
adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key, sample_key="sample"
)

adata_2 = generate_dummy_adata(n_celltypes=n_celltypes)
npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample")
df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8)
ad_tmp = sc.AnnData(
X=df.values,
obs=adata_2.obs,
var=pd.DataFrame(index=df.columns),
uns=adata_2.uns,
)
npc.ne.aggregate(ad_tmp)
sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1)

assert "X_pca" in adata.obsm.keys()
assert "X_pca_harmony" in adata.obsm.keys()
assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"])
Loading
Loading