Skip to content

Commit

Permalink
BREAKING CHANGE Adding labels_centers/center, changing `center_of…
Browse files Browse the repository at this point in the history
…_mass`, deprecating `pos2img` (#96)
  • Loading branch information
tdegeus authored Dec 5, 2023
1 parent 6856dfe commit 674e659
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 135 deletions.
170 changes: 138 additions & 32 deletions include/GooseEYE/GooseEYE.h
Original file line number Diff line number Diff line change
Expand Up @@ -1282,8 +1282,11 @@ array_type::array<int> clusters(const T& f, bool periodic = true)
* @return The image, with labels inserted (overwritten) at the positions.
*/
template <typename T, typename U, typename V>
inline T pos2img(const T& img, const U& positions, const V& labels)
[[deprecated("Will not be supported in the future. See Python warning for new API.")]] inline T
pos2img(const T& img, const U& positions, const V& labels)
{
GOOSEEYE_WARNING_PYTHON("pos2img(img, positions, labels) deprecated, use: "
"i = ravel_multi_index(positions.T, img.shape); img.flat[i] = labels")
GOOSEEYE_ASSERT(img.dimension() > 0, std::out_of_range);
GOOSEEYE_ASSERT(img.dimension() <= 3, std::out_of_range);
GOOSEEYE_ASSERT(img.dimension() == positions.shape(1), std::out_of_range);
Expand Down Expand Up @@ -1312,56 +1315,159 @@ inline T pos2img(const T& img, const U& positions, const V& labels)
}

/**
* @brief Get the position of the center of each label.
* @brief Return the geometric center of a list of positions.
*
* @details
* For periodic algorithm, see:
* https://en.wikipedia.org/wiki/Center_of_mass#Systems_with_periodic_boundary_conditions
*
* @param labels An image with labels [0..N].
* @param shape Shape of the array.
* @param positions List of positions (in array coordinates).
* @param periodic Switch to assume array periodic.
* @return Coordinates of the center (in array coordinates).
*/
inline array_type::tensor<double, 1> center(
const array_type::tensor<double, 1>& shape,
const array_type::tensor<double, 2>& positions,
bool periodic = true)
{
if (positions.size() == 0) {
return xt::zeros<double>({shape.size()});
}

if (!periodic) {
return xt::mean(positions, 0);
}
else {
double pi = xt::numeric_constants<double>::PI;
auto theta = 2.0 * pi * positions / shape;
auto xi = xt::cos(theta);
auto zeta = xt::sin(theta);
auto xi_bar = xt::mean(xi, 0);
auto zeta_bar = xt::mean(zeta, 0);
auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi;
return shape * theta_bar / (2.0 * pi);
}
}

/**
* @copydoc GooseEYE::center()
* @param weights Weight for each position.
*/
inline array_type::tensor<double, 1> center_of_mass(
const array_type::tensor<double, 1>& shape,
const array_type::tensor<double, 2>& positions,
const array_type::tensor<double, 1>& weights,
bool periodic = true)
{
if (positions.size() == 0) {
return xt::zeros<double>({shape.size()});
}

if (!periodic) {
return xt::average(positions, weights, 0);
}
else {
double pi = xt::numeric_constants<double>::PI;
auto theta = 2.0 * pi * positions / shape;
auto xi = xt::cos(theta);
auto zeta = xt::sin(theta);
auto xi_bar = xt::average(xi, weights, 0);
auto zeta_bar = xt::average(zeta, weights, 0);
auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi;
return shape * theta_bar / (2.0 * pi);
}
}

namespace detail {

class PositionList {
public:
array_type::tensor<double, 2> positions;
array_type::tensor<double, 1> weights;

PositionList() = default;

template <class T>
void set(const T& condition)
{
positions = xt::from_indices(xt::argwhere(condition));
}

template <class T, class W>
void set(const T& condition, const W& w)
{
auto pos = xt::argwhere(condition);
weights = xt::empty<double>({pos.size()});
for (size_t i = 0; i < pos.size(); ++i) {
weights(i) = w[pos[i]];
}
positions = xt::from_indices(pos);
}
};

} // namespace detail

/**
* @brief Get the position of the center of each label.
*
* @param labels An image with labels.
* @param names List of labels to compute the center for.
* @param periodic Switch to assume image periodic.
* @return The position of the center of each label.
* @return Coordinates of the center (in array coordinates), in order of the unique (sorted) labels.
*/
template <class T>
array_type::tensor<double, 2> center_of_mass(const T& labels, bool periodic = true)
template <class T, class N>
inline array_type::tensor<double, 2>
labels_centers(const T& labels, const N& names, bool periodic = true)
{
static_assert(std::is_integral<typename T::value_type>::value, "Integral labels required.");

GOOSEEYE_ASSERT(labels.dimension() > 0, std::out_of_range);
GOOSEEYE_ASSERT(xt::all(labels >= 0), std::out_of_range);
GOOSEEYE_ASSERT(names.dimension() == 1, std::out_of_range);

double pi = xt::numeric_constants<double>::PI;
size_t N = static_cast<size_t>(xt::amax(labels)(0)) + 1ul;
size_t rank = labels.dimension();
auto axes = detail::atleast_3d_axes(rank);
array_type::tensor<double, 1> shape = xt::adapt(labels.shape());
array_type::tensor<double, 2> ret = xt::zeros<double>({N, rank});
array_type::tensor<double, 2> ret = xt::zeros<double>({names.size(), rank});
detail::PositionList plist;

for (size_t l = 0; l < N; ++l) {
array_type::tensor<double, 2> positions =
xt::from_indices(xt::argwhere(xt::equal(labels, l)));
if (positions.size() == 0) {
for (size_t l = 0; l < names.size(); ++l) {
plist.set(xt::equal(labels, names(l)));
if (plist.positions.size() == 0) {
continue;
}
if (!periodic) {
xt::view(ret, l, xt::all()) = xt::mean(positions, 0);
}
else {
if (xt::all(xt::equal(positions, 0.0))) {
continue;
}
auto theta = 2.0 * pi * positions / shape;
auto xi = xt::cos(theta);
auto zeta = xt::sin(theta);
auto xi_bar = xt::mean(xi, 0);
auto zeta_bar = xt::mean(zeta, 0);
auto theta_bar = xt::atan2(-zeta_bar, -xi_bar) + pi;
auto positions_bar = shape * theta_bar / (2.0 * pi);
xt::view(ret, l, xt::all()) = positions_bar;
xt::view(ret, l, xt::all()) = center(shape, plist.positions, periodic);
}

return ret;
}

/**
* @copydoc GooseEYE::labels_centers()
* @param weights Weight for each pixel.
*/
template <class T, class W, class N>
inline array_type::tensor<double, 2>
labels_centers_of_mass(const T& labels, const W& weights, const N& names, bool periodic = true)
{
static_assert(std::is_integral<typename T::value_type>::value, "Integral labels required.");
GOOSEEYE_ASSERT(xt::has_shape(labels, weights.shape()), std::out_of_range);
GOOSEEYE_ASSERT(labels.dimension() > 0, std::out_of_range);
GOOSEEYE_ASSERT(names.dimension() == 1, std::out_of_range);

size_t rank = labels.dimension();
array_type::tensor<double, 1> shape = xt::adapt(labels.shape());
array_type::tensor<double, 2> ret = xt::zeros<double>({names.size(), rank});
detail::PositionList plist;

for (size_t l = 0; l < names.size(); ++l) {
plist.set(xt::equal(labels, names(l)), weights);
if (plist.positions.size() == 0) {
continue;
}
xt::view(ret, l, xt::all()) =
center_of_mass(shape, plist.positions, plist.weights, periodic);
}

return xt::view(ret, xt::all(), xt::keep(axes));
return ret;
}

/**
Expand Down
29 changes: 28 additions & 1 deletion python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,37 @@ PYBIND11_MODULE(_GooseEYE, m)
py::arg("positions"),
py::arg("labels"));

m.def(
"center",
&GooseEYE::center,
py::arg("shape"),
py::arg("positions"),
py::arg("periodic") = true);

m.def(
"center_of_mass",
&GooseEYE::center_of_mass<xt::pyarray<size_t>>,
&GooseEYE::center_of_mass,
py::arg("shape"),
py::arg("positions"),
py::arg("weights"),
py::arg("periodic") = true);

m.def(
"labels_centers",
&GooseEYE::labels_centers<xt::pyarray<ptrdiff_t>, xt::pytensor<ptrdiff_t, 1>>,
py::arg("labels"),
py::arg("names"),
py::arg("periodic") = true);

m.def(
"labels_centers_of_mass",
&GooseEYE::labels_centers_of_mass<
xt::pyarray<ptrdiff_t>,
xt::pyarray<double>,
xt::pytensor<ptrdiff_t, 1>>,
py::arg("labels"),
py::arg("weights"),
py::arg("names"),
py::arg("periodic") = true);

py::class_<GooseEYE::Ensemble>(m, "Ensemble")
Expand Down
102 changes: 0 additions & 102 deletions tests/clusters.cpp

This file was deleted.

Loading

0 comments on commit 674e659

Please sign in to comment.