Skip to content

Commit

Permalink
Adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus committed Dec 5, 2023
1 parent a020f61 commit 3691526
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,43 @@ def test_labels_centers3():
assert np.all(np.equal(np.sort(np.flatnonzero(c)), np.sort(centers_flat[1:])))


def test_labels_centers_weights():
labels = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
]
)
weights = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 2, 0, 0],
[0, 2, 9, 2, 0],
[0, 0, 2, 0, 0],
[0, 0, 0, 0, 0],
]
)
centers = np.array([[2.0, 2.0]])

for i in range(labels.shape[0]):
for j in range(labels.shape[1]):
ret = eye.labels_centers_of_mass(labels, weights, [1], periodic=True)
assert np.allclose(ret, centers)

labels = np.roll(labels, 1, 1)
weights = np.roll(weights, 1, 1)
y = centers[:, 1] + 1
centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y)

labels = np.roll(labels, 1, 0)
weights = np.roll(weights, 1, 0)
x = centers[:, 0] + 1
centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x)


def test_prune():
segmenter = eye.ClusterLabeller(shape=(4, 4))
segmenter.add_points([0, 2, 8, 10, 1])
Expand Down

0 comments on commit 3691526

Please sign in to comment.