Skip to content

Commit

Permalink
add test for graph construction resolver and correct mask handling
Browse files Browse the repository at this point in the history
  • Loading branch information
dschaub95 committed Sep 27, 2024
1 parent 49a047b commit d7e7984
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion tests/test_graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
import torch
from utils import generate_dummy_adata

import nichepca as npc
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_construct_multi_sample_graph():
sub_ad = adata_2[mask].copy()
sub_ad.uns["graph"] = adata_2.uns["graph"][batch].copy()
npc.ne.aggregate(sub_ad)
adata_2[mask].X = sub_ad.X.copy()
adata_2.X[mask] = sub_ad.X.copy()

# manual variant
adata_3.X = npc.utils.to_numpy(adata_3.X).astype(np.float32)
Expand Down Expand Up @@ -166,3 +167,36 @@ def test_squidpy_conversion():

assert np.all(edge_index == edge_index_new)
assert np.all(edge_weight == edge_weight_new)


def test_resolve_graph_constructor():
adata = generate_dummy_adata()

knn = 10
edge_index_1, edge_weight_1 = npc.gc.knn_graph(adata, knn=knn, return_graph=True)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(knn=knn)(
adata, return_graph=True
)

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)

radius = 0.1
edge_index_1, edge_weight_1 = npc.gc.distance_graph(
adata, radius=radius, return_graph=True
)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(radius=radius)(
adata, return_graph=True
)

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)

edge_index_1, edge_weight_1 = npc.gc.delaunay_graph(adata, return_graph=True)
edge_index_2, edge_weight_2 = npc.gc.resolve_graph_constructor(delaunay=True)(
adata, return_graph=True
)
print(type(edge_index_1), type(edge_index_2))

assert torch.all(edge_index_1 == edge_index_2)
assert torch.all(edge_weight_1 == edge_weight_2)

0 comments on commit d7e7984

Please sign in to comment.