Skip to content

Commit

Permalink
Optimised strides: dealing with edge cases (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus authored Dec 5, 2023
1 parent d980cb4 commit 7a7beae
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 35 deletions.
160 changes: 128 additions & 32 deletions include/GooseEYE/GooseEYE.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ namespace detail {
* @return Array of distances.
*/
template <size_t Dim, class T>
inline array_type::tensor<ptrdiff_t, Dim> kernel_to_dx(T kernel)
inline array_type::tensor<ptrdiff_t, 2> kernel_to_dx(T kernel)
{
#ifdef GOOSEEYE_ENABLE_ASSERT
for (size_t i = 0; i < Dim; ++i) {
Expand All @@ -453,7 +453,10 @@ inline array_type::tensor<ptrdiff_t, Dim> kernel_to_dx(T kernel)
kernel.flat(idx) = 0;

if constexpr (Dim == 1) {
return xt::flatten_indices(xt::argwhere(kernel)) - mid[0];
auto i = xt::flatten_indices(xt::argwhere(kernel)) - mid[0];
array_type::tensor<ptrdiff_t, 2> ret = xt::empty<ptrdiff_t>({i.size(), size_t(1)});
std::copy(i.begin(), i.end(), ret.begin());
return ret;
}

auto ret = xt::from_indices(xt::argwhere(kernel));
Expand All @@ -478,7 +481,7 @@ class ClusterLabeller {

private:
std::array<ptrdiff_t, Dim> m_shape; ///< Shape of the system.
array_type::tensor<ptrdiff_t, Dim> m_dx; ///< Kernel (in distances along each dimension).
array_type::tensor<ptrdiff_t, 2> m_dx; ///< Kernel (in distances along each dimension).
array_type::tensor<ptrdiff_t, Dim> m_label; ///< Per block, the label (`0` for background).
ptrdiff_t m_new_label = 1; ///< The next label number to assign.
size_t m_nmerge = 0; ///< Number of times that clusters have been merged.
Expand All @@ -505,6 +508,9 @@ class ClusterLabeller {
std::vector<ptrdiff_t> m_next;
std::vector<ptrdiff_t> m_connected; ///< List of labels connected to the current block.

typedef ptrdiff_t (ClusterLabeller<Dimension, Periodicity>::*CompareImpl)(size_t, size_t);
CompareImpl get_compare = &ClusterLabeller<Dimension, Periodicity>::get_compare_default;

public:
/**
* @param shape @copydoc ClusterLabeller::m_shape
Expand All @@ -514,12 +520,18 @@ class ClusterLabeller {
{
if constexpr (Dim == 1) {
// kernel = {1, 1, 1}
m_dx = {-1, 1};
m_dx = {{-1}, {1}};
}
else if constexpr (Dim == 2) {
// kernel = {{0, 1, 0}, {1, 1, 1}, {0, 1, 0}};
m_dx = {{-1, 0}, {0, -1}, {0, 1}, {1, 0}};
}
else if constexpr (Dim == 3) {
m_dx = {{-1, 0, 0}, {0, -1, 0}, {0, 0, -1}, {0, 0, 1}, {0, 1, 0}, {1, 0, 0}};
}
else {
throw std::runtime_error("Please specify the kernel in dimensions > 3.");
}
this->init(shape);
}

Expand All @@ -538,17 +550,28 @@ class ClusterLabeller {
template <class T>
void init(const T& shape)
{
static_assert(Dim == 1 || Dim == 2, "WIP: 1d and 2d supported.");
m_label = xt::empty<ptrdiff_t>(shape);
m_renum.resize(m_label.size() + 1);
m_next.resize(m_label.size() + 1);
for (size_t i = 0; i < Dim; ++i) {
m_shape[i] = static_cast<ptrdiff_t>(shape[i]);
m_strides[i] = static_cast<ptrdiff_t>(m_label.strides()[i]);
if constexpr (Dim >= 2) {
m_strides[i] = static_cast<ptrdiff_t>(m_label.strides()[i]);
}
}
GOOSEEYE_ASSERT(m_strides.back() == 1, std::out_of_range);
this->reset();
m_connected.resize(m_dx.shape(0));

// Dim == 2: by default strides are assumed non-zero to avoid extra checks
// check once if zeros strides occur and if so use a special implementation of unravel_index
if constexpr (Dim == 2) {
if (m_shape[0] == 1) {
get_compare = &ClusterLabeller<Dimension, Periodicity>::get_compare_2d_1n;
}
else if (m_shape[1] == 1) {
get_compare = &ClusterLabeller<Dimension, Periodicity>::get_compare_2d_n1;
}
}
}

public:
Expand Down Expand Up @@ -660,37 +683,104 @@ class ClusterLabeller {
m_nmerge = 0;
}

void label_impl(size_t idx)
/**
* @copydoc ClusterLabeller::get_compare_default
* @warning Implementation only for 2D, `shape = [1, n]`.
*/
ptrdiff_t get_compare_2d_1n(size_t idx, size_t j)
{
static_assert(Dim == 1 || Dim == 2, "WIP: 1d and 2d supported.");
if constexpr (Periodic) {
return (m_shape[1] + idx + m_dx(j, 1)) % m_shape[1];
}
if constexpr (!Periodic) {
ptrdiff_t compare = idx + m_dx(j, 1);
if (compare < 0 || compare >= m_shape[1]) {
return -1;
}
return compare;
}
}

ptrdiff_t compare;
size_t nconnected = 0;
/**
* @copydoc ClusterLabeller::get_compare_default
* @warning Implementation only for 2D, `shape = [n, 1]`.
*/
ptrdiff_t get_compare_2d_n1(size_t idx, size_t j)
{
if constexpr (Periodic) {
return (m_shape[0] + idx + m_dx(j, 0)) % m_shape[0];
}
if constexpr (!Periodic) {
ptrdiff_t compare = idx + m_dx(j, 0);
if (compare < 0 || compare >= m_shape[0]) {
return -1;
}
return compare;
}
}

for (size_t j = 0; j < m_dx.shape(0); ++j) {
if constexpr (Dim == 1 && Periodic) {
compare = (m_shape[0] + idx + m_dx(j)) % m_shape[0];
/**
* @brief Get the pixel to compare with.
* @note If the pixel is out of bounds, return -1.
*
* @param idx Flat index of the pixel.
* @param j Index of the kernel (in `m_dx`).
* @return ptrdiff_t Index of the pixel to compare with.
*/
ptrdiff_t get_compare_default(size_t idx, size_t j)
{
if constexpr (Dim == 1 && Periodic) {
return (m_shape[0] + idx + m_dx.flat(j)) % m_shape[0];
}
if constexpr (Dim == 1 && !Periodic) {
ptrdiff_t compare = idx + m_dx.flat(j);
if (compare < 0 || compare >= m_shape[0]) {
return -1;
}
else if constexpr (Dim == 1 && !Periodic) {
if (compare < 0 || compare >= m_shape[0]) {
continue;
}
compare = idx + m_dx(j);
return idx + m_dx.flat(j);
}
if constexpr (Dim == 2 && Periodic) {
ptrdiff_t ii = (m_shape[0] + (idx / m_strides[0]) + m_dx(j, 0)) % m_shape[0];
ptrdiff_t jj = (m_shape[1] + (idx % m_strides[0]) + m_dx(j, 1)) % m_shape[1];
return ii * m_shape[1] + jj;
}
if constexpr (Dim == 2 && !Periodic) {
ptrdiff_t ii = (idx / m_strides[0]) + m_dx(j, 0);
ptrdiff_t jj = (idx % m_strides[0]) + m_dx(j, 1);
if (ii < 0 || ii >= m_shape[0] || jj < 0 || jj >= m_shape[1]) {
return -1;
}
else if constexpr (Dim == 2 && Periodic) {
ptrdiff_t ii = (m_shape[0] + (idx / m_strides[0]) + m_dx(j, 0)) % m_shape[0];
ptrdiff_t jj = (m_shape[1] + (idx % m_strides[0]) + m_dx(j, 1)) % m_shape[1];
compare = ii * m_shape[1] + jj;
return ii * m_shape[1] + jj;
}
else {
auto index = xt::unravel_from_strides(idx, m_strides, xt::layout_type::row_major);
for (size_t d = 0; d < Dim; ++d) {
index[d] += m_dx(j, d);
if constexpr (!Periodic) {
if (index[d] < 0 || index[d] >= m_shape[d]) {
return -1;
}
}
else {
auto n = m_shape[d];
index[d] = (n + (index[d] % n)) % n;
}
}
else if constexpr (Dim == 2 && !Periodic) {
ptrdiff_t ii = (idx / m_strides[0]) + m_dx(j, 0);
ptrdiff_t jj = (idx % m_strides[0]) + m_dx(j, 1);
if (ii < 0 || ii >= m_shape[0] || jj < 0 || jj >= m_shape[1]) {
return xt::ravel_index(index, m_shape, xt::layout_type::row_major);
}
}

void label_impl(size_t idx)
{
size_t nconnected = 0;

for (size_t j = 0; j < m_dx.shape(0); ++j) {
ptrdiff_t compare = (this->*get_compare)(idx, j);
if constexpr (!Periodic) {
if (compare == -1) {
continue;
}
compare = ii * m_shape[1] + jj;
}

if (m_label.flat(compare) != 0) {
m_connected[nconnected] = m_renum[m_label.flat(compare)];
nconnected++;
Expand Down Expand Up @@ -1160,6 +1250,8 @@ class ClusterLabellerOverload : public ClusterLabeller<Dimension, Periodicity> {
template <class T>
array_type::array<int> clusters(const T& f, bool periodic = true)
{
GOOSEEYE_ASSERT(f.layout() == xt::layout_type::row_major, std::runtime_error);

auto n = f.dimension();
if (n == 1 && periodic) {
return detail::ClusterLabellerOverload<1, true>(f).get();
Expand All @@ -1173,9 +1265,13 @@ array_type::array<int> clusters(const T& f, bool periodic = true)
if (n == 2 && !periodic) {
return detail::ClusterLabellerOverload<2, false>(f).get();
}

GOOSEEYE_WARNING("WIP: updated 3d implementation needs to be completed. Please file a PR.");
return Clusters(f, kernel::nearest(f.dimension()), periodic).labels();
if (n == 3 && periodic) {
return detail::ClusterLabellerOverload<3, true>(f).get();
}
if (n == 3 && !periodic) {
return detail::ClusterLabellerOverload<3, false>(f).get();
}
throw std::runtime_error("Please use ClusterLabeller directly for dimensions > 3.");
}

/**
Expand Down
6 changes: 6 additions & 0 deletions include/GooseEYE/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
* @endcond
*/

/**
* @brief Assertions that are always enabled.
*/
#define GOOSEEYE_REQUIRE(expr, assertion) \
GOOSEEYE_ASSERT_IMPL(expr, assertion, __FILE__, __LINE__, __FUNCTION__)

/**
* All assertions are implemented as:
*
Expand Down
6 changes: 5 additions & 1 deletion python/GooseEYE/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def ClusterLabeller(shape, periodic=True, **kwargs):
if periodic:
return ClusterLabeller2p(shape, **kwargs)
return ClusterLabeller2(shape, **kwargs)
raise NotImplementedError("3d extension needed, please open a PR")
if len(shape) == 3:
if periodic:
return ClusterLabeller3p(shape, **kwargs)
return ClusterLabeller3(shape, **kwargs)
raise NotImplementedError("dimension > 3 not compiled by default")


class Structure(enstat.static):
Expand Down
4 changes: 2 additions & 2 deletions python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ PYBIND11_MODULE(_GooseEYE, m)
py::arg("iterations") = 1,
py::arg("periodic") = true);

static_for<1, 3>(
static_for<1, 4>(
[&](auto i) { allocate_ClusterLabeller<GooseEYE::ClusterLabeller<i, true>>(m); });
static_for<1, 3>(
static_for<1, 4>(
[&](auto i) { allocate_ClusterLabeller<GooseEYE::ClusterLabeller<i, false>>(m); });

py::class_<GooseEYE::Clusters>(m, "Clusters")
Expand Down
11 changes: 11 additions & 0 deletions tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ def test_clusters_simple3():
assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels))


def test_clusters_one():
"""
Test that having a shape[i] == 1 does not break the code.
"""
labels = np.array([[1, 1, 0, 2, 2, 2, 0, 3, 3, 3, 0, 1]])
assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels))

labels = labels.reshape(-1, 1)
assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels))


def test_clusters_scipy():
img = eye.dummy_circles((500, 500), periodic=False)
clusters = eye.clusters(img, periodic=False)
Expand Down

0 comments on commit 7a7beae

Please sign in to comment.