diff --git a/include/GooseEYE/GooseEYE.h b/include/GooseEYE/GooseEYE.h index 6839c71b..2565fd3f 100644 --- a/include/GooseEYE/GooseEYE.h +++ b/include/GooseEYE/GooseEYE.h @@ -654,6 +654,18 @@ class ClusterLabeller { m_next[a] = b; } + /** + * @brief Sorted unique list. + * @param labels List of labels. + * @param nlabels Size of the list. + * @return Size of the unique list. + */ + size_t unique(ptrdiff_t* labels, size_t nlabels) + { + std::sort(labels, labels + nlabels); + return std::unique(labels, labels + nlabels) - labels; + } + /** * @brief Mark list of labels as merged. * @note Link all labels to the lowest label in the list. @@ -663,8 +675,6 @@ class ClusterLabeller { */ ptrdiff_t merge(ptrdiff_t* labels, size_t nlabels) { - std::sort(labels, labels + nlabels); - nlabels = std::unique(labels, labels + nlabels) - labels; ptrdiff_t target = labels[0]; for (size_t i = 1; i < nlabels; ++i) { this->merge_detail(target, labels[i]); @@ -798,6 +808,12 @@ class ClusterLabeller { return; } + nconnected = this->unique(&m_connected[0], nconnected); + if (nconnected == 1) { + m_label.flat(idx) = m_connected[0]; + return; + } + // mark all labels in the list for merging // `m_label` is not yet updated to avoid looping over all blocks too frequently // the new label can be read by `m_renum[lab]` (as done above) @@ -827,27 +843,29 @@ class ClusterLabeller { this->apply_merge(); } - /** - * @brief Add sequence of points. - * @param begin Iterator to first point. - * @param end Iterator to last point. - */ +private: template - void add_points(const T& begin, const T& end) + bool legal_points(const T& begin, const T& end) { -#ifdef GOOSEEYE_ENABLE_ASSERT size_t n = m_label.size(); if constexpr (std::is_signed_v) { - GOOSEEYE_ASSERT( - !std::any_of(begin, end, [n](size_t i) { return i < 0 || i >= n; }), - std::out_of_range); + return !std::any_of(begin, end, [n](size_t i) { return i < 0 || i >= n; }); } else { - GOOSEEYE_ASSERT( - !std::any_of(begin, end, [n](size_t i) { return i >= n; }), std::out_of_range); + return !std::any_of(begin, end, [n](size_t i) { return i >= n; }); } -#endif + } +public: + /** + * @brief Add sequence of points. + * @param begin Iterator to first point. + * @param end Iterator to last point. + */ + template + void add_points(const T& begin, const T& end) + { + GOOSEEYE_ASSERT(this->legal_points(begin, end), std::out_of_range); for (auto it = begin; it != end; ++it) { if (m_label.flat(*it) != 0) { continue; @@ -868,6 +886,53 @@ class ClusterLabeller { return this->add_points(idx.begin(), idx.end()); } + /** + * @brief + * Add a sequence of points. + * Mark index every time a new cluster is started or a cluster is merged. + * + * @param idx List of points. + * @return List of indices. + */ + template + std::vector add_sequence(const T& idx) + { + GOOSEEYE_ASSERT(idx.dimension() == 1, std::out_of_range); + GOOSEEYE_ASSERT(idx.size() >= 1, std::out_of_range); + GOOSEEYE_ASSERT(this->legal_points(idx.begin(), idx.end()), std::out_of_range); + std::vector ret; + size_t i = 0; + while (true) { + auto nl = m_new_label; + auto nm = m_nmerge; + auto lab = m_label.flat(idx(i)); + + for (; i < idx.size(); ++i) { + auto l = m_label.flat(idx(i)); + if (l != lab && l != 0) { + ret.push_back(i); + break; + } + if (l != 0) { + continue; + } + this->label_impl(idx(i)); + if (m_new_label != nl || m_nmerge != nm) { + ret.push_back(i); + break; + } + } + + if (i == idx.size()) { + break; + } + } + + this->apply_merge(); + ret.push_back(idx.size()); + return ret; + } + /** * @brief Basic class info. * @return std::string diff --git a/python/main.cpp b/python/main.cpp index 51b81fbd..6219cad1 100644 --- a/python/main.cpp +++ b/python/main.cpp @@ -58,6 +58,12 @@ void allocate_ClusterLabeller(py::module& mod) static_cast&)>(&Class::add_points), "Add points", py::arg("idx")); + + cls.def( + "add_sequence", + &Class::template add_sequence>, + "Add points", + py::arg("idx")); } PYBIND11_MODULE(_GooseEYE, m) diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 00b75130..ab894020 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -13,6 +13,79 @@ def test_init(): assert np.all(eye.clusters(s) == s) +def test_clusters_sequence(): + """ + Get sequence of clusters. + """ + split = [0] + idx = [] + labels = np.zeros([5, 5], dtype=int) + + patch = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + ] + ) + idx += np.argwhere(patch.ravel() > 0).ravel().tolist() + split.append(len(idx)) + labels += patch + + patch = np.array( + [ + [0, 0, 0, 0, 0], + [0, 2, 2, 0, 0], + [0, 2, 2, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + idx += np.argwhere(patch.ravel() > 0).ravel().tolist() + split.append(len(idx)) + labels += patch + + patch = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 3], + [0, 0, 0, 0, 3], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + idx += np.argwhere(patch.ravel() > 0).ravel().tolist() + split.append(len(idx)) + labels += patch + + segmenter = eye.ClusterLabeller(shape=labels.shape) + ret = segmenter.add_sequence(idx) + assert np.all(np.equal(segmenter.labels, labels)) + assert np.all(np.equal(ret, split)) + + patch = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 4, 0], + [0, 0, 0, 4, 0], + [0, 0, 0, 4, 0], + [0, 0, 0, 0, 0], + ] + ) + idx += np.argwhere(patch.ravel() > 0).ravel().tolist() + split.append(len(idx)) + labels += patch + labels = np.where(labels == 3, 2, labels) + labels = np.where(labels == 4, 2, labels) + + segmenter = eye.ClusterLabeller(shape=labels.shape) + ret = segmenter.add_sequence(idx) + assert np.all(np.equal(segmenter.labels, labels)) + assert np.all(np.equal(ret, split)) + + def test_labels_prune(): a = np.array([[-2, -2, 0, 0], [0, 0, 8, 8], [3, 3, 0, 0], [0, 0, 6, 6]]) b = np.array([[1, 1, 0, 0], [0, 0, 4, 4], [2, 2, 0, 0], [0, 0, 3, 3]])