diff --git a/src/nichepca/workflows/_nichepca.py b/src/nichepca/workflows/_nichepca.py index 9678d16..f7f8c8c 100644 --- a/src/nichepca/workflows/_nichepca.py +++ b/src/nichepca/workflows/_nichepca.py @@ -161,6 +161,9 @@ def nichepca( suffix="", ) elif fn == "pca": + # pca requires float dtype + if "float" not in str(ad_tmp.X.dtype): + ad_tmp.X = ad_tmp.X.astype(np.float32) sc.tl.pca(ad_tmp, n_comps=n_comps) # run harmony if sample_key is provided and obs key is None if sample_key is not None and obs_key is None and allow_harmony: diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 86ce3d7..57ebfcb 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -68,6 +68,22 @@ def test_nichepca_single(): assert np.all(X_npca_0 == X_npca_1) + # test with pca on raw counts + pipeline = ("pca", "agg") + + adata = generate_dummy_adata() + npc.wf.nichepca(adata, knn=5, pipeline=pipeline) + X_npca_0 = adata.obsm["X_npca"] + + adata = generate_dummy_adata() + adata.X = adata.X.astype(np.float32) + sc.pp.pca(adata, n_comps=30) + npc.gc.knn_graph(adata, knn=5) + npc.ne.aggregate(adata, obsm_key="X_pca") + X_npca_1 = adata.obsm["X_pca_agg"] + + assert np.all(X_npca_0 == X_npca_1) + def test_nichepca_multi_sample(): adata_1 = generate_dummy_adata()