diff --git a/src/shogun/io/ARFFFile.cpp b/src/shogun/io/ARFFFile.cpp new file mode 100644 index 00000000000..d2adb2fbc58 --- /dev/null +++ b/src/shogun/io/ARFFFile.cpp @@ -0,0 +1,263 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Gil Hoben + */ + +#include +#include + +using namespace shogun; +using namespace shogun::arff_detail; + +const char* ARFFDeserializer::m_comment_string = "%"; +const char* ARFFDeserializer::m_relation_string = "@RELATION"; +const char* ARFFDeserializer::m_attribute_string = "@ATTRIBUTE"; +const char* ARFFDeserializer::m_data_string = "@DATA"; + +std::vector +ARFFDeserializer::clean_up(std::vector& line) +{ + std::string result_string; + std::vector result; + std::vector::iterator begin; + + for (auto& elem : line) + { + elem.erase( + std::remove_if( + elem.begin(), elem.end(), + [](auto& v) { return v == ',' || v == '{' || v == '}'; }), + elem.end()); + } + for (auto iter = line.begin(); iter != line.end(); ++iter) + { + if (iter->front() == '\'' or iter->front() == '\"') + { + result_string = *iter; + if (iter->back() != '\'' and iter->back() != '\"') + { + begin = iter; + ++iter; + while (iter->back() != '\'' and iter->back() != '\"') + { + if (iter == line.end()) + { + SG_SERROR("Unbalanced quotes") + } + ++iter; + } + // concatenate strings within quotes with a space in + // between + result_string = std::accumulate( + begin + 1, iter + 1, *begin, + [](std::string s0, std::string& s1) { + remove_char_inplace(s0, '\''); + remove_char_inplace(s1, '\''); + return s0 += " " + s1; + }); + } + else + remove_char_inplace(result_string, '\''); + result.push_back(result_string); + } + else + { + result_string = *iter; + remove_char_inplace(result_string, '\''); + if (!result_string.empty()) + result.push_back(result_string); + } + } + return result; +} + +void ARFFDeserializer::read() +{ + m_line_number = 0; + m_row_count = 0; + m_file_done = false; + auto read_comment = [this]() { + if (string_to_lower(m_current_line.substr(0, 1)) == m_comment_string) + m_comments.push_back(m_current_line.substr(1, std::string::npos)); + else + m_state = true; + }; + auto check_comment = [this]() { return true; }; + process_chunk(read_comment, check_comment, false); + + auto read_relation = [this]() { + if (string_to_lower(m_current_line.substr( + 0, strlen(m_relation_string))) == m_relation_string) + { + m_relation = remove_whitespace( + m_current_line.substr(strlen(m_relation_string))); + } + else + m_state = true; + }; + // a relation has to be defined + auto check_relation = [this]() { return !m_relation.empty(); }; + process_chunk(read_relation, check_relation, true); + + auto read_attributes = [this]() { + if (string_to_lower(m_current_line.substr( + 0, strlen(m_attribute_string))) == m_attribute_string) + { + std::vector elems; + auto innner_string = + m_current_line.substr(strlen(m_attribute_string)); + split(innner_string, " ,\t\r\f\v", std::back_inserter(elems)); + std::transform( + elems.begin(), elems.end(), elems.begin(), + [](const auto& val) { return remove_whitespace(val); }); + // check if it is nominal + if (elems[1] == "{" || elems[1].front() == '{') + { + elems = clean_up(elems); + std::vector attributes( + elems.begin() + 1, elems.end()); + m_nominal_attributes.emplace_back( + std::make_pair(elems[0], attributes)); + m_attributes.emplace_back("nominal"); + return; + } + + auto is_date = std::find(elems.begin(), elems.end(), "date"); + if (is_date != elems.end()) + { + if (elems.begin() == is_date && elems.size() < 2) + { + // TODO: @attribute date [[date-format]] + } + else if (elems.begin() + 1 == is_date && elems.size() < 3) + { + // TODO: @attribute [name] date [[date-format]] + } + else + { + SG_SERROR("Error parsing date on line %d", m_line_number) + } + // m_attributes.emplace(std::make_pair(elems[0], + // "date")); + m_attributes.emplace_back("date"); + } + else if (elems.size() == 2) + { + auto type = string_to_lower(elems[1]); + // numeric attributes + if (type == "numeric" || type == "integer" || type == "real") + { + // m_attributes.emplace(std::make_pair(elems[0], + // "numeric")); + m_attributes.emplace_back("numeric"); + } + else if (type == "string") + { + // @ATTRIBUTE LCC string + // m_attributes.emplace(std::make_pair(elems[0], + // "string")); + m_attributes.emplace_back("string"); + } + else + SG_SERROR( + "Unexpected attribute type identifier \"%s\" " + "on line %d\n", + type.c_str(), m_line_number) + } + else + SG_SERROR( + "Unexpected format in @ATTRIBUTE on line %d\n", + m_line_number); + } + // comments in this section are ignored + else if (m_current_line.substr(0, 1) == m_comment_string) + { + return; + } + // if none of the others are true this is the end of the + // attributes section + else + { + m_state = true; + } + }; + + auto check_attributes = [this]() { + // attributes cannot be empty + return !m_attributes.empty(); + }; + process_chunk(read_attributes, check_attributes, true); + + auto read_data = [this]() { + // it's a comment and can be skipped + if (m_current_line.substr(0, 1) == m_comment_string) + return; + // it's the data string (i.e. @data"), does not provide + // information + if (string_to_lower(m_current_line.substr(0, strlen(m_data_string))) == + m_data_string) + { + return; + } + else + { + std::vector elems; + std::string type; + split(m_current_line, ",", std::back_inserter(elems)); + auto nominal_pos = m_nominal_attributes.begin(); + for (int i = 0; i < elems.size(); ++i) + { + type = m_attributes[i]; + if (type == "numeric") + { + m_data.push_back(std::stod(elems[i])); + } + else if (type == "nominal") + { + if (nominal_pos == m_nominal_attributes.end()) + SG_SERROR( + "Unexpected nominal value \"%s\" on line " + "%d\n", + elems[i].c_str(), m_line_number); + auto encoding = (*nominal_pos).second; + remove_char_inplace(elems[i], '\''); + auto pos = + std::find(encoding.begin(), encoding.end(), elems[i]); + if (pos == encoding.end()) + SG_SERROR( + "Unexpected value \"%s\" on line %d\n", + elems[i].c_str(), m_line_number); + float64_t idx = std::distance(encoding.begin(), pos); + m_data.push_back(idx); + nominal_pos = std::next(nominal_pos); + } + } + } + ++m_row_count; + }; + auto check_data = [this]() { + // check X values + SG_SDEBUG( + "size: %d, cols: %d, rows: %d", m_data.size(), + m_data.size() / m_row_count, m_row_count) + if (!m_data.empty()) + { + auto tmp = + SGMatrix(m_data.size() / m_row_count, m_row_count); + m_data_matrix = + SGMatrix(m_row_count, m_data.size() / m_row_count); + memcpy( + tmp.matrix, m_data.data(), m_data.size() * sizeof(float64_t)); + typename SGMatrix::EigenMatrixXtMap tmp_eigen = tmp; + typename SGMatrix::EigenMatrixXtMap m_data_matrix_eigen = + m_data_matrix; + + m_data_matrix_eigen = tmp_eigen.transpose(); + } + else + return false; + return true; + }; + process_chunk(read_data, check_data, true); +} diff --git a/src/shogun/io/ARFFFile.h b/src/shogun/io/ARFFFile.h new file mode 100644 index 00000000000..693b3119e1e --- /dev/null +++ b/src/shogun/io/ARFFFile.h @@ -0,0 +1,248 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Gil Hoben + */ + +#ifndef SHOGUN_ARFFFILE_H +#define SHOGUN_ARFFFILE_H + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace shogun +{ + namespace arff_detail + { + /** + * Checks if string is blank + * @param line to check + * @return bool whether line is empty + */ + SG_FORCED_INLINE bool string_is_blank(const std::string& line) + { + return line.find_first_not_of(" \t\r\f\v") == std::string::npos; + } + + /** + * Splits a line given a set of delimiter characters + * + * @tparam Out type of container + * @param s string to split + * @param delimiters a set of delimiter character + * @param result dynamic container where tokens are stored + */ + template + void split(const std::string& s, const char* delimiters, Out result) + { + std::stringstream ss(s); + std::string line; + while (std::getline(ss, line)) + { + size_t prev = 0, pos; + while ((pos = line.find_first_of(delimiters, prev)) != + std::string::npos) + { + if (pos > prev) + *(result++) = line.substr(prev, pos - prev); + prev = pos + 1; + } + if (prev < line.length()) + *(result++) = line.substr(prev, std::string::npos); + } + } + + /** + * Returns a string in lowercase. + * + * @param line string to process + * @return lowercase string + */ + SG_FORCED_INLINE std::string string_to_lower(const std::string& line) + { + std::string result; + std::transform( + line.begin(), line.end(), std::back_inserter(result), + [](const auto& val) { return std::tolower(val); }); + return result; + } + + /** + * Returns string without whitespace + * @param line string to process + * @return string without whitespace + */ + SG_FORCED_INLINE std::string remove_whitespace(const std::string& line) + { + std::string result = line; + result.erase( + std::remove_if(result.begin(), result.end(), ::isspace), + result.end()); + return result; + } + + /** + * Removes all occurences of a character in place + * @param line string to process + * @param character char to remove + */ + SG_FORCED_INLINE void + remove_char_inplace(std::string& line, char character) + { + line.erase( + std::remove_if( + line.begin(), line.end(), + [&character](auto const& val) { return val == character; }), + line.end()); + } + } // namespace arff_details + /** + * ARFFDeserializer parses files in the ARFF format. + * For information about this format see + * https://waikato.github.io/weka-wiki/arff_stable/ + */ + class ARFFDeserializer + { + public: + /** + * ARFFDeserializer constructor with a filename. + * Performs a check to see if a file can be streamed. + * Fails if file does not exist, or it cannot be opened, + * i.e. not the correct permission. + * + * @param filename the name of the file to parse + */ + explicit ARFFDeserializer(const std::string& filename) + { + m_file_stream = std::ifstream(filename); + if (m_file_stream.fail()) + { + SG_SERROR("Cannot open %s\n", filename.c_str()) + } + } + + /** + * Parse the file passed to the contructor. + * + */ + void read(); + + /** + * Returns the data processed after parsing. + * @return matrix with parsed data + */ + SGMatrix get_data() + { + return m_data_matrix; + } + + private: + /** + * Processes a chunk. A chunk is defined as a set of lines that + * are processed in the same way. A chunk ends when the func + * sets the internal m_state to false. + * Parsing can also end when the stream reaches EOF. + * + * @tparam LambdaT type of processing function + * @tparam CheckT type of check function + * @param func processing function that reads each line + * @param check_func function that checks the result from the processing + * function + * @param skip_first whether to stream the first line + */ + template + void process_chunk(LambdaT&& func, CheckT&& check_func, bool skip_first) + { + m_state = false; + + if (skip_first && !m_file_stream.eof()) + func(); + + while (!m_state && !m_file_done) + { + consume_line(func); + } + if (!check_func()) + { + SG_SERROR("Parsing error: %d", m_current_line.c_str()); + } + } + + /** + * Function called by process_chunk to process a "chunk" line by line. + * This function also checks if EOF has been reached. + * + * @tparam T type of processing function + * @param func line processing function + */ + template + void consume_line(T&& func) + { + if (m_file_stream.eof()) + { + m_file_done = true; + return; + } + std::getline(m_file_stream, m_current_line); + m_line_number++; + if (!arff_detail::string_is_blank(m_current_line)) + func(); + } + + /** + * Cleans up the tokens for nominal attributes. + * + * @param line the line with nominal attributes. + * @return returns a vector with the nominal values in the correct + * position. + */ + std::vector clean_up(std::vector& line); + + /** character used in file to comment out a line */ + static const char* m_comment_string; + /** characters to declare relations, i.e. @relation */ + static const char* m_relation_string; + /** characters to declare attributes, i.e. @attribute */ + static const char* m_attribute_string; + /** characters to declare data fields, i.e. @data */ + static const char* m_data_string; + + /** internal line number counter for exceptions */ + size_t m_line_number; + /** internal flag set true when string stream is EOF */ + bool m_file_done; + /** internal state when set to true switches parsing rules */ + bool m_state; + /** current row count of data */ + size_t m_row_count; + /** the string after m_relation_string*/ + std::string m_relation; + /** the shared file stream */ + std::ifstream m_file_stream; + /** the string where comments are stored */ + std::vector m_comments; + /** the string representing the current line being parsed */ + std::string m_current_line; + /** the attribute types in the order they are parsed */ + std::vector m_attributes; + /** a list of the learning target(s), i.e. y */ + std::vector m_target_strings; + /** the mapping of nominal attributes to their value */ + std::vector>> + m_nominal_attributes; + + /** dynamic continuous vector with the parsed data */ + std::vector m_data; + /** sgmatrix with the properly formatted data from m_data */ + SGMatrix m_data_matrix; + }; +} // namespace shogun + +#endif // SHOGUN_ARFFFILE_H