Skip to content

Commit

Permalink
Column iterator for SGMatrix (#4997)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpsy authored Jun 2, 2020
1 parent c0feb0a commit a3f8d98
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 0 deletions.
285 changes: 285 additions & 0 deletions src/shogun/lib/SGMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ template<class T> class SGMatrix : public SGReferencedData
public:
typedef RandomIterator<T> iterator;
typedef ConstRandomIterator<T> const_iterator;
#ifndef SWIG
public:
class column_iterator;
class const_column_iterator;

private:
template<class MatrixType>
struct column_struct;
#endif
public:
typedef Eigen::Matrix<T,-1,-1,0,-1,-1> EigenMatrixXt;
typedef Eigen::Map<EigenMatrixXt,0,Eigen::Stride<0,0> > EigenMatrixXtMap;
Expand Down Expand Up @@ -258,6 +266,24 @@ template<class T> class SGMatrix : public SGReferencedData
/** Returns a const iterator to the element following the last element of the container. */
const_iterator end() const noexcept { return const_iterator(matrix + (num_rows * num_cols)); }

/** Returns an iterator to the first column of the container. */
column_iterator begin_column() noexcept { return column_iterator(*this); }

/** Returns an iterator to the column following the last column of the container. */
column_iterator end_column() noexcept { return column_iterator(*this,num_cols); }

/** Returns a const iterator to the first column of the container. */
const_column_iterator begin_column() const noexcept { return const_column_iterator(*this); }

/** Returns a const iterator to the column following the last column of the container. */
const_column_iterator end_column() const noexcept { return const_column_iterator(*this,num_cols); }

/** Returns a helper struct to help iterate over columns*/
auto columns() noexcept { return column_struct<decltype(*this)>(*this); }

/** Returns a helper struct to help const iterate over columns*/
auto columns() const noexcept { return column_struct<decltype(*this)>(*this); }

#endif // SWIG should skip this part

/** Get element at index
Expand Down Expand Up @@ -532,5 +558,264 @@ template<class T> class SGMatrix : public SGReferencedData
/** GPU Matrix structure. Stores pointer to the data on GPU. */
std::shared_ptr<GPUMemoryBase<T>> gpu_ptr;
};
#ifndef SWIG
template<typename T>
class SGMatrix<T>::column_iterator
{
public:
// Iterator traits
using difference_type = index_t;
using value_type = SGVector<T>;
using pointer = SGVector<T>*;
using reference = SGVector<T>&;
using iterator_category = std::random_access_iterator_tag;

column_iterator(SGMatrix<T> mat, index_t idx = 0)
: m_mat(mat), m_col_idx(idx)
{
}
column_iterator(const column_iterator& other)
: m_mat(other.m_mat), m_col_idx(other.m_col_idx)
{
}

column_iterator& operator++()
{
++m_col_idx;
return *this;
}
column_iterator operator++(int)
{
column_iterator retval(*this);
++(*this);
return retval;
}
column_iterator& operator--()
{
--m_col_idx;
return *this;
}
column_iterator operator--(int)
{
column_iterator retval(*this);
--(*this);
return retval;
}

bool operator==(const column_iterator& other) const
{
return m_mat == other.m_mat && m_col_idx == other.m_col_idx;
}
bool operator!=(const column_iterator& other) const
{
return !(*this == other);
}

value_type operator*()
{
return m_mat.get_column(m_col_idx);
}

column_iterator& operator+=(difference_type d)
{
m_col_idx += d;
return *this;
}
column_iterator& operator-=(difference_type d)
{
m_col_idx -= d;
return *this;
}

column_iterator operator+(difference_type d) const
{
return column_iterator(m_mat, m_col_idx + d);
}
column_iterator operator-(difference_type d) const
{
return column_iterator(m_mat, m_col_idx - d);
}

value_type operator[](difference_type d) const
{
return *column_iterator(m_mat, m_col_idx + d);
}

bool operator<(const column_iterator& other) const
{
return m_col_idx < other.m_col_idx;
}
bool operator>(const column_iterator& other) const
{
return m_col_idx > other.m_col_idx;
}
bool operator<=(const column_iterator& other) const
{
return m_col_idx <= other.m_col_idx;
}
bool operator>=(const column_iterator& other) const
{
return m_col_idx >= other.m_col_idx;
}

difference_type operator-(const column_iterator& other) const
{
return m_col_idx - other.m_col_idx;
}

private:
SGMatrix<T> m_mat;
difference_type m_col_idx;
};
template<typename T>
class SGMatrix<T>::const_column_iterator
{
public:
// Iterator traits
using difference_type = index_t;
using value_type = const SGVector<T>;
using pointer = const SGVector<T>*;
using reference = const SGVector<T>&;
using iterator_category = std::random_access_iterator_tag;

const_column_iterator(const SGMatrix<T> mat, index_t idx = 0)
: m_mat(mat), m_col_idx(idx)
{
}
const_column_iterator(const const_column_iterator& other)
: m_mat(other.m_mat), m_col_idx(other.m_col_idx)
{
}

const_column_iterator& operator++()
{
++m_col_idx;
return *this;
}
const_column_iterator operator++(int)
{
const_column_iterator retval(*this);
++(*this);
return retval;
}
const_column_iterator& operator--()
{
--m_col_idx;
return *this;
}
const_column_iterator operator--(int)
{
const_column_iterator retval(*this);
--(*this);
return retval;
}

bool operator==(const const_column_iterator& other) const
{
return m_mat == other.m_mat && m_col_idx == other.m_col_idx;
}
bool operator!=(const const_column_iterator& other) const
{
return !(*this == other);
}

value_type operator*()
{
return m_mat.get_column(m_col_idx);
}

const_column_iterator& operator+=(difference_type d)
{
m_col_idx += d;
return *this;
}
const_column_iterator& operator-=(difference_type d)
{
m_col_idx -= d;
return *this;
}

const_column_iterator operator+(difference_type d) const
{
return const_column_iterator(m_mat, m_col_idx + d);
}
const_column_iterator operator-(difference_type d) const
{
return const_column_iterator(m_mat, m_col_idx - d);
}

value_type operator[](difference_type d) const
{
return *const_column_iterator(m_mat, m_col_idx + d);
}

bool operator<(const const_column_iterator& other) const
{
return m_col_idx < other.m_col_idx;
}
bool operator>(const const_column_iterator& other) const
{
return m_col_idx > other.m_col_idx;
}
bool operator<=(const const_column_iterator& other) const
{
return m_col_idx <= other.m_col_idx;
}
bool operator>=(const const_column_iterator& other) const
{
return m_col_idx >= other.m_col_idx;
}

difference_type operator-(const const_column_iterator& other) const
{
return m_col_idx - other.m_col_idx;
}

private:
const SGMatrix<T> m_mat;
difference_type m_col_idx;
};
template<typename T>
template<class MatrixType>
struct SGMatrix<T>::column_struct
{
column_struct(MatrixType mat):m_mat(mat)
{
}

auto begin()
{
return m_mat.begin_column();
}

auto end()
{
return m_mat.end_column();
}

auto begin() const
{
return m_mat.begin_column();
}

auto end() const
{
return m_mat.end_column();
}

auto cbegin() const
{
return m_mat.begin_column();
}

auto cend() const
{
return m_mat.end_column();
}

private:
MatrixType m_mat;
};
#endif
}
#endif // __SGMATRIX_H__
51 changes: 51 additions & 0 deletions tests/unit/lib/SGMatrix_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,54 @@ TEST(SGMatrixTest,iterator)
for (auto v: mat)
EXPECT_EQ(mat[index++], v);
}

TEST(SGMatrixTest, column_iterator)
{
constexpr index_t size = 5;
SGMatrix<float64_t> mat(size, size);
linalg::range_fill(mat, 1.0);

auto begin_col = mat.begin_column();
auto end_col = mat.end_column();

EXPECT_EQ(mat.num_cols, std::distance(begin_col, end_col));
EXPECT_EQ(mat.get_column(0), *begin_col);
++begin_col;
EXPECT_EQ(mat.get_column(1), *begin_col);
--begin_col;
EXPECT_EQ(mat.get_column(0), *begin_col++);
EXPECT_TRUE(begin_col != end_col);
++begin_col;
EXPECT_EQ(mat.get_column(2), *begin_col--);
--begin_col;
EXPECT_EQ(mat.get_column(0), *begin_col);
begin_col += 2;
EXPECT_EQ(mat.get_column(2), *begin_col);
begin_col -= 2;
EXPECT_EQ(mat.get_column(0), *begin_col);
EXPECT_EQ(mat.get_column(1), begin_col[1]);

auto new_itr = begin_col + 2;
EXPECT_EQ(mat.get_column(2), *new_itr);

// range-based loop should work as well
auto index = 0;
for (const auto& i : mat.columns())
EXPECT_EQ(i, mat.get_column(index++));

const SGMatrix<float64_t> const_mat(mat);

index = 0;
for (const auto& i : const_mat.columns())
EXPECT_EQ(i, const_mat.get_column(index++));

// Modifying value of column vector
auto mat_copy = mat.clone();

for (auto vec : mat.columns())
linalg::add_scalar(vec, 1.0);

for (auto col_idx = 0; col_idx < mat.num_cols; col_idx++)
EXPECT_EQ(*mat.get_column(col_idx), *mat_copy.get_column(col_idx) + 1.0);
}

0 comments on commit a3f8d98

Please sign in to comment.