Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect sequence of clusters #97

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 80 additions & 15 deletions include/GooseEYE/GooseEYE.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,18 @@ class ClusterLabeller {
m_next[a] = b;
}

/**
* @brief Sorted unique list.
* @param labels List of labels.
* @param nlabels Size of the list.
* @return Size of the unique list.
*/
size_t unique(ptrdiff_t* labels, size_t nlabels)
{
std::sort(labels, labels + nlabels);
return std::unique(labels, labels + nlabels) - labels;
}

/**
* @brief Mark list of labels as merged.
* @note Link all labels to the lowest label in the list.
Expand All @@ -663,8 +675,6 @@ class ClusterLabeller {
*/
ptrdiff_t merge(ptrdiff_t* labels, size_t nlabels)
{
std::sort(labels, labels + nlabels);
nlabels = std::unique(labels, labels + nlabels) - labels;
ptrdiff_t target = labels[0];
for (size_t i = 1; i < nlabels; ++i) {
this->merge_detail(target, labels[i]);
Expand Down Expand Up @@ -798,6 +808,12 @@ class ClusterLabeller {
return;
}

nconnected = this->unique(&m_connected[0], nconnected);
if (nconnected == 1) {
m_label.flat(idx) = m_connected[0];
return;
}

// mark all labels in the list for merging
// `m_label` is not yet updated to avoid looping over all blocks too frequently
// the new label can be read by `m_renum[lab]` (as done above)
Expand Down Expand Up @@ -827,27 +843,29 @@ class ClusterLabeller {
this->apply_merge();
}

/**
* @brief Add sequence of points.
* @param begin Iterator to first point.
* @param end Iterator to last point.
*/
private:
template <class T>
void add_points(const T& begin, const T& end)
bool legal_points(const T& begin, const T& end)
{
#ifdef GOOSEEYE_ENABLE_ASSERT
size_t n = m_label.size();
if constexpr (std::is_signed_v<typename T::value_type>) {
GOOSEEYE_ASSERT(
!std::any_of(begin, end, [n](size_t i) { return i < 0 || i >= n; }),
std::out_of_range);
return !std::any_of(begin, end, [n](size_t i) { return i < 0 || i >= n; });
}
else {
GOOSEEYE_ASSERT(
!std::any_of(begin, end, [n](size_t i) { return i >= n; }), std::out_of_range);
return !std::any_of(begin, end, [n](size_t i) { return i >= n; });
}
#endif
}

public:
/**
* @brief Add sequence of points.
* @param begin Iterator to first point.
* @param end Iterator to last point.
*/
template <class T>
void add_points(const T& begin, const T& end)
{
GOOSEEYE_ASSERT(this->legal_points(begin, end), std::out_of_range);
for (auto it = begin; it != end; ++it) {
if (m_label.flat(*it) != 0) {
continue;
Expand All @@ -868,6 +886,53 @@ class ClusterLabeller {
return this->add_points(idx.begin(), idx.end());
}

/**
* @brief
* Add a sequence of points.
* Mark index every time a new cluster is started or a cluster is merged.
*
* @param idx List of points.
* @return List of indices.
*/
template <class T>
std::vector<size_t> add_sequence(const T& idx)
{
GOOSEEYE_ASSERT(idx.dimension() == 1, std::out_of_range);
GOOSEEYE_ASSERT(idx.size() >= 1, std::out_of_range);
GOOSEEYE_ASSERT(this->legal_points(idx.begin(), idx.end()), std::out_of_range);
std::vector<size_t> ret;
size_t i = 0;
while (true) {
auto nl = m_new_label;
auto nm = m_nmerge;
auto lab = m_label.flat(idx(i));

for (; i < idx.size(); ++i) {
auto l = m_label.flat(idx(i));
if (l != lab && l != 0) {
ret.push_back(i);
break;
}
if (l != 0) {
continue;
}
this->label_impl(idx(i));
if (m_new_label != nl || m_nmerge != nm) {
ret.push_back(i);
break;
}
}

if (i == idx.size()) {
break;
}
}

this->apply_merge();
ret.push_back(idx.size());
return ret;
}

/**
* @brief Basic class info.
* @return std::string
Expand Down
6 changes: 6 additions & 0 deletions python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ void allocate_ClusterLabeller(py::module& mod)
static_cast<void (Class::*)(const xt::pytensor<size_t, 1>&)>(&Class::add_points),
"Add points",
py::arg("idx"));

cls.def(
"add_sequence",
&Class::template add_sequence<xt::pytensor<size_t, 1>>,
"Add points",
py::arg("idx"));
}

PYBIND11_MODULE(_GooseEYE, m)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,79 @@ def test_init():
assert np.all(eye.clusters(s) == s)


def test_clusters_sequence():
"""
Get sequence of clusters.
"""
split = [0]
idx = []
labels = np.zeros([5, 5], dtype=int)

patch = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
]
)
idx += np.argwhere(patch.ravel() > 0).ravel().tolist()
split.append(len(idx))
labels += patch

patch = np.array(
[
[0, 0, 0, 0, 0],
[0, 2, 2, 0, 0],
[0, 2, 2, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
)
idx += np.argwhere(patch.ravel() > 0).ravel().tolist()
split.append(len(idx))
labels += patch

patch = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 3],
[0, 0, 0, 0, 3],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
)
idx += np.argwhere(patch.ravel() > 0).ravel().tolist()
split.append(len(idx))
labels += patch

segmenter = eye.ClusterLabeller(shape=labels.shape)
ret = segmenter.add_sequence(idx)
assert np.all(np.equal(segmenter.labels, labels))
assert np.all(np.equal(ret, split))

patch = np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 0, 4, 0],
[0, 0, 0, 4, 0],
[0, 0, 0, 4, 0],
[0, 0, 0, 0, 0],
]
)
idx += np.argwhere(patch.ravel() > 0).ravel().tolist()
split.append(len(idx))
labels += patch
labels = np.where(labels == 3, 2, labels)
labels = np.where(labels == 4, 2, labels)

segmenter = eye.ClusterLabeller(shape=labels.shape)
ret = segmenter.add_sequence(idx)
assert np.all(np.equal(segmenter.labels, labels))
assert np.all(np.equal(ret, split))


def test_labels_prune():
a = np.array([[-2, -2, 0, 0], [0, 0, 8, 8], [3, 3, 0, 0], [0, 0, 6, 6]])
b = np.array([[1, 1, 0, 0], [0, 0, 4, 4], [2, 2, 0, 0], [0, 0, 3, 3]])
Expand Down
Loading