-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.py
117 lines (96 loc) · 3.22 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
from scipy.spatial import ConvexHull, Delaunay
import networkx as nx
__all__ = ["points2graph", "plus_freq_dim", "stereographic_projection"]
def stereographic_projection(points: np.ndarray):
"""Projects points on a unit sphere to a flat plane using stereographic projection.
Parameters
----------
points : ndarray
N x 3 array of points.
Returns
-------
points_proj : ndarray
N x 2 array of projected points.
"""
assert points.ndim == 2
assert points.shape[1] == 3
N, _ = points.shape
# project points to a plane
x, y, z = points[:, 0], points[:, 1], points[:, 2]
points_proj = np.zeros((N, 2))
points_proj[:, 0] = x / (1 - z)
points_proj[:, 1] = y / (1 - z)
return points_proj
def points2graph(points: np.ndarray, stereo_proj: bool = False):
"""Creates a graph from a set of points.
Parameters
----------
points : ndarray
N x 3 array of points.
Returns
-------
G : networkx.Graph
Graph of points.
"""
assert points.ndim == 2
assert points.shape[1] == 3
N, _ = points.shape
R = np.linalg.norm(points, axis=1)
points = points / R[:, None]
# create graph
if stereo_proj:
z = points[:, 2]
if (not np.all(z < 1)) and np.all(z > -1):
points = -points
else:
raise ValueError(
f"z values must be in (-1, 1), but got ({z.min()}, {z.max()})"
)
points_proj = stereographic_projection(points)
hull = Delaunay(points_proj, qhull_options="QJ")
else:
hull = ConvexHull(points)
edges = np.vstack(
(hull.simplices[:, :2], hull.simplices[:, 1:], hull.simplices[:, ::2])
)
G = nx.Graph()
G.add_edges_from(edges)
G = G.to_undirected()
hull_simplices = hull.simplices
# add the simplice at the bottom, which is represented as -1 in the neighbor simplices
mask = hull.neighbors == -1
if np.any(mask):
simplex_edges = np.stack(
(hull_simplices, np.roll(hull_simplices, -1, 1)), axis=2
)[np.roll(mask, 1, 1)]
simplex_G = nx.Graph(simplex_edges.tolist())
cycles = nx.cycle_basis(simplex_G)
assert len(cycles) == 1, "more than one cycle detected"
bottom_simplex = cycles[0]
assert len(bottom_simplex) == len(
simplex_G.nodes
), "bottom simplex is not complete"
print("Size of the bottom simplex:", len(bottom_simplex))
hull_simplices = hull_simplices.tolist()
hull_simplices.append(bottom_simplex)
else:
hull_simplices = hull_simplices.tolist()
return G, hull_simplices
def plus_freq_dim(G: nx.Graph, f: int):
assert not G.is_directed()
num_nodes = G.number_of_nodes()
edges = np.array(G.edges)
extended_edges = np.vstack([edges + i * num_nodes for i in range(f)])
idxs = np.arange(num_nodes)
extra_edges = np.vstack(
[
np.vstack([idxs + num_nodes * i, idxs + num_nodes * (i + 1)]).T
for i in range(f - 1)
]
)
full_edges = np.vstack([extended_edges, extra_edges])
new_G = nx.Graph()
new_G.add_edges_from(full_edges)
new_G = new_G.to_undirected()
return new_G