Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 13, 2019
1 parent c3acaa1 commit 583a724
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 50 deletions.
131 changes: 107 additions & 24 deletions src/shogun/io/ARFFFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* Authors: Gil Hoben
*/

#include <shogun/features/DenseFeatures.h>
#include <shogun/io/ARFFFile.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>

#include <date/date.h>

Expand All @@ -18,6 +18,18 @@ const char* ARFFDeserializer::m_attribute_string = "@attribute";
const char* ARFFDeserializer::m_data_string = "@data";
const char* ARFFDeserializer::m_default_date_format = "%Y-%M-%DT%H:%M:%S";

struct VectorSizeVisitor
{
size_t operator()(const std::vector<float64_t>& v) const
{
return v.size();
}
size_t operator()(const std::vector<std::string>& v) const
{
return v.size();
}
};

void ARFFDeserializer::read()
{
m_line_number = 0;
Expand Down Expand Up @@ -78,6 +90,7 @@ void ARFFDeserializer::read()
// check if it is nominal
if (type[0] == '{')
{
// @ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}
std::vector<std::string> attributes;
// split norminal values: "{A, B, C}" to vector{A, B, C}
split(
Expand All @@ -86,6 +99,7 @@ void ARFFDeserializer::read()
m_nominal_attributes.emplace_back(
std::make_pair(name, attributes));
m_attributes.push_back(Attribute::Nominal);
m_data_vectors.emplace_back(std::vector<float64_t>{});
return;
}

Expand Down Expand Up @@ -120,23 +134,32 @@ void ARFFDeserializer::read()
m_current_line.c_str())
}
m_attributes.push_back(Attribute::Date);
m_data_vectors.emplace_back(std::vector<float64_t>{});
}
else if (is_primitive_type(type))
{
type = string_to_lower(type);
// numeric attributes
if (type == "numeric")
{
m_attributes.push_back(Attribute::Numeric);
m_data_vectors.emplace_back(std::vector<float64_t>{});
}
else if (type == "integer")
{
m_attributes.push_back(Attribute::Integer);
m_data_vectors.emplace_back(std::vector<float64_t>{});
}
else if (type == "real")
{
m_attributes.push_back(Attribute::Real);
m_data_vectors.emplace_back(std::vector<float64_t>{});
}
else if (type == "string")
{
// @ATTRIBUTE LCC string
// m_attributes.emplace(std::make_pair(elems[0],
// "string"));
m_attributes.push_back(Attribute::String);
m_data_vectors.emplace_back(std::vector<std::string>{});
}
else
SG_SERROR(
Expand Down Expand Up @@ -180,7 +203,8 @@ void ARFFDeserializer::read()
split(m_current_line, ",", std::back_inserter(elems), "\'\"");
auto nominal_pos = m_nominal_attributes.begin();
auto date_pos = m_date_formats.begin();
for (int i = 0; i < elems.size(); ++i)
int i = 0;
for (; i < elems.size(); ++i)
{
Attribute type = m_attributes[i];
switch (type)
Expand All @@ -191,7 +215,8 @@ void ARFFDeserializer::read()
{
try
{
m_data.push_back(std::stod(elems[i]));
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
.push_back(std::stod(elems[i]));
}
catch (const std::invalid_argument&)
{
Expand All @@ -216,7 +241,8 @@ void ARFFDeserializer::read()
"Unexpected value \"%s\" on line %d\n",
elems[i].c_str(), m_line_number);
float64_t idx = std::distance(encoding.begin(), pos);
m_data.push_back(idx);
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
.push_back(idx);
nominal_pos = std::next(nominal_pos);
}
break;
Expand All @@ -227,49 +253,106 @@ void ARFFDeserializer::read()
if (date_pos == m_date_formats.end())
SG_SERROR(
"Unexpected date value \"%s\" on line %d.\n",
elems[i].c_str(), m_line_number);
elems[i].c_str(), m_line_number);
ss >> date::parse(*date_pos, t);
if (bool(ss))
{
auto value_timestamp = t.time_since_epoch().count();
m_data.emplace_back(value_timestamp);
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
.push_back(value_timestamp);
}
else
SG_SERROR(
"Error parsing date \"%s\" with date format \"%s\" "
"on line %d.\n",
elems[i].c_str(), (*date_pos).c_str(), m_line_number)
elems[i].c_str(), (*date_pos).c_str(), m_line_number)
++date_pos;
}
break;
case (Attribute::String):
SG_SERROR("String parsing not implemented.\n")
shogun::get<std::vector<std::string>>(m_data_vectors[i])
.emplace_back(elems[i]);
}
}
if (i != m_attributes.size())
SG_SERROR(
"Unexpected number of values on line %d, expected %d values, "
"but found %d.\n",
m_line_number, m_attributes.size(), i)
++m_row_count;
};
auto check_data = [this]() {
// check X values
SG_SDEBUG(
"size: %d, cols: %d, rows: %d", m_data.size(),
m_data.size() / m_row_count, m_row_count)
if (!m_data.empty())
"size: %d, cols: %d, rows: %d", m_data_vectors.size(),
m_data_vectors.size() / m_row_count, m_row_count)
if (!m_data_vectors.empty())
{
auto tmp =
SGMatrix<float64_t>(m_data.size() / m_row_count, m_row_count);
m_data_matrix =
SGMatrix<float64_t>(m_row_count, m_data.size() / m_row_count);
memcpy(
tmp.matrix, m_data.data(), m_data.size() * sizeof(float64_t));
typename SGMatrix<float64_t>::EigenMatrixXtMap tmp_eigen = tmp;
typename SGMatrix<float64_t>::EigenMatrixXtMap m_data_matrix_eigen =
m_data_matrix;

m_data_matrix_eigen = tmp_eigen.transpose();
auto feature_count = m_data_vectors.size();
index_t row_count =
shogun::visit(VectorSizeVisitor{}, m_data_vectors[0]);
for (int i = 1; i < feature_count; ++i)
{
REQUIRE(
shogun::visit(VectorSizeVisitor{}, m_data_vectors[i]) ==
row_count,
"All columns must have the same number of features!\n")
}
}
else
return false;
return true;
};
process_chunk(read_data, check_data, true);
}

std::shared_ptr<CCombinedFeatures> ARFFDeserializer::get_features()
{
auto result = std::make_shared<CCombinedFeatures>();
index_t row_count = shogun::visit(VectorSizeVisitor{}, m_data_vectors[0]);
for (int i = 0; i < m_data_vectors.size(); ++i)
{
Attribute att = m_attributes[i];
auto vec = m_data_vectors[i];
switch (att)
{
case Attribute::Numeric:
case Attribute::Integer:
case Attribute::Real:
case Attribute::Date:
case Attribute::Nominal:
{
auto casted_vec = shogun::get<std::vector<float64_t>>(vec);
SGMatrix<float64_t> mat(1, row_count);
memcpy(
mat.matrix, casted_vec.data(),
casted_vec.size() * sizeof(float64_t));
auto* feat = new CDenseFeatures<float64_t>(mat);
result->append_feature_obj(feat);
}
break;
case Attribute::String:
{
auto casted_vec = shogun::get<std::vector<std::string>>(vec);
index_t max_string_length = 0;
for (const auto& el : casted_vec)
{
if (max_string_length < el.size())
max_string_length = el.size();
}
SGStringList<char> strings(row_count, max_string_length);
for (int j = 0; j < row_count; ++j)
{
SGString<char> current(max_string_length);
memcpy(
current.string, casted_vec[j].data(),
(casted_vec.size()+1) * sizeof(char));
strings.strings[j] = current;
}
auto* feat = new CStringFeatures<char>(strings, EAlphabet::RAWBYTE);
result->append_feature_obj(feat);
}
}
}
return result;
}
42 changes: 28 additions & 14 deletions src/shogun/io/ARFFFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define SHOGUN_ARFFFILE_H

#include <shogun/base/init.h>
#include <shogun/base/variant.h>
#include <shogun/features/CombinedFeatures.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGVector.h>

Expand Down Expand Up @@ -41,15 +43,18 @@ namespace shogun
SG_FORCED_INLINE void left_trim(std::string& s)
{
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](char val) {
return !std::isspace(val);
}));
return !std::isspace(val);
}));
}

SG_FORCED_INLINE void right_trim(std::string& s)
{
s.erase(std::find_if(s.rbegin(), s.rend(), [](char val) {
return !std::isspace(val);
}).base(), s.end());
s.erase(
std::find_if(
s.rbegin(), s.rend(),
[](char val) { return !std::isspace(val); })
.base(),
s.end());
}

SG_FORCED_INLINE std::string trim(std::string line)
Expand Down Expand Up @@ -170,7 +175,8 @@ namespace shogun
* @param java_token
* @return
*/
SG_FORCED_INLINE const char* process_javatoken(const std::string& java_token)
SG_FORCED_INLINE const char*
process_javatoken(const std::string& java_token)
{
if (java_token == "yy")
return "%y";
Expand All @@ -191,7 +197,7 @@ namespace shogun
if (java_token == "Z")
return "%z";
if (java_token == "z")
return "%Z";
SG_SERROR("Timezone abbreviations are currently not supported.\n")
if (java_token.empty())
return "";
if (java_token == "SSS")
Expand Down Expand Up @@ -237,7 +243,8 @@ namespace shogun
return nullptr;
}

SG_FORCED_INLINE std::string javatime_to_cpptime(const std::string& java_time)
SG_FORCED_INLINE std::string
javatime_to_cpptime(const std::string& java_time)
{
std::string cpp_time;
std::string token;
Expand Down Expand Up @@ -326,7 +333,7 @@ namespace shogun
"have the right permissions to open it.\n",
filename.c_str())
}
m_stream = std::unique_ptr<std::istream>(static_cast<std::istream*>(file_stream));
m_stream = std::unique_ptr<std::istream>(file_stream);
}

/**
Expand All @@ -348,14 +355,20 @@ namespace shogun
void read();

/**
* Returns the data processed after parsing.
* @return matrix with parsed data
* Returns string parsed in @relation line
* @return the relation string
*/
SGMatrix<float64_t> get_data()
SG_FORCED_INLINE std::string get_relation()
{
return m_data_matrix;
return m_relation;
}

/**
* Get combined features from parsed data
* @return
*/
std::shared_ptr<CCombinedFeatures> get_features();

private:
/**
* Processes a chunk. A chunk is defined as a set of lines that
Expand Down Expand Up @@ -455,7 +468,8 @@ namespace shogun
m_nominal_attributes;

/** dynamic continuous vector with the parsed data */
std::vector<float64_t> m_data;
std::vector<variant<std::vector<float64_t>, std::vector<std::string>>>
m_data_vectors;
/** sgmatrix with the properly formatted data from m_data */
SGMatrix<float64_t> m_data_matrix;
};
Expand Down
Loading

0 comments on commit 583a724

Please sign in to comment.