From 63d53fba1159d32fde8e7e0aa556fa54763b99c1 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Thu, 26 Sep 2024 12:50:03 +0200 Subject: [PATCH 1/5] add mean aggr for sparse aggr --- src/nichepca/nhood_embedding/_aggr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nichepca/nhood_embedding/_aggr.py b/src/nichepca/nhood_embedding/_aggr.py index 31531d5..0fe87c7 100644 --- a/src/nichepca/nhood_embedding/_aggr.py +++ b/src/nichepca/nhood_embedding/_aggr.py @@ -90,15 +90,17 @@ def aggregate( elif backend == "sparse": N = adata.shape[0] + aggr = kwargs.get("aggr", "mean") edge_index = adata.uns["graph"]["edge_index"] A = scipy.sparse.csr_matrix( (np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])), shape=(N, N) ) - A_mean = A / A.sum(1) + if aggr == "mean": + A = A / A.sum(1) for _ in range(n_layers): - X = A_mean @ X + X = A @ X if out_key is not None: adata.obsm[out_key] = to_numpy(X) From 34147d52d6da42f0d83c37220174404a89143b0e Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Thu, 26 Sep 2024 13:01:21 +0200 Subject: [PATCH 2/5] simplify aggr function --- src/nichepca/nhood_embedding/_aggr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/nichepca/nhood_embedding/_aggr.py b/src/nichepca/nhood_embedding/_aggr.py index 0fe87c7..a4512e2 100644 --- a/src/nichepca/nhood_embedding/_aggr.py +++ b/src/nichepca/nhood_embedding/_aggr.py @@ -43,7 +43,7 @@ def aggregate( n_layers: int = 1, out_key: str | None = None, backend: str = "pyg", - **kwargs, + aggr="mean", ): """ Aggregate data in an AnnData object based on a previously constructed graph. @@ -79,7 +79,7 @@ def aggregate( X = adata.obsm[obsm_key] if backend == "pyg": - aggr_fn = GraphAggregation(**kwargs) + aggr_fn = GraphAggregation(aggr) X = torch.tensor(X).float() edge_index = torch.tensor(adata.uns["graph"]["edge_index"]) @@ -90,7 +90,6 @@ def aggregate( elif backend == "sparse": N = adata.shape[0] - aggr = kwargs.get("aggr", "mean") edge_index = adata.uns["graph"]["edge_index"] A = scipy.sparse.csr_matrix( From 49a047be1709366c0f11ee7aadaaf00c06f276d7 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Fri, 27 Sep 2024 17:12:50 +0200 Subject: [PATCH 3/5] add graph construction resolver, refactor multi sample graph construction, harmonize output format --- src/nichepca/graph_construction/__init__.py | 1 + .../graph_construction/_spatial_graph.py | 137 +++++++++++++----- 2 files changed, 103 insertions(+), 35 deletions(-) diff --git a/src/nichepca/graph_construction/__init__.py b/src/nichepca/graph_construction/__init__.py index 0256a24..e0773c9 100644 --- a/src/nichepca/graph_construction/__init__.py +++ b/src/nichepca/graph_construction/__init__.py @@ -7,5 +7,6 @@ knn_graph, print_graph_stats, remove_long_links, + resolve_graph_constructor, to_squidpy, ) diff --git a/src/nichepca/graph_construction/_spatial_graph.py b/src/nichepca/graph_construction/_spatial_graph.py index be1fa49..1e02673 100644 --- a/src/nichepca/graph_construction/_spatial_graph.py +++ b/src/nichepca/graph_construction/_spatial_graph.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING import numpy as np @@ -106,6 +107,7 @@ def store_graph( adata: AnnData, edge_index: torch.tensor | np.ndarray, edge_weight: torch.tensor | np.ndarray, + extra_key: str | None = None, ): """ Store graph data in `adata.uns`. @@ -118,6 +120,8 @@ def store_graph( Edge index of the graph. edge_weight : Union[torch.Tensor, np.ndarray] Edge weight of the graph. + extra_key : str, optional + Extra key to store the graph data in `adata.uns`. Returns ------- @@ -125,8 +129,13 @@ def store_graph( """ if "graph" not in adata.uns: adata.uns["graph"] = {} - adata.uns["graph"]["edge_index"] = to_numpy(edge_index) - adata.uns["graph"]["edge_weight"] = to_numpy(edge_weight) + if extra_key is not None: + adata.uns["graph"][extra_key] = {} + adata.uns["graph"][extra_key]["edge_index"] = to_numpy(edge_index) + adata.uns["graph"][extra_key]["edge_weight"] = to_numpy(edge_weight) + else: + adata.uns["graph"]["edge_index"] = to_numpy(edge_index) + adata.uns["graph"]["edge_weight"] = to_numpy(edge_weight) def knn_graph( @@ -309,8 +318,7 @@ def remove_long_links( if copy: return edge_index, edge_weight else: - adata.uns["graph"]["edge_index"] = edge_index - adata.uns["graph"]["edge_weight"] = edge_weight + store_graph(adata, edge_index, edge_weight) def delaunay_graph( @@ -368,6 +376,9 @@ def delaunay_graph( edge_index, edge_weight = remove_long_links( edge_index=edge_index, edge_weight=edge_weight ) + # convert back to torch tensors + edge_index = torch.tensor(edge_index, dtype=torch.long) + edge_weight = torch.tensor(edge_weight, dtype=torch.float) if verbose: print_graph_stats(edge_index=edge_index, num_nodes=adata.n_obs) @@ -386,16 +397,41 @@ def construct_multi_sample_graph( delaunay: bool = False, return_graph: bool = False, keep_local_edge_index: bool = False, + verbose: bool = True, **kwargs, ): - # make sure only one of knn, radius, delaunay is provided - assert ( - sum([knn is not None, radius is not None, delaunay]) == 1 - ), "Only one of knn, radius, delaunay must be provided." + """ + Construct a multi-sample graph from AnnData. + + Parameters + ---------- + adata : AnnData + Annotated data object. + sample_key : str + Key in `adata.obs` where sample information is stored. + knn : int | None, optional + Number of nearest neighbors. Defaults to None. + radius : float | None, optional + Radius for the distance graph constructor. Defaults to None. + delaunay : bool, optional + Whether to use the delaunay graph constructor. Defaults to False. + return_graph : bool, optional + Whether to return the graph instead of storing it in `adata`. Defaults to False. + keep_local_edge_index : bool, optional + Whether to keep the local edge index. Defaults to False. + verbose : bool, optional + Whether to print graph statistics. Defaults to True. + **kwargs + Additional keyword arguments passed to the graph constructor. + Returns + ------- + edge_index, edge_weight : torch.Tensor, torch.Tensor + Edge index and edge weight of the constructed graph if `return_graph` is True. + """ edge_index = [] edge_weight = [] - global_indices = np.arange(adata.n_obs) + global_indices = torch.arange(adata.n_obs) if "graph" not in adata.uns and not return_graph: adata.uns["graph"] = {} @@ -404,20 +440,16 @@ def construct_multi_sample_graph( mask = adata.obs[sample_key] == sample ad_sub = adata[mask] - local_global_indices = global_indices[mask] + local_global_indices = global_indices[mask.values] - if knn is not None: - local_edge_index, local_edge_weight = knn_graph( - ad_sub, knn, return_graph=True, **kwargs - ) - elif radius is not None: - local_edge_index, local_edge_weight = distance_graph( - ad_sub, radius, return_graph=True, **kwargs - ) - elif delaunay: - local_edge_index, local_edge_weight = delaunay_graph( - ad_sub, return_graph=True, **kwargs - ) + local_edge_index, local_edge_weight = resolve_graph_constructor( + radius, knn, delaunay + )( + ad_sub, + return_graph=True, + verbose=False, + **kwargs, + ) local_global_edge_index = local_global_indices[local_edge_index] @@ -425,25 +457,60 @@ def construct_multi_sample_graph( edge_weight.append(local_edge_weight) if not return_graph: - adata.uns["graph"][sample] = { - "edge_index": ( - to_numpy(local_edge_index) - if keep_local_edge_index - else to_numpy(local_global_edge_index) - ), - "edge_weight": to_numpy(local_edge_weight), - } - - edge_index = to_numpy(np.concatenate(edge_index, axis=1)) - edge_weight = to_numpy(np.concatenate(edge_weight, axis=0)) + store_graph( + adata, + local_edge_index if keep_local_edge_index else local_global_edge_index, + local_edge_weight, + extra_key=sample, + ) + + edge_index = torch.cat(edge_index, axis=1) + edge_weight = torch.cat(edge_weight, axis=0) + + if verbose: + print_graph_stats(edge_index=edge_index, num_nodes=adata.n_obs) if not return_graph: - adata.uns["graph"]["edge_index"] = edge_index - adata.uns["graph"]["edge_weight"] = edge_weight + store_graph(adata, edge_index, edge_weight) else: return edge_index, edge_weight +def resolve_graph_constructor( + radius: float | None = None, knn: int | None = None, delaunay: bool = False +): + """ + Resolves and returns the graph constructor based on the provided parameters. + + Parameters + ---------- + radius + The radius for the distance graph constructor. + knn + The number of nearest neighbors for the knn graph constructor. + delaunay + Whether to use the delaunay graph constructor. + + Returns + ------- + callable + The resolved graph constructor function with the appropriate parameters. + + Raises + ------ + ValueError + If no graph constructor is specified. + """ + if radius is not None: + return partial(distance_graph, radius=radius) + elif knn is not None: + return partial(knn_graph, knn=knn) + elif delaunay: + return delaunay_graph + else: + raise ValueError("No graph constructor specified.") + + def to_squidpy(adata: AnnData): """ Convert the pyg graph stored in `adata.uns` to squidpy format. From d7e79847907ce6183181a9144282857b5864ba99 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Fri, 27 Sep 2024 17:14:19 +0200 Subject: [PATCH 4/5] add test for graph construction resolver and correct mask handling --- tests/test_graph_construction.py | 36 +++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/test_graph_construction.py b/tests/test_graph_construction.py index 824e006..f10df01 100644 --- a/tests/test_graph_construction.py +++ b/tests/test_graph_construction.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import torch from utils import generate_dummy_adata import nichepca as npc @@ -123,7 +124,7 @@ def test_construct_multi_sample_graph(): sub_ad = adata_2[mask].copy() sub_ad.uns["graph"] = adata_2.uns["graph"][batch].copy() npc.ne.aggregate(sub_ad) - adata_2[mask].X = sub_ad.X.copy() + adata_2.X[mask] = sub_ad.X.copy() # manual variant adata_3.X = npc.utils.to_numpy(adata_3.X).astype(np.float32) @@ -166,3 +167,36 @@ def test_squidpy_conversion(): assert np.all(edge_index == edge_index_new) assert np.all(edge_weight == edge_weight_new) + + +def test_resolve_graph_constructor(): + adata = generate_dummy_adata() + + knn = 10 + edge_index_1, edge_weight_1 = npc.gc.knn_graph(adata, knn=knn, return_graph=True) + edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(knn=knn)( + adata, return_graph=True + ) + + assert torch.all(edge_index_1 == edge_index_2) + assert torch.all(edge_weight_1 == edge_weight_2) + + radius = 0.1 + edge_index_1, edge_weight_1 = npc.gc.distance_graph( + adata, radius=radius, return_graph=True + ) + edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(radius=radius)( + adata, return_graph=True + ) + + assert torch.all(edge_index_1 == edge_index_2) + assert torch.all(edge_weight_1 == edge_weight_2) + + edge_index_1, edge_weight_1 = npc.gc.delaunay_graph(adata, return_graph=True) + edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(delaunay=True)( + adata, return_graph=True + ) + print(type(edge_index_1), type(edge_index_2)) + + assert torch.all(edge_index_1 == edge_index_2) + assert torch.all(edge_weight_1 == edge_weight_2) From b7f798bc1911fac92dcbd06e336ab416aec3b41e Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Fri, 27 Sep 2024 17:15:02 +0200 Subject: [PATCH 5/5] add helper function to normalize counts per sample --- src/nichepca/utils/__init__.py | 2 +- src/nichepca/utils/_helper.py | 30 ++++++++++++++++++++++++++++ tests/test_utils.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/nichepca/utils/__init__.py b/src/nichepca/utils/__init__.py index b2d3d0b..2f2a155 100644 --- a/src/nichepca/utils/__init__.py +++ b/src/nichepca/utils/__init__.py @@ -1 +1 @@ -from ._helper import check_for_raw_counts, to_numpy, to_torch +from ._helper import check_for_raw_counts, normalize_per_sample, to_numpy, to_torch diff --git a/src/nichepca/utils/_helper.py b/src/nichepca/utils/_helper.py index f1a63a5..7329874 100644 --- a/src/nichepca/utils/_helper.py +++ b/src/nichepca/utils/_helper.py @@ -4,6 +4,7 @@ from warnings import warn import numpy as np +import scanpy as sc import scipy.sparse as sp import torch @@ -81,3 +82,32 @@ def check_for_raw_counts(adata: AnnData): UserWarning, stacklevel=1, ) + + +def normalize_per_sample(adata, sample_key, **kwargs): + """ + Normalize the per-sample counts in the `adata` object based on the given `sample_key`. + + Parameters + ---------- + adata : AnnData + The annotated data object. + sample_key : str + The key in `adata.obs` that identifies distinct samples. + kwargs : dict, optional + Additional keyword arguments to be passed to `sc.pp.normalize_total`. + + Returns + ------- + None + """ + if kwargs.get("target_sum", None) is not None: + # if target sum is provided, samples make no difference + sc.pp.normalize_total(adata, **kwargs) + else: + adata.X = adata.X.astype(np.float32) + for sample in adata.obs[sample_key].unique(): + mask = adata.obs[sample_key] == sample + sub_ad = adata[mask].copy() + sc.pp.normalize_total(sub_ad, **kwargs) + adata.X[mask.values] = sub_ad.X diff --git a/tests/test_utils.py b/tests/test_utils.py index 93f8a7b..86b90e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import scanpy as sc import torch from utils import generate_dummy_adata @@ -49,3 +50,38 @@ def test_check_for_raw_counts(): # Check for the specific warning with pytest.warns(UserWarning): npc.utils.check_for_raw_counts(adata) + + +def test_normalize_per_sample(): + sample_key = "sample" + + target_sum = 1e4 + + adata_1 = generate_dummy_adata() + npc.utils.normalize_per_sample( + adata_1, target_sum=target_sum, sample_key=sample_key + ) + + adata_2 = generate_dummy_adata() + sc.pp.normalize_total(adata_2, target_sum=target_sum) + + assert np.all(adata_1.X.toarray() == adata_2.X.toarray()) + + # second test without fixed target sum + target_sum = None + + adata_1 = generate_dummy_adata() + npc.utils.normalize_per_sample( + adata_1, target_sum=target_sum, sample_key=sample_key + ) + + adata_2 = generate_dummy_adata() + adata_2.X = adata_2.X.astype(np.float32).toarray() + + for sample in adata_2.obs[sample_key].unique(): + mask = adata_2.obs[sample_key] == sample + sub_ad = adata_2[mask].copy() + sc.pp.normalize_total(sub_ad) + adata_2.X[mask.values] = sub_ad.X + + assert np.all(adata_1.X.astype(np.float32).toarray() == adata_2.X)