From 7a7beae2be9c141d179b9ea4d5fdd2743255076d Mon Sep 17 00:00:00 2001 From: Tom de Geus Date: Tue, 5 Dec 2023 12:52:47 +0100 Subject: [PATCH] Optimised strides: dealing with edge cases (#95) --- include/GooseEYE/GooseEYE.h | 160 ++++++++++++++++++++++++++++-------- include/GooseEYE/config.h | 6 ++ python/GooseEYE/__init__.py | 6 +- python/main.cpp | 4 +- tests/test_clusters.py | 11 +++ 5 files changed, 152 insertions(+), 35 deletions(-) diff --git a/include/GooseEYE/GooseEYE.h b/include/GooseEYE/GooseEYE.h index 2460ce18..6839c71b 100644 --- a/include/GooseEYE/GooseEYE.h +++ b/include/GooseEYE/GooseEYE.h @@ -433,7 +433,7 @@ namespace detail { * @return Array of distances. */ template -inline array_type::tensor kernel_to_dx(T kernel) +inline array_type::tensor kernel_to_dx(T kernel) { #ifdef GOOSEEYE_ENABLE_ASSERT for (size_t i = 0; i < Dim; ++i) { @@ -453,7 +453,10 @@ inline array_type::tensor kernel_to_dx(T kernel) kernel.flat(idx) = 0; if constexpr (Dim == 1) { - return xt::flatten_indices(xt::argwhere(kernel)) - mid[0]; + auto i = xt::flatten_indices(xt::argwhere(kernel)) - mid[0]; + array_type::tensor ret = xt::empty({i.size(), size_t(1)}); + std::copy(i.begin(), i.end(), ret.begin()); + return ret; } auto ret = xt::from_indices(xt::argwhere(kernel)); @@ -478,7 +481,7 @@ class ClusterLabeller { private: std::array m_shape; ///< Shape of the system. - array_type::tensor m_dx; ///< Kernel (in distances along each dimension). + array_type::tensor m_dx; ///< Kernel (in distances along each dimension). array_type::tensor m_label; ///< Per block, the label (`0` for background). ptrdiff_t m_new_label = 1; ///< The next label number to assign. size_t m_nmerge = 0; ///< Number of times that clusters have been merged. @@ -505,6 +508,9 @@ class ClusterLabeller { std::vector m_next; std::vector m_connected; ///< List of labels connected to the current block. + typedef ptrdiff_t (ClusterLabeller::*CompareImpl)(size_t, size_t); + CompareImpl get_compare = &ClusterLabeller::get_compare_default; + public: /** * @param shape @copydoc ClusterLabeller::m_shape @@ -514,12 +520,18 @@ class ClusterLabeller { { if constexpr (Dim == 1) { // kernel = {1, 1, 1} - m_dx = {-1, 1}; + m_dx = {{-1}, {1}}; } else if constexpr (Dim == 2) { // kernel = {{0, 1, 0}, {1, 1, 1}, {0, 1, 0}}; m_dx = {{-1, 0}, {0, -1}, {0, 1}, {1, 0}}; } + else if constexpr (Dim == 3) { + m_dx = {{-1, 0, 0}, {0, -1, 0}, {0, 0, -1}, {0, 0, 1}, {0, 1, 0}, {1, 0, 0}}; + } + else { + throw std::runtime_error("Please specify the kernel in dimensions > 3."); + } this->init(shape); } @@ -538,17 +550,28 @@ class ClusterLabeller { template void init(const T& shape) { - static_assert(Dim == 1 || Dim == 2, "WIP: 1d and 2d supported."); m_label = xt::empty(shape); m_renum.resize(m_label.size() + 1); m_next.resize(m_label.size() + 1); for (size_t i = 0; i < Dim; ++i) { m_shape[i] = static_cast(shape[i]); - m_strides[i] = static_cast(m_label.strides()[i]); + if constexpr (Dim >= 2) { + m_strides[i] = static_cast(m_label.strides()[i]); + } } - GOOSEEYE_ASSERT(m_strides.back() == 1, std::out_of_range); this->reset(); m_connected.resize(m_dx.shape(0)); + + // Dim == 2: by default strides are assumed non-zero to avoid extra checks + // check once if zeros strides occur and if so use a special implementation of unravel_index + if constexpr (Dim == 2) { + if (m_shape[0] == 1) { + get_compare = &ClusterLabeller::get_compare_2d_1n; + } + else if (m_shape[1] == 1) { + get_compare = &ClusterLabeller::get_compare_2d_n1; + } + } } public: @@ -660,37 +683,104 @@ class ClusterLabeller { m_nmerge = 0; } - void label_impl(size_t idx) + /** + * @copydoc ClusterLabeller::get_compare_default + * @warning Implementation only for 2D, `shape = [1, n]`. + */ + ptrdiff_t get_compare_2d_1n(size_t idx, size_t j) { - static_assert(Dim == 1 || Dim == 2, "WIP: 1d and 2d supported."); + if constexpr (Periodic) { + return (m_shape[1] + idx + m_dx(j, 1)) % m_shape[1]; + } + if constexpr (!Periodic) { + ptrdiff_t compare = idx + m_dx(j, 1); + if (compare < 0 || compare >= m_shape[1]) { + return -1; + } + return compare; + } + } - ptrdiff_t compare; - size_t nconnected = 0; + /** + * @copydoc ClusterLabeller::get_compare_default + * @warning Implementation only for 2D, `shape = [n, 1]`. + */ + ptrdiff_t get_compare_2d_n1(size_t idx, size_t j) + { + if constexpr (Periodic) { + return (m_shape[0] + idx + m_dx(j, 0)) % m_shape[0]; + } + if constexpr (!Periodic) { + ptrdiff_t compare = idx + m_dx(j, 0); + if (compare < 0 || compare >= m_shape[0]) { + return -1; + } + return compare; + } + } - for (size_t j = 0; j < m_dx.shape(0); ++j) { - if constexpr (Dim == 1 && Periodic) { - compare = (m_shape[0] + idx + m_dx(j)) % m_shape[0]; + /** + * @brief Get the pixel to compare with. + * @note If the pixel is out of bounds, return -1. + * + * @param idx Flat index of the pixel. + * @param j Index of the kernel (in `m_dx`). + * @return ptrdiff_t Index of the pixel to compare with. + */ + ptrdiff_t get_compare_default(size_t idx, size_t j) + { + if constexpr (Dim == 1 && Periodic) { + return (m_shape[0] + idx + m_dx.flat(j)) % m_shape[0]; + } + if constexpr (Dim == 1 && !Periodic) { + ptrdiff_t compare = idx + m_dx.flat(j); + if (compare < 0 || compare >= m_shape[0]) { + return -1; } - else if constexpr (Dim == 1 && !Periodic) { - if (compare < 0 || compare >= m_shape[0]) { - continue; - } - compare = idx + m_dx(j); + return idx + m_dx.flat(j); + } + if constexpr (Dim == 2 && Periodic) { + ptrdiff_t ii = (m_shape[0] + (idx / m_strides[0]) + m_dx(j, 0)) % m_shape[0]; + ptrdiff_t jj = (m_shape[1] + (idx % m_strides[0]) + m_dx(j, 1)) % m_shape[1]; + return ii * m_shape[1] + jj; + } + if constexpr (Dim == 2 && !Periodic) { + ptrdiff_t ii = (idx / m_strides[0]) + m_dx(j, 0); + ptrdiff_t jj = (idx % m_strides[0]) + m_dx(j, 1); + if (ii < 0 || ii >= m_shape[0] || jj < 0 || jj >= m_shape[1]) { + return -1; } - else if constexpr (Dim == 2 && Periodic) { - ptrdiff_t ii = (m_shape[0] + (idx / m_strides[0]) + m_dx(j, 0)) % m_shape[0]; - ptrdiff_t jj = (m_shape[1] + (idx % m_strides[0]) + m_dx(j, 1)) % m_shape[1]; - compare = ii * m_shape[1] + jj; + return ii * m_shape[1] + jj; + } + else { + auto index = xt::unravel_from_strides(idx, m_strides, xt::layout_type::row_major); + for (size_t d = 0; d < Dim; ++d) { + index[d] += m_dx(j, d); + if constexpr (!Periodic) { + if (index[d] < 0 || index[d] >= m_shape[d]) { + return -1; + } + } + else { + auto n = m_shape[d]; + index[d] = (n + (index[d] % n)) % n; + } } - else if constexpr (Dim == 2 && !Periodic) { - ptrdiff_t ii = (idx / m_strides[0]) + m_dx(j, 0); - ptrdiff_t jj = (idx % m_strides[0]) + m_dx(j, 1); - if (ii < 0 || ii >= m_shape[0] || jj < 0 || jj >= m_shape[1]) { + return xt::ravel_index(index, m_shape, xt::layout_type::row_major); + } + } + + void label_impl(size_t idx) + { + size_t nconnected = 0; + + for (size_t j = 0; j < m_dx.shape(0); ++j) { + ptrdiff_t compare = (this->*get_compare)(idx, j); + if constexpr (!Periodic) { + if (compare == -1) { continue; } - compare = ii * m_shape[1] + jj; } - if (m_label.flat(compare) != 0) { m_connected[nconnected] = m_renum[m_label.flat(compare)]; nconnected++; @@ -1160,6 +1250,8 @@ class ClusterLabellerOverload : public ClusterLabeller { template array_type::array clusters(const T& f, bool periodic = true) { + GOOSEEYE_ASSERT(f.layout() == xt::layout_type::row_major, std::runtime_error); + auto n = f.dimension(); if (n == 1 && periodic) { return detail::ClusterLabellerOverload<1, true>(f).get(); @@ -1173,9 +1265,13 @@ array_type::array clusters(const T& f, bool periodic = true) if (n == 2 && !periodic) { return detail::ClusterLabellerOverload<2, false>(f).get(); } - - GOOSEEYE_WARNING("WIP: updated 3d implementation needs to be completed. Please file a PR."); - return Clusters(f, kernel::nearest(f.dimension()), periodic).labels(); + if (n == 3 && periodic) { + return detail::ClusterLabellerOverload<3, true>(f).get(); + } + if (n == 3 && !periodic) { + return detail::ClusterLabellerOverload<3, false>(f).get(); + } + throw std::runtime_error("Please use ClusterLabeller directly for dimensions > 3."); } /** diff --git a/include/GooseEYE/config.h b/include/GooseEYE/config.h index 1f1a4c21..9df48c6c 100644 --- a/include/GooseEYE/config.h +++ b/include/GooseEYE/config.h @@ -61,6 +61,12 @@ * @endcond */ +/** + * @brief Assertions that are always enabled. + */ +#define GOOSEEYE_REQUIRE(expr, assertion) \ + GOOSEEYE_ASSERT_IMPL(expr, assertion, __FILE__, __LINE__, __FUNCTION__) + /** * All assertions are implemented as: * diff --git a/python/GooseEYE/__init__.py b/python/GooseEYE/__init__.py index 0bae7bc9..36f2e23c 100644 --- a/python/GooseEYE/__init__.py +++ b/python/GooseEYE/__init__.py @@ -23,7 +23,11 @@ def ClusterLabeller(shape, periodic=True, **kwargs): if periodic: return ClusterLabeller2p(shape, **kwargs) return ClusterLabeller2(shape, **kwargs) - raise NotImplementedError("3d extension needed, please open a PR") + if len(shape) == 3: + if periodic: + return ClusterLabeller3p(shape, **kwargs) + return ClusterLabeller3(shape, **kwargs) + raise NotImplementedError("dimension > 3 not compiled by default") class Structure(enstat.static): diff --git a/python/main.cpp b/python/main.cpp index 1aaf5515..51b81fbd 100644 --- a/python/main.cpp +++ b/python/main.cpp @@ -143,9 +143,9 @@ PYBIND11_MODULE(_GooseEYE, m) py::arg("iterations") = 1, py::arg("periodic") = true); - static_for<1, 3>( + static_for<1, 4>( [&](auto i) { allocate_ClusterLabeller>(m); }); - static_for<1, 3>( + static_for<1, 4>( [&](auto i) { allocate_ClusterLabeller>(m); }); py::class_(m, "Clusters") diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 798a4079..00b75130 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -147,6 +147,17 @@ def test_clusters_simple3(): assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels)) +def test_clusters_one(): + """ + Test that having a shape[i] == 1 does not break the code. + """ + labels = np.array([[1, 1, 0, 2, 2, 2, 0, 3, 3, 3, 0, 1]]) + assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels)) + + labels = labels.reshape(-1, 1) + assert np.all(np.equal(eye.clusters(np.where(labels > 0, 1, 0)), labels)) + + def test_clusters_scipy(): img = eye.dummy_circles((500, 500), periodic=False) clusters = eye.clusters(img, periodic=False)