Skip to content

Commit

Permalink
Merge pull request #5 from dschaub95/nichepca-refactor
Browse files Browse the repository at this point in the history
fix bug when running no normalization and handle pipeline without pca
  • Loading branch information
dschaub95 authored Oct 1, 2024
2 parents 9d8c311 + 66c98fd commit 4010bc0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
40 changes: 23 additions & 17 deletions src/nichepca/workflows/_nichepca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
resolve_graph_constructor,
)
from nichepca.nhood_embedding import aggregate
from nichepca.utils import check_for_raw_counts, normalize_per_sample
from nichepca.utils import check_for_raw_counts, normalize_per_sample, to_numpy

if TYPE_CHECKING:
from anndata import AnnData
Expand Down Expand Up @@ -74,19 +74,22 @@ def nichepca(
-------
None
"""
# make sure pipeline is an iterable
if isinstance(pipeline, str):
pipeline = [pipeline]

# we always need to use agg
assert "agg" in pipeline, "aggregation must be part of the pipeline"

# assert that the pca is behind norm and log1p
if "pca" in pipeline:
pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "norm"
)
pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax(
np.array(pipeline) == "log1p"
)
if "pca" in pipeline and ("norm" in pipeline or "log1p" in pipeline):
pca_index = np.argmax(np.array(pipeline) == "pca")
norm_index = np.argmax(np.array(pipeline) == "norm")
log1p_index = np.argmax(np.array(pipeline) == "log1p")
# argmax returns 0 if not found
assert (
pca_after_norm and pca_after_log1p
), "pca must be executed after norm and log1p"
norm_index <= pca_index and log1p_index <= pca_index
), "PCA must be executed after both norm and log1p."

# perform sanity check in case we are normalizing the data
if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None:
Expand Down Expand Up @@ -170,14 +173,17 @@ def nichepca(
# extract the results and remove old keys
if "X_pca_harmony" in ad_tmp.obsm:
X_npca = ad_tmp.obsm["X_pca_harmony"]
else:
elif "X_pca" in ad_tmp.obsm:
X_npca = ad_tmp.obsm["X_pca"]
else:
X_npca = to_numpy(ad_tmp.X)

# store the results
adata.obsm["X_npca"] = X_npca
adata.uns["npca"] = ad_tmp.uns["pca"]
adata.uns["npca"]["PCs"] = pd.DataFrame(
data=ad_tmp.varm["PCs"],
index=ad_tmp.var_names,
columns=[f"PC{i}" for i in range(n_comps)],
)
if "pca" in pipeline:
adata.uns["npca"] = ad_tmp.uns["pca"]
adata.uns["npca"]["PCs"] = pd.DataFrame(
data=ad_tmp.varm["PCs"],
index=ad_tmp.var_names,
columns=[f"PC{i}" for i in range(n_comps)],
)
14 changes: 14 additions & 0 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def test_nichepca_single():

assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"])

# test without pca
pipeline = "agg"

adata = generate_dummy_adata()
npc.wf.nichepca(adata, knn=5, pipeline=pipeline)
X_npca_0 = adata.obsm["X_npca"]

adata = generate_dummy_adata()
npc.gc.knn_graph(adata, knn=5)
npc.ne.aggregate(adata)
X_npca_1 = npc.utils.to_numpy(adata.X)

assert np.all(X_npca_0 == X_npca_1)


def test_nichepca_multi_sample():
adata_1 = generate_dummy_adata()
Expand Down

0 comments on commit 4010bc0

Please sign in to comment.