Skip to content

Commit

Permalink
Upgrade relabel_map
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus committed Nov 23, 2023
1 parent 09147a8 commit 117fe51
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- ninja
- numpy
- pybind11
- pytest
- python
- scikit-build
- setuptools_scm
Expand Down
24 changes: 9 additions & 15 deletions include/GooseEYE/GooseEYE.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,20 @@ template <class T, std::enable_if_t<std::is_integral<typename T::value_type>::va
inline T dilate(const T& f, size_t iterations = 1, bool periodic = true);

/**
* Find map to relabel from "src" to "dest".
* @param src Image.
* @param dest Image.
* Find map to relabel from `a` to `b`.
* @param a Image.
* @param b Image.
* @return List of length `max(a) + 1` with per label in `a` the corresponding label in `b`.
*/
template <class T, class S>
array_type::tensor<size_t, 1> relabel_map(const T& src, const S& dest)
array_type::tensor<size_t, 1> relabel_map(const T& a, const S& b)
{
GOOSEEYE_ASSERT(xt::has_shape(src, dest.shape()));
GOOSEEYE_ASSERT(xt::has_shape(a, b.shape()));

array_type::tensor<size_t, 1> ret =
xt::zeros<size_t>({static_cast<size_t>(xt::amax(src)() + 1)});
auto A = xt::atleast_3d(src);
auto B = xt::atleast_3d(dest);
array_type::tensor<size_t, 1> ret = xt::zeros<size_t>({static_cast<size_t>(xt::amax(a)() + 1)});

for (size_t h = 0; h < A.shape(0); ++h) {
for (size_t i = 0; i < A.shape(1); ++i) {
for (size_t j = 0; j < A.shape(2); ++j) {
ret(A(h, i, j)) = B(h, i, j);
}
}
for (size_t i = 0; i < a.size(); ++i) {
ret(a.flat(i)) = b.flat(i);
}

return ret;
Expand Down
6 changes: 6 additions & 0 deletions python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ PYBIND11_MODULE(_GooseEYE, m)
py::arg("f"),
py::arg("periodic") = true);

m.def(
"relabel_map",
&GooseEYE::relabel_map<xt::pyarray<int>, xt::pyarray<int>>,
py::arg("a"),
py::arg("b"));

m.def(
"pos2img",
&GooseEYE::pos2img<xt::pyarray<size_t>, xt::pytensor<double, 2>, xt::pytensor<size_t, 1>>,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import GooseEYE as eye
import numpy as np


def test_relabel_map():
a = np.array(
[
[1, 1, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 3, 3, 0],
[0, 0, 3, 0, 0],
[1, 0, 0, 0, 1],
]
)

b = a.copy()
b = np.where(b == 1, 4, b)
b = np.where(b == 3, 7, b)

assert list(eye.relabel_map(a, b)) == [0, 4, 0, 7]

0 comments on commit 117fe51

Please sign in to comment.