Skip to content

Commit

Permalink
Merge pull request #3 from dschaub95/main
Browse files Browse the repository at this point in the history
Prepare NichePCA refactoring
  • Loading branch information
dschaub95 authored Sep 27, 2024
2 parents 22c751f + b7f798b commit b64111d
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/nichepca/graph_construction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
knn_graph,
print_graph_stats,
remove_long_links,
resolve_graph_constructor,
to_squidpy,
)
137 changes: 102 additions & 35 deletions src/nichepca/graph_construction/_spatial_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -106,6 +107,7 @@ def store_graph(
adata: AnnData,
edge_index: torch.tensor | np.ndarray,
edge_weight: torch.tensor | np.ndarray,
extra_key: str | None = None,
):
"""
Store graph data in `adata.uns`.
Expand All @@ -118,15 +120,22 @@ def store_graph(
Edge index of the graph.
edge_weight : Union[torch.Tensor, np.ndarray]
Edge weight of the graph.
extra_key : str, optional
Extra key to store the graph data in `adata.uns`.
Returns
-------
None
"""
if "graph" not in adata.uns:
adata.uns["graph"] = {}
adata.uns["graph"]["edge_index"] = to_numpy(edge_index)
adata.uns["graph"]["edge_weight"] = to_numpy(edge_weight)
if extra_key is not None:
adata.uns["graph"][extra_key] = {}
adata.uns["graph"][extra_key]["edge_index"] = to_numpy(edge_index)
adata.uns["graph"][extra_key]["edge_weight"] = to_numpy(edge_weight)
else:
adata.uns["graph"]["edge_index"] = to_numpy(edge_index)
adata.uns["graph"]["edge_weight"] = to_numpy(edge_weight)


def knn_graph(
Expand Down Expand Up @@ -309,8 +318,7 @@ def remove_long_links(
if copy:
return edge_index, edge_weight
else:
adata.uns["graph"]["edge_index"] = edge_index
adata.uns["graph"]["edge_weight"] = edge_weight
store_graph(adata, edge_index, edge_weight)


def delaunay_graph(
Expand Down Expand Up @@ -368,6 +376,9 @@ def delaunay_graph(
edge_index, edge_weight = remove_long_links(
edge_index=edge_index, edge_weight=edge_weight
)
# convert back to torch tensors
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_weight = torch.tensor(edge_weight, dtype=torch.float)

if verbose:
print_graph_stats(edge_index=edge_index, num_nodes=adata.n_obs)
Expand All @@ -386,16 +397,41 @@ def construct_multi_sample_graph(
delaunay: bool = False,
return_graph: bool = False,
keep_local_edge_index: bool = False,
verbose: bool = True,
**kwargs,
):
# make sure only one of knn, radius, delaunay is provided
assert (
sum([knn is not None, radius is not None, delaunay]) == 1
), "Only one of knn, radius, delaunay must be provided."
"""
Construct a multi-sample graph from AnnData.
Parameters
----------
adata : AnnData
Annotated data object.
sample_key : str
Key in `adata.obs` where sample information is stored.
knn : int | None, optional
Number of nearest neighbors. Defaults to None.
radius : float | None, optional
Radius for the distance graph constructor. Defaults to None.
delaunay : bool, optional
Whether to use the delaunay graph constructor. Defaults to False.
return_graph : bool, optional
Whether to return the graph instead of storing it in `adata`. Defaults to False.
keep_local_edge_index : bool, optional
Whether to keep the local edge index. Defaults to False.
verbose : bool, optional
Whether to print graph statistics. Defaults to True.
**kwargs
Additional keyword arguments passed to the graph constructor.
Returns
-------
edge_index, edge_weight : torch.Tensor, torch.Tensor
Edge index and edge weight of the constructed graph if `return_graph` is True.
"""
edge_index = []
edge_weight = []
global_indices = np.arange(adata.n_obs)
global_indices = torch.arange(adata.n_obs)

if "graph" not in adata.uns and not return_graph:
adata.uns["graph"] = {}
Expand All @@ -404,46 +440,77 @@ def construct_multi_sample_graph(
mask = adata.obs[sample_key] == sample
ad_sub = adata[mask]

local_global_indices = global_indices[mask]
local_global_indices = global_indices[mask.values]

if knn is not None:
local_edge_index, local_edge_weight = knn_graph(
ad_sub, knn, return_graph=True, **kwargs
)
elif radius is not None:
local_edge_index, local_edge_weight = distance_graph(
ad_sub, radius, return_graph=True, **kwargs
)
elif delaunay:
local_edge_index, local_edge_weight = delaunay_graph(
ad_sub, return_graph=True, **kwargs
)
local_edge_index, local_edge_weight = resolve_graph_constructor(
radius, knn, delaunay
)(
ad_sub,
return_graph=True,
verbose=False,
**kwargs,
)

local_global_edge_index = local_global_indices[local_edge_index]

edge_index.append(local_global_edge_index)
edge_weight.append(local_edge_weight)

if not return_graph:
adata.uns["graph"][sample] = {
"edge_index": (
to_numpy(local_edge_index)
if keep_local_edge_index
else to_numpy(local_global_edge_index)
),
"edge_weight": to_numpy(local_edge_weight),
}

edge_index = to_numpy(np.concatenate(edge_index, axis=1))
edge_weight = to_numpy(np.concatenate(edge_weight, axis=0))
store_graph(
adata,
local_edge_index if keep_local_edge_index else local_global_edge_index,
local_edge_weight,
extra_key=sample,
)

edge_index = torch.cat(edge_index, axis=1)
edge_weight = torch.cat(edge_weight, axis=0)

if verbose:
print_graph_stats(edge_index=edge_index, num_nodes=adata.n_obs)

if not return_graph:
adata.uns["graph"]["edge_index"] = edge_index
adata.uns["graph"]["edge_weight"] = edge_weight
store_graph(adata, edge_index, edge_weight)
else:
return edge_index, edge_weight


def resolve_graph_constructor(
radius: float | None = None, knn: int | None = None, delaunay: bool = False
):
"""
Resolves and returns the graph constructor based on the provided parameters.
Parameters
----------
radius
The radius for the distance graph constructor.
knn
The number of nearest neighbors for the knn graph constructor.
delaunay
Whether to use the delaunay graph constructor.
Returns
-------
callable
The resolved graph constructor function with the appropriate parameters.
Raises
------
ValueError
If no graph constructor is specified.
"""
if radius is not None:
return partial(distance_graph, radius=radius)
elif knn is not None:
return partial(knn_graph, knn=knn)
elif delaunay:
return delaunay_graph
else:
raise ValueError("No graph constructor specified.")


def to_squidpy(adata: AnnData):
"""
Convert the pyg graph stored in `adata.uns` to squidpy format.
Expand Down
9 changes: 5 additions & 4 deletions src/nichepca/nhood_embedding/_aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def aggregate(
n_layers: int = 1,
out_key: str | None = None,
backend: str = "pyg",
**kwargs,
aggr="mean",
):
"""
Aggregate data in an AnnData object based on a previously constructed graph.
Expand Down Expand Up @@ -79,7 +79,7 @@ def aggregate(
X = adata.obsm[obsm_key]

if backend == "pyg":
aggr_fn = GraphAggregation(**kwargs)
aggr_fn = GraphAggregation(aggr)

X = torch.tensor(X).float()
edge_index = torch.tensor(adata.uns["graph"]["edge_index"])
Expand All @@ -95,10 +95,11 @@ def aggregate(
A = scipy.sparse.csr_matrix(
(np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])), shape=(N, N)
)
A_mean = A / A.sum(1)
if aggr == "mean":
A = A / A.sum(1)

for _ in range(n_layers):
X = A_mean @ X
X = A @ X

if out_key is not None:
adata.obsm[out_key] = to_numpy(X)
Expand Down
2 changes: 1 addition & 1 deletion src/nichepca/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._helper import check_for_raw_counts, to_numpy, to_torch
from ._helper import check_for_raw_counts, normalize_per_sample, to_numpy, to_torch
30 changes: 30 additions & 0 deletions src/nichepca/utils/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn

import numpy as np
import scanpy as sc
import scipy.sparse as sp
import torch

Expand Down Expand Up @@ -81,3 +82,32 @@ def check_for_raw_counts(adata: AnnData):
UserWarning,
stacklevel=1,
)


def normalize_per_sample(adata, sample_key, **kwargs):
"""
Normalize the per-sample counts in the `adata` object based on the given `sample_key`.
Parameters
----------
adata : AnnData
The annotated data object.
sample_key : str
The key in `adata.obs` that identifies distinct samples.
kwargs : dict, optional
Additional keyword arguments to be passed to `sc.pp.normalize_total`.
Returns
-------
None
"""
if kwargs.get("target_sum", None) is not None:
# if target sum is provided, samples make no difference
sc.pp.normalize_total(adata, **kwargs)
else:
adata.X = adata.X.astype(np.float32)
for sample in adata.obs[sample_key].unique():
mask = adata.obs[sample_key] == sample
sub_ad = adata[mask].copy()
sc.pp.normalize_total(sub_ad, **kwargs)
adata.X[mask.values] = sub_ad.X
36 changes: 35 additions & 1 deletion tests/test_graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
import torch
from utils import generate_dummy_adata

import nichepca as npc
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_construct_multi_sample_graph():
sub_ad = adata_2[mask].copy()
sub_ad.uns["graph"] = adata_2.uns["graph"][batch].copy()
npc.ne.aggregate(sub_ad)
adata_2[mask].X = sub_ad.X.copy()
adata_2.X[mask] = sub_ad.X.copy()

# manual variant
adata_3.X = npc.utils.to_numpy(adata_3.X).astype(np.float32)
Expand Down Expand Up @@ -166,3 +167,36 @@ def test_squidpy_conversion():

assert np.all(edge_index == edge_index_new)
assert np.all(edge_weight == edge_weight_new)


def test_resolve_graph_constructor():
adata = generate_dummy_adata()

knn = 10
edge_index_1, edge_weight_1 = npc.gc.knn_graph(adata, knn=knn, return_graph=True)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(knn=knn)(
adata, return_graph=True
)

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)

radius = 0.1
edge_index_1, edge_weight_1 = npc.gc.distance_graph(
adata, radius=radius, return_graph=True
)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(radius=radius)(
adata, return_graph=True
)

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)

edge_index_1, edge_weight_1 = npc.gc.delaunay_graph(adata, return_graph=True)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(delaunay=True)(
adata, return_graph=True
)
print(type(edge_index_1), type(edge_index_2))

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)
36 changes: 36 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import scanpy as sc
import torch
from utils import generate_dummy_adata

Expand Down Expand Up @@ -49,3 +50,38 @@ def test_check_for_raw_counts():
# Check for the specific warning
with pytest.warns(UserWarning):
npc.utils.check_for_raw_counts(adata)


def test_normalize_per_sample():
sample_key = "sample"

target_sum = 1e4

adata_1 = generate_dummy_adata()
npc.utils.normalize_per_sample(
adata_1, target_sum=target_sum, sample_key=sample_key
)

adata_2 = generate_dummy_adata()
sc.pp.normalize_total(adata_2, target_sum=target_sum)

assert np.all(adata_1.X.toarray() == adata_2.X.toarray())

# second test without fixed target sum
target_sum = None

adata_1 = generate_dummy_adata()
npc.utils.normalize_per_sample(
adata_1, target_sum=target_sum, sample_key=sample_key
)

adata_2 = generate_dummy_adata()
adata_2.X = adata_2.X.astype(np.float32).toarray()

for sample in adata_2.obs[sample_key].unique():
mask = adata_2.obs[sample_key] == sample
sub_ad = adata_2[mask].copy()
sc.pp.normalize_total(sub_ad)
adata_2.X[mask.values] = sub_ad.X

assert np.all(adata_1.X.astype(np.float32).toarray() == adata_2.X)

0 comments on commit b64111d

Please sign in to comment.