Skip to content

Commit

Permalink
Adding labels_centers (deprecating pos2img and center_of_mass)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus committed Dec 5, 2023
1 parent 6856dfe commit 3a00e1d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 105 deletions.
23 changes: 21 additions & 2 deletions include/GooseEYE/GooseEYE.h
Original file line number Diff line number Diff line change
Expand Up @@ -1282,8 +1282,11 @@ array_type::array<int> clusters(const T& f, bool periodic = true)
* @return The image, with labels inserted (overwritten) at the positions.
*/
template <typename T, typename U, typename V>
[[deprecated("Will not be supported in the future. See Python warning for new API.")]]
inline T pos2img(const T& img, const U& positions, const V& labels)
{
GOOSEEYE_WARNING_PYTHON("pos2img(img, positions, labels) deprecated, use: "
"i = ravel_multi_index(positions.T, img.shape); img.flat[i] = labels")
GOOSEEYE_ASSERT(img.dimension() > 0, std::out_of_range);
GOOSEEYE_ASSERT(img.dimension() <= 3, std::out_of_range);
GOOSEEYE_ASSERT(img.dimension() == positions.shape(1), std::out_of_range);
Expand Down Expand Up @@ -1319,11 +1322,13 @@ inline T pos2img(const T& img, const U& positions, const V& labels)
* https://en.wikipedia.org/wiki/Center_of_mass#Systems_with_periodic_boundary_conditions
*
* @param labels An image with labels [0..N].
* @param weights Weights for each pixel.
* @param periodic Switch to assume image periodic.
* @return The position of the center of each label.
*/
template <class T>
array_type::tensor<double, 2> center_of_mass(const T& labels, bool periodic = true)
template <class T, class W = decltype(nullptr)>
array_type::tensor<double, 2>
labels_centers(const T& labels, const W& weights = nullptr, bool periodic = true)
{
static_assert(std::is_integral<typename T::value_type>::value, "Integral labels required.");

Expand Down Expand Up @@ -1364,6 +1369,20 @@ array_type::tensor<double, 2> center_of_mass(const T& labels, bool periodic = tr
return xt::view(ret, xt::all(), xt::keep(axes));
}

/**
* @brief Get the position of the center of each label.
* @param labels An image with labels [0..N].
* @param periodic Switch to assume image periodic.
* @return The position of the center of each label (in array-index).
*/
template <class T>
[[deprecated("Use labels_centers() instead.")]] array_type::tensor<double, 2>
center_of_mass(const T& labels, bool periodic = true)
{
GOOSEEYE_WARNING_PYTHON("center_of_mass() is deprecated, use labels_centers() instead");
return labels_centers(labels, nullptr, periodic);
}

/**
* Compute ensemble averaged statistics, by repetitively calling the member-function of a certain
* statistical measure with different data.
Expand Down
9 changes: 8 additions & 1 deletion python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,17 @@ PYBIND11_MODULE(_GooseEYE, m)

m.def(
"center_of_mass",
&GooseEYE::center_of_mass<xt::pyarray<size_t>>,
&GooseEYE::center_of_mass<xt::pyarray<ptrdiff_t>>,
py::arg("labels"),
py::arg("periodic") = true);

m.def(
"labels_centers",
&GooseEYE::labels_centers<xt::pyarray<ptrdiff_t>, xt::pyarray<double>>,
py::arg("labels"),
py::arg("weights").none(true) = nullptr,
py::arg("periodic") = true);

py::class_<GooseEYE::Ensemble>(m, "Ensemble")

// Constructors
Expand Down
102 changes: 0 additions & 102 deletions tests/clusters.cpp

This file was deleted.

86 changes: 86 additions & 0 deletions tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,92 @@ def test_labels_sizes():
)


def test_labels_centers():
labels = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
]
)
centers = np.array(
[
[2.0, 2.0], # background (0)
[2.0, 2.0], # label (1)
]
)
assert np.allclose(eye.labels_centers(labels, periodic=False), centers)


def test_labels_centers2():
labels = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
]
)
centers = np.array(
[
[4.5, 4.5], # background (0)
[2.0, 2.0], # label (1)
]
)

for i in range(labels.shape[0]):
for j in range(labels.shape[1]):
assert np.allclose(eye.labels_centers(labels, periodic=True), centers)
labels = np.roll(labels, 1, 1)
y = centers[:, 1] + 1
centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y)

labels = np.roll(labels, 1, 0)
x = centers[:, 0] + 1
centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x)


def test_labels_centers3():
labels = np.array(
[
[2, 3, 3, 3, 2],
[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0],
[2, 4, 4, 4, 2],
]
)

c = np.array(
[
[0, 0, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 4, 0, 2],
]
)

centers = np.array(
[
[2.0, 4.5],
[2.0, 2.0],
[4.5, 4.5],
[0.0, 2.0],
[4.0, 2.0],
]
)

res_centers = eye.labels_centers(labels)
assert np.allclose(res_centers, centers)

centers_flat = np.ravel_multi_index(np.floor(centers).astype(int).T, c.shape)
assert np.all(np.equal(np.sort(np.flatnonzero(c)), np.sort(centers_flat[1:])))


def test_prune():
segmenter = eye.ClusterLabeller(shape=(4, 4))
segmenter.add_points([0, 2, 8, 10, 1])
Expand Down

0 comments on commit 3a00e1d

Please sign in to comment.