diff --git a/include/GooseEYE/GooseEYE.h b/include/GooseEYE/GooseEYE.h index 6839c71b..d63be5c4 100644 --- a/include/GooseEYE/GooseEYE.h +++ b/include/GooseEYE/GooseEYE.h @@ -1282,8 +1282,11 @@ array_type::array clusters(const T& f, bool periodic = true) * @return The image, with labels inserted (overwritten) at the positions. */ template +[[deprecated("Will not be supported in the future. See Python warning for new API.")]] inline T pos2img(const T& img, const U& positions, const V& labels) { + GOOSEEYE_WARNING_PYTHON("pos2img(img, positions, labels) deprecated, use: " + "i = ravel_multi_index(positions.T, img.shape); img.flat[i] = labels") GOOSEEYE_ASSERT(img.dimension() > 0, std::out_of_range); GOOSEEYE_ASSERT(img.dimension() <= 3, std::out_of_range); GOOSEEYE_ASSERT(img.dimension() == positions.shape(1), std::out_of_range); @@ -1319,11 +1322,13 @@ inline T pos2img(const T& img, const U& positions, const V& labels) * https://en.wikipedia.org/wiki/Center_of_mass#Systems_with_periodic_boundary_conditions * * @param labels An image with labels [0..N]. + * @param weights Weights for each pixel. * @param periodic Switch to assume image periodic. * @return The position of the center of each label. */ -template -array_type::tensor center_of_mass(const T& labels, bool periodic = true) +template +array_type::tensor +labels_centers(const T& labels, const W& weights = nullptr, bool periodic = true) { static_assert(std::is_integral::value, "Integral labels required."); @@ -1364,6 +1369,20 @@ array_type::tensor center_of_mass(const T& labels, bool periodic = tr return xt::view(ret, xt::all(), xt::keep(axes)); } +/** + * @brief Get the position of the center of each label. + * @param labels An image with labels [0..N]. + * @param periodic Switch to assume image periodic. + * @return The position of the center of each label (in array-index). + */ +template +[[deprecated("Use labels_centers() instead.")]] array_type::tensor +center_of_mass(const T& labels, bool periodic = true) +{ + GOOSEEYE_WARNING_PYTHON("center_of_mass() is deprecated, use labels_centers() instead"); + return labels_centers(labels, nullptr, periodic); +} + /** * Compute ensemble averaged statistics, by repetitively calling the member-function of a certain * statistical measure with different data. diff --git a/python/main.cpp b/python/main.cpp index 51b81fbd..19b95489 100644 --- a/python/main.cpp +++ b/python/main.cpp @@ -212,10 +212,17 @@ PYBIND11_MODULE(_GooseEYE, m) m.def( "center_of_mass", - &GooseEYE::center_of_mass>, + &GooseEYE::center_of_mass>, py::arg("labels"), py::arg("periodic") = true); + m.def( + "labels_centers", + &GooseEYE::labels_centers, xt::pyarray>, + py::arg("labels"), + py::arg("weights").none(true) = nullptr, + py::arg("periodic") = true); + py::class_(m, "Ensemble") // Constructors diff --git a/tests/clusters.cpp b/tests/clusters.cpp deleted file mode 100644 index 668c36b1..00000000 --- a/tests/clusters.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include -#include - -TEST_CASE("GooseEYE::clusters", "clusters.hpp") -{ - - SECTION("center_of_mass - not periodic") - { - xt::xtensor l = { - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - }; - - xt::xtensor centers = {{2.0, 2.0}, {2.0, 2.0}}; - - REQUIRE(xt::allclose(GooseEYE::center_of_mass(l, false), centers)); - } - - SECTION("center_of_mass - periodic") - { - xt::xtensor l = { - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - }; - - xt::xtensor centers = {{4.5, 4.5}, {2.0, 2.0}}; - - double row = static_cast(l.shape(0)); - double col = static_cast(l.shape(1)); - - for (size_t i = 0; i < l.shape(0); ++i) { - for (size_t j = 0; j < l.shape(1); ++j) { - REQUIRE(xt::allclose(GooseEYE::center_of_mass(l), centers)); - l = xt::roll(l, 1, 1); - auto y = xt::view(centers, xt::all(), 1); - y += 1.0; - y = xt::where(y > col, y - col, y); - } - l = xt::roll(l, 1, 0); - auto x = xt::view(centers, xt::all(), 0); - x += 1.0; - x = xt::where(x > row, x - row, x); - } - } - - SECTION("center_of_mass - labels") - { - xt::xtensor l = { - {2, 3, 3, 3, 2}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {2, 4, 4, 4, 2}, - }; - - xt::xtensor c = { - {0, 0, 3, 0, 0}, - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - {0, 0, 4, 0, 2}, - }; - - xt::xtensor centers = { - {2.0, 4.5}, {2.0, 2.0}, {4.5, 4.5}, {0.0, 2.0}, {4.0, 2.0}}; - - auto res_centers = GooseEYE::center_of_mass(l); - xt::xtensor res_c = xt::zeros_like(l); - res_c = - GooseEYE::pos2img(res_c, xt::floor(res_centers), xt::arange(centers.shape(0))); - - REQUIRE(xt::allclose(res_centers, centers)); - REQUIRE(xt::all(xt::equal(res_c, c))); - } - - SECTION("Clusters::centers") - { - xt::xarray I = xt::zeros({5, 5}); - xt::xarray C = xt::zeros({5, 5}); - - I(0, 0) = 1; - I(0, 3) = 1; - I(0, 4) = 1; - I(3, 0) = 1; - I(3, 3) = 1; - I(3, 4) = 1; - - C(0, 4) = 1; - C(3, 4) = 2; - - GooseEYE::Clusters clusters(I, true); - REQUIRE(xt::all(xt::equal(C, clusters.centers()))); - REQUIRE(xt::all(xt::equal(GooseEYE::clusters(I, true), clusters.labels()))); - } -} diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 00b75130..afaf71b1 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -216,6 +216,92 @@ def test_labels_sizes(): ) +def test_labels_centers(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array( + [ + [2.0, 2.0], # background (0) + [2.0, 2.0], # label (1) + ] + ) + assert np.allclose(eye.labels_centers(labels, periodic=False), centers) + + +def test_labels_centers2(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array( + [ + [4.5, 4.5], # background (0) + [2.0, 2.0], # label (1) + ] + ) + + for i in range(labels.shape[0]): + for j in range(labels.shape[1]): + assert np.allclose(eye.labels_centers(labels, periodic=True), centers) + labels = np.roll(labels, 1, 1) + y = centers[:, 1] + 1 + centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y) + + labels = np.roll(labels, 1, 0) + x = centers[:, 0] + 1 + centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x) + + +def test_labels_centers3(): + labels = np.array( + [ + [2, 3, 3, 3, 2], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [2, 4, 4, 4, 2], + ] + ) + + c = np.array( + [ + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 4, 0, 2], + ] + ) + + centers = np.array( + [ + [2.0, 4.5], + [2.0, 2.0], + [4.5, 4.5], + [0.0, 2.0], + [4.0, 2.0], + ] + ) + + res_centers = eye.labels_centers(labels) + assert np.allclose(res_centers, centers) + + centers_flat = np.ravel_multi_index(np.floor(centers).astype(int).T, c.shape) + assert np.all(np.equal(np.sort(np.flatnonzero(c)), np.sort(centers_flat[1:]))) + + def test_prune(): segmenter = eye.ClusterLabeller(shape=(4, 4)) segmenter.add_points([0, 2, 8, 10, 1])