diff --git a/include/GooseEYE/GooseEYE.h b/include/GooseEYE/GooseEYE.h index 6839c71b..f066bc99 100644 --- a/include/GooseEYE/GooseEYE.h +++ b/include/GooseEYE/GooseEYE.h @@ -1282,8 +1282,11 @@ array_type::array clusters(const T& f, bool periodic = true) * @return The image, with labels inserted (overwritten) at the positions. */ template -inline T pos2img(const T& img, const U& positions, const V& labels) +[[deprecated("Will not be supported in the future. See Python warning for new API.")]] inline T +pos2img(const T& img, const U& positions, const V& labels) { + GOOSEEYE_WARNING_PYTHON("pos2img(img, positions, labels) deprecated, use: " + "i = ravel_multi_index(positions.T, img.shape); img.flat[i] = labels") GOOSEEYE_ASSERT(img.dimension() > 0, std::out_of_range); GOOSEEYE_ASSERT(img.dimension() <= 3, std::out_of_range); GOOSEEYE_ASSERT(img.dimension() == positions.shape(1), std::out_of_range); @@ -1312,56 +1315,159 @@ inline T pos2img(const T& img, const U& positions, const V& labels) } /** - * @brief Get the position of the center of each label. + * @brief Return the geometric center of a list of positions. * * @details * For periodic algorithm, see: * https://en.wikipedia.org/wiki/Center_of_mass#Systems_with_periodic_boundary_conditions * - * @param labels An image with labels [0..N]. + * @param shape Shape of the array. + * @param positions List of positions (in array coordinates). + * @param periodic Switch to assume array periodic. + * @return Coordinates of the center (in array coordinates). + */ +inline array_type::tensor center( + const array_type::tensor& shape, + const array_type::tensor& positions, + bool periodic = true) +{ + if (positions.size() == 0) { + return xt::zeros({shape.size()}); + } + + if (!periodic) { + return xt::mean(positions, 0); + } + else { + double pi = xt::numeric_constants::PI; + auto theta = 2.0 * pi * positions / shape; + auto xi = xt::cos(theta); + auto zeta = xt::sin(theta); + auto xi_bar = xt::mean(xi, 0); + auto zeta_bar = xt::mean(zeta, 0); + auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi; + return shape * theta_bar / (2.0 * pi); + } +} + +/** + * @copydoc GooseEYE::center() + * @param weights Weight for each position. + */ +inline array_type::tensor center_of_mass( + const array_type::tensor& shape, + const array_type::tensor& positions, + const array_type::tensor& weights, + bool periodic = true) +{ + if (positions.size() == 0) { + return xt::zeros({shape.size()}); + } + + if (!periodic) { + return xt::average(positions, weights, 0); + } + else { + double pi = xt::numeric_constants::PI; + auto theta = 2.0 * pi * positions / shape; + auto xi = xt::cos(theta); + auto zeta = xt::sin(theta); + auto xi_bar = xt::average(xi, weights, 0); + auto zeta_bar = xt::average(zeta, weights, 0); + auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi; + return shape * theta_bar / (2.0 * pi); + } +} + +namespace detail { + +class PositionList { +public: + array_type::tensor positions; + array_type::tensor weights; + + PositionList() = default; + + template + void set(const T& condition) + { + positions = xt::from_indices(xt::argwhere(condition)); + } + + template + void set(const T& condition, const W& w) + { + auto pos = xt::argwhere(condition); + weights = xt::empty({pos.size()}); + for (size_t i = 0; i < pos.size(); ++i) { + weights(i) = w[pos[i]]; + } + positions = xt::from_indices(pos); + } +}; + +} // namespace detail + +/** + * @brief Get the position of the center of each label. + * + * @param labels An image with labels. + * @param names List of labels to compute the center for. * @param periodic Switch to assume image periodic. - * @return The position of the center of each label. + * @return Coordinates of the center (in array coordinates), in order of the unique (sorted) labels. */ -template -array_type::tensor center_of_mass(const T& labels, bool periodic = true) +template +inline array_type::tensor +labels_centers(const T& labels, const N& names, bool periodic = true) { static_assert(std::is_integral::value, "Integral labels required."); - GOOSEEYE_ASSERT(labels.dimension() > 0, std::out_of_range); - GOOSEEYE_ASSERT(xt::all(labels >= 0), std::out_of_range); + GOOSEEYE_ASSERT(names.dimension() == 1, std::out_of_range); - double pi = xt::numeric_constants::PI; - size_t N = static_cast(xt::amax(labels)(0)) + 1ul; size_t rank = labels.dimension(); - auto axes = detail::atleast_3d_axes(rank); array_type::tensor shape = xt::adapt(labels.shape()); - array_type::tensor ret = xt::zeros({N, rank}); + array_type::tensor ret = xt::zeros({names.size(), rank}); + detail::PositionList plist; - for (size_t l = 0; l < N; ++l) { - array_type::tensor positions = - xt::from_indices(xt::argwhere(xt::equal(labels, l))); - if (positions.size() == 0) { + for (size_t l = 0; l < names.size(); ++l) { + plist.set(xt::equal(labels, names(l))); + if (plist.positions.size() == 0) { continue; } - if (!periodic) { - xt::view(ret, l, xt::all()) = xt::mean(positions, 0); - } - else { - if (xt::all(xt::equal(positions, 0.0))) { - continue; - } - auto theta = 2.0 * pi * positions / shape; - auto xi = xt::cos(theta); - auto zeta = xt::sin(theta); - auto xi_bar = xt::mean(xi, 0); - auto zeta_bar = xt::mean(zeta, 0); - auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi; - auto positions_bar = shape * theta_bar / (2.0 * pi); - xt::view(ret, l, xt::all()) = positions_bar; + xt::view(ret, l, xt::all()) = center(shape, plist.positions, periodic); + } + + return ret; +} + +/** + * @copydoc GooseEYE::labels_centers() + * @param weights Weight for each pixel. + */ +template +inline array_type::tensor +labels_centers_of_mass(const T& labels, const W& weights, const N& names, bool periodic = true) +{ + static_assert(std::is_integral::value, "Integral labels required."); + GOOSEEYE_ASSERT(xt::has_shape(labels, weights.shape()), std::out_of_range); + GOOSEEYE_ASSERT(labels.dimension() > 0, std::out_of_range); + GOOSEEYE_ASSERT(names.dimension() == 1, std::out_of_range); + + size_t rank = labels.dimension(); + array_type::tensor shape = xt::adapt(labels.shape()); + array_type::tensor ret = xt::zeros({names.size(), rank}); + detail::PositionList plist; + + for (size_t l = 0; l < names.size(); ++l) { + plist.set(xt::equal(labels, names(l)), weights); + if (plist.positions.size() == 0) { + continue; } + xt::view(ret, l, xt::all()) = + center_of_mass(shape, plist.positions, plist.weights, periodic); } - return xt::view(ret, xt::all(), xt::keep(axes)); + return ret; } /** diff --git a/python/main.cpp b/python/main.cpp index 51b81fbd..56f048de 100644 --- a/python/main.cpp +++ b/python/main.cpp @@ -210,10 +210,37 @@ PYBIND11_MODULE(_GooseEYE, m) py::arg("positions"), py::arg("labels")); + m.def( + "center", + &GooseEYE::center, + py::arg("shape"), + py::arg("positions"), + py::arg("periodic") = true); + m.def( "center_of_mass", - &GooseEYE::center_of_mass>, + &GooseEYE::center_of_mass, + py::arg("shape"), + py::arg("positions"), + py::arg("weights"), + py::arg("periodic") = true); + + m.def( + "labels_centers", + &GooseEYE::labels_centers, xt::pytensor>, + py::arg("labels"), + py::arg("names"), + py::arg("periodic") = true); + + m.def( + "labels_centers_of_mass", + &GooseEYE::labels_centers_of_mass< + xt::pyarray, + xt::pyarray, + xt::pytensor>, py::arg("labels"), + py::arg("weights"), + py::arg("names"), py::arg("periodic") = true); py::class_(m, "Ensemble") diff --git a/tests/clusters.cpp b/tests/clusters.cpp deleted file mode 100644 index 668c36b1..00000000 --- a/tests/clusters.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include -#include - -TEST_CASE("GooseEYE::clusters", "clusters.hpp") -{ - - SECTION("center_of_mass - not periodic") - { - xt::xtensor l = { - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - }; - - xt::xtensor centers = {{2.0, 2.0}, {2.0, 2.0}}; - - REQUIRE(xt::allclose(GooseEYE::center_of_mass(l, false), centers)); - } - - SECTION("center_of_mass - periodic") - { - xt::xtensor l = { - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - }; - - xt::xtensor centers = {{4.5, 4.5}, {2.0, 2.0}}; - - double row = static_cast(l.shape(0)); - double col = static_cast(l.shape(1)); - - for (size_t i = 0; i < l.shape(0); ++i) { - for (size_t j = 0; j < l.shape(1); ++j) { - REQUIRE(xt::allclose(GooseEYE::center_of_mass(l), centers)); - l = xt::roll(l, 1, 1); - auto y = xt::view(centers, xt::all(), 1); - y += 1.0; - y = xt::where(y > col, y - col, y); - } - l = xt::roll(l, 1, 0); - auto x = xt::view(centers, xt::all(), 0); - x += 1.0; - x = xt::where(x > row, x - row, x); - } - } - - SECTION("center_of_mass - labels") - { - xt::xtensor l = { - {2, 3, 3, 3, 2}, - {0, 0, 1, 0, 0}, - {0, 1, 1, 1, 0}, - {0, 0, 1, 0, 0}, - {2, 4, 4, 4, 2}, - }; - - xt::xtensor c = { - {0, 0, 3, 0, 0}, - {0, 0, 0, 0, 0}, - {0, 0, 1, 0, 0}, - {0, 0, 0, 0, 0}, - {0, 0, 4, 0, 2}, - }; - - xt::xtensor centers = { - {2.0, 4.5}, {2.0, 2.0}, {4.5, 4.5}, {0.0, 2.0}, {4.0, 2.0}}; - - auto res_centers = GooseEYE::center_of_mass(l); - xt::xtensor res_c = xt::zeros_like(l); - res_c = - GooseEYE::pos2img(res_c, xt::floor(res_centers), xt::arange(centers.shape(0))); - - REQUIRE(xt::allclose(res_centers, centers)); - REQUIRE(xt::all(xt::equal(res_c, c))); - } - - SECTION("Clusters::centers") - { - xt::xarray I = xt::zeros({5, 5}); - xt::xarray C = xt::zeros({5, 5}); - - I(0, 0) = 1; - I(0, 3) = 1; - I(0, 4) = 1; - I(3, 0) = 1; - I(3, 3) = 1; - I(3, 4) = 1; - - C(0, 4) = 1; - C(3, 4) = 2; - - GooseEYE::Clusters clusters(I, true); - REQUIRE(xt::all(xt::equal(C, clusters.centers()))); - REQUIRE(xt::all(xt::equal(GooseEYE::clusters(I, true), clusters.labels()))); - } -} diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 00b75130..87e7340a 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -216,6 +216,129 @@ def test_labels_sizes(): ) +def test_labels_centers(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array( + [ + [2.0, 2.0], # background (0) + [2.0, 2.0], # label (1) + ] + ) + assert np.allclose(eye.labels_centers(labels, [0, 1], periodic=False), centers) + + +def test_labels_centers2(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array( + [ + [4.5, 4.5], # background (0) + [2.0, 2.0], # label (1) + ] + ) + + for i in range(labels.shape[0]): + for j in range(labels.shape[1]): + assert np.allclose(eye.labels_centers(labels, [0, 1], periodic=True), centers) + labels = np.roll(labels, 1, 1) + y = centers[:, 1] + 1 + centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y) + + labels = np.roll(labels, 1, 0) + x = centers[:, 0] + 1 + centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x) + + +def test_labels_centers3(): + labels = np.array( + [ + [2, 3, 3, 3, 2], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [2, 4, 4, 4, 2], + ] + ) + + c = np.array( + [ + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 4, 0, 2], + ] + ) + + centers = np.array( + [ + [2.0, 4.5], + [2.0, 2.0], + [4.5, 4.5], + [0.0, 2.0], + [4.0, 2.0], + ] + ) + + res_centers = eye.labels_centers(labels, [0, 1, 2, 3, 4]) + assert np.allclose(res_centers, centers) + + centers_flat = np.ravel_multi_index(np.floor(centers).astype(int).T, c.shape) + assert np.all(np.equal(np.sort(np.flatnonzero(c)), np.sort(centers_flat[1:]))) + + +def test_labels_centers_weights(): + labels = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + weights = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 2, 0, 0], + [0, 2, 9, 2, 0], + [0, 0, 2, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + centers = np.array([[2.0, 2.0]]) + + for i in range(labels.shape[0]): + for j in range(labels.shape[1]): + ret = eye.labels_centers_of_mass(labels, weights, [1], periodic=True) + assert np.allclose(ret, centers) + + labels = np.roll(labels, 1, 1) + weights = np.roll(weights, 1, 1) + y = centers[:, 1] + 1 + centers[:, 1] = np.where(y > labels.shape[1], y - labels.shape[1], y) + + labels = np.roll(labels, 1, 0) + weights = np.roll(weights, 1, 0) + x = centers[:, 0] + 1 + centers[:, 0] = np.where(x > labels.shape[0], x - labels.shape[0], x) + + def test_prune(): segmenter = eye.ClusterLabeller(shape=(4, 4)) segmenter.add_points([0, 2, 8, 10, 1])