diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 5bb21c0..87e7340 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -302,6 +302,43 @@ def test_labels_centers3(): assert np.all(np.equal(np.sort(np.flatnonzero(c)), np.sort(centers_flat[1:]))) +def test_labels_centers_weights(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + weights = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 2, 0, 0], + [0, 2, 9, 2, 0], + [0, 0, 2, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array([[2.0, 2.0]]) + + for i in range(labels.shape[0]): + for j in range(labels.shape[1]): + ret = eye.labels_centers_of_mass(labels, weights, [1], periodic=True) + assert np.allclose(ret, centers) + + labels = np.roll(labels, 1, 1) + weights = np.roll(weights, 1, 1) + y = centers[:, 1] + 1 + centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y) + + labels = np.roll(labels, 1, 0) + weights = np.roll(weights, 1, 0) + x = centers[:, 0] + 1 + centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x) + + def test_prune(): segmenter = eye.ClusterLabeller(shape=(4, 4)) segmenter.add_points([0, 2, 8, 10, 1])