Skip to content

Commit

Permalink
implement the taxonomic classification with queries on the annotation…
Browse files Browse the repository at this point in the history
… matrix

Signed-off-by: Radu Muntean <[email protected]>
  • Loading branch information
heracle committed Aug 2, 2021
1 parent 805735c commit 22eec24
Show file tree
Hide file tree
Showing 4 changed files with 467 additions and 19 deletions.
5 changes: 2 additions & 3 deletions metagraph/src/annotation/taxonomy/label_to_taxid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ namespace annot {
using mtg::common::logger;

void TaxonomyBase::assign_label_type(const std::string &label, bool *require_accversion_to_taxid_map) {
if (utils::starts_with(label, ">gi|")) {
if (utils::starts_with(label, "gi|")) {
// e.g. >gi|1070643132|ref|NC_031224.1| Arthrobacter phage Mudcat, complete genome
label_type = GEN_BANK;
*require_accversion_to_taxid_map = true;
} else if (utils::starts_with(label, ">") &&
utils::starts_with(utils::split_string(label, ":")[1], "taxid|")) {
} else if (utils::starts_with(utils::split_string(label, ":")[1], "taxid|")) {
// e.g. >kraken:taxid|2016032|NC_047834.1 Alteromonas virus vB_AspP-H4/4, complete genome
label_type = TAXID;
*require_accversion_to_taxid_map = false;
Expand Down
240 changes: 238 additions & 2 deletions metagraph/src/annotation/taxonomy/tax_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

#include "annotation/representation/annotation_matrix/annotation_matrix.hpp"
#include "common/unix_tools.hpp"
#include "common/seq_tools/reverse_complement.hpp"

#include "common/logger.hpp"

#define LOG2(X) ((unsigned) (8*sizeof (unsigned long long) - __builtin_clzll((X)) - 1))

namespace mtg {
namespace annot {

Expand Down Expand Up @@ -177,8 +180,241 @@ void TaxonomyClsAnno::rmq_preprocessing(const std::vector<TaxId> &tree_lineariza
}
}

TaxId TaxonomyClsAnno::assign_class(const std::string &sequence) const {
throw std::runtime_error("Assign class not implemented. Received " + sequence);
std::vector<TaxId> TaxonomyClsAnno::get_lca_taxids_for_seq(const std::string_view &sequence, bool reversed) const {
// num_kmers represents the total number of kmers parsed until the current time.
uint32_t num_kmers = 0;

// 'kmer_idx' and 'kmer_val' are storing the indexes and values of all the nonzero kmers in the given read.
// The list of kmers, 'kmer_val', will be further sent to "matrix.getrows()" method;
// The list of indexes, 'kmer_idx', will be used to associate one row from "matrix.getrows()" with the corresponding kmer index.
std::vector<uint32_t> kmer_idx;
std::vector<node_index> kmer_val;

if (sequence.size() >= std::numeric_limits<uint32_t>::max()) {
logger->error("The given sequence contains more than 2^32 bp.");
std::exit(1);
}

auto anno_graph = _anno_matrix->get_graph_ptr();
anno_graph->map_to_nodes(sequence, [&](node_index i) {
num_kmers++;
if (i <= 0 || i >= anno_graph->max_index()) {
return;
}
kmer_val.push_back(i - 1);
kmer_idx.push_back(num_kmers - 1);
});

// Compute the LCA normalized taxid for each nonzero kmer in the given read.
const auto unique_matrix_rows = _anno_matrix->get_annotation().get_matrix().get_rows(kmer_val);
//TODO make sure that this function works even if we have duplications in 'rows'. Then, delete this error catch.
if (kmer_val.size() != unique_matrix_rows.size()) {
throw std::runtime_error("Internal error: There must be no duplications in the received set of 'rows' in 'call_annotated_rows'.");
}

if (unique_matrix_rows.size() >= std::numeric_limits<uint32_t>::max()) {
throw std::runtime_error("Internal error: There must be less than 2^32 unique rows. Reduce the query batch size.");
}

const auto &label_encoder = _anno_matrix->get_annotation().get_label_encoder();

TaxId taxid;
uint64_t cnt_kmer_idx = 0;
std::vector<TaxId> curr_kmer_taxids;
std::vector<TaxId> seq_taxids(num_kmers);

for (auto row : unique_matrix_rows) {
for (auto cell : row) {
if (get_taxid_from_label(label_encoder.decode(cell), &taxid)) {
curr_kmer_taxids.push_back(taxid);
}
}
if (curr_kmer_taxids.size() != 0) {
if (not reversed) {
seq_taxids[kmer_idx[cnt_kmer_idx]] = find_lca(curr_kmer_taxids);
} else {
seq_taxids[num_kmers - 1 - kmer_idx[cnt_kmer_idx]] = find_lca(curr_kmer_taxids);
}
}
cnt_kmer_idx++;
curr_kmer_taxids.clear();
}

return seq_taxids;
}

TaxId TaxonomyBase::assign_class(const std::string &sequence) const {
std::vector<TaxId> forward_taxids = get_lca_taxids_for_seq(sequence, false);

std::string reversed_sequence(sequence);
reverse_complement(reversed_sequence.begin(), reversed_sequence.end());
std::vector<TaxId> backward_taxids = get_lca_taxids_for_seq(reversed_sequence, true);

tsl::hopscotch_map<TaxId, uint64_t> num_kmers_per_node;

// total_discovered_kmers represents the number of nonzero kmers according to both forward and reversed read.
uint32_t num_discovered_kmers = 0;
const uint32_t num_total_kmers = forward_taxids.size();

// Find the LCA taxid for each kmer without any dependency on the orientation of the read.
for (uint32_t i = 0; i < num_total_kmers; ++i) {
if (forward_taxids[i] == 0 && backward_taxids[i] == 0) {
continue;
}
TaxId curr_taxid;
if (backward_taxids[i] == 0) {
curr_taxid = forward_taxids[i];
} else if (forward_taxids[i] == 0) {
curr_taxid = backward_taxids[i];
} else {
// In case that both 'forward_taxid[i]' and 'backward_taxids[i]' are nonzero, compute the LCA.
TaxId forward_taxid = forward_taxids[i];
TaxId backward_taxid = backward_taxids[i];
if (forward_taxid == 0) {
curr_taxid = backward_taxid;
} else if (backward_taxid == 0) {
curr_taxid = forward_taxid;
} else {
curr_taxid = find_lca({forward_taxid, backward_taxid});
}
}
if (curr_taxid) {
num_discovered_kmers ++;
num_kmers_per_node[curr_taxid]++;
}
}

if (num_discovered_kmers <= _kmers_discovery_rate * num_total_kmers) {
return 0; // 0 is a wildcard for not enough discovered kmers.
}

tsl::hopscotch_set<TaxId> nodes_already_propagated;
tsl::hopscotch_map<TaxId, uint64_t> node_scores;

uint32_t desired_number_kmers = num_discovered_kmers * _lca_coverage_rate;
TaxId best_lca = root_node;
uint32_t best_lca_dist_to_root = 1;

// Update the nodes' score by iterating through all the nodes with nonzero kmers.
for (const pair<TaxId, uint64_t> &node_pair : num_kmers_per_node) {
TaxId start_node = node_pair.first;
this->update_scores_and_lca(start_node, num_kmers_per_node, desired_number_kmers, &node_scores,
&nodes_already_propagated, &best_lca, &best_lca_dist_to_root);
}
return best_lca;
}


void TaxonomyBase::update_scores_and_lca(const TaxId start_node,
const tsl::hopscotch_map<TaxId, uint64_t> &num_kmers_per_node,
const uint64_t desired_number_kmers,
tsl::hopscotch_map<TaxId, uint64_t> *node_scores,
tsl::hopscotch_set<TaxId> *nodes_already_propagated,
TaxId *best_lca,
uint32_t *best_lca_dist_to_root) const {
if (nodes_already_propagated->count(start_node)) {
return;
}
uint64_t score_from_processed_parents = 0;
uint64_t score_from_unprocessed_parents = num_kmers_per_node.at(start_node);

// processed_parents represents the set of nodes on the path start_node->root that have already been processed in the previous iterations.
std::vector<TaxId> processed_parents;
std::vector<TaxId> unprocessed_parents;

TaxId act_node = start_node;
unprocessed_parents.push_back(act_node);

while (act_node != root_node) {
act_node = node_parent.at(act_node);
if (!nodes_already_propagated->count(act_node)) {
if (num_kmers_per_node.count(act_node)) {
score_from_unprocessed_parents += num_kmers_per_node.at(act_node);
}
unprocessed_parents.push_back(act_node);
} else {
if (num_kmers_per_node.count(act_node)) {
score_from_processed_parents += num_kmers_per_node.at(act_node);
}
processed_parents.push_back(act_node);
}
}
// The score of all the nodes in 'processed_parents' will be updated with 'score_from_unprocessed_parents' only.
// The nodes in 'unprocessed_parents' will be updated with the sum 'score_from_processed_parents + score_from_unprocessed_parents'.
for (uint64_t i = 0; i < unprocessed_parents.size(); ++i) {
TaxId &act_node = unprocessed_parents[i];
(*node_scores)[act_node] =
score_from_processed_parents + score_from_unprocessed_parents;
nodes_already_propagated->insert(act_node);

uint64_t act_dist_to_root =
processed_parents.size() + unprocessed_parents.size() - i;

// Test if the current node's score would be a better LCA result.
if ((*node_scores)[act_node] >= desired_number_kmers
&& (act_dist_to_root > *best_lca_dist_to_root
|| (act_dist_to_root == *best_lca_dist_to_root && (*node_scores)[act_node] > (*node_scores)[*best_lca]))) {
*best_lca = act_node;
*best_lca_dist_to_root = act_dist_to_root;
}
}
for (uint64_t i = 0; i < processed_parents.size(); ++i) {
TaxId &act_node = processed_parents[i];
(*node_scores)[act_node] += score_from_unprocessed_parents;

uint64_t act_dist_to_root = processed_parents.size() - i;
if ((*node_scores)[act_node] >= desired_number_kmers
&& (act_dist_to_root > *best_lca_dist_to_root
|| (act_dist_to_root == *best_lca_dist_to_root && (*node_scores)[act_node] > (*node_scores)[*best_lca]))) {
*best_lca = act_node;
*best_lca_dist_to_root = act_dist_to_root;
}
}
}

TaxId TaxonomyClsAnno::find_lca(const std::vector<TaxId> &taxids) const {
if (taxids.empty()) {
logger->error("Internal error: Can't find LCA for an empty set of normalized taxids.");
std::exit(1);
}
uint64_t left_idx = node_to_linearization_idx.at(taxids[0]);
uint64_t right_idx = node_to_linearization_idx.at(taxids[0]);

for (const TaxId &taxid : taxids) {
if (node_to_linearization_idx.at(taxid) < left_idx) {
left_idx = node_to_linearization_idx.at(taxid);
}
if (node_to_linearization_idx.at(taxid) > right_idx) {
right_idx = node_to_linearization_idx.at(taxid);
}
}
// The node with maximum node_depth in 'linearization[left_idx : right_idx+1]' is the LCA of the given set.

// Find the maximum node_depth between the 2 overlapping intervals of size 2^log_dist.
uint32_t log_dist = LOG2(right_idx - left_idx);
if (rmq_data.size() <= log_dist) {
logger->error("Internal error: the RMQ was not precomputed before the LCA queries.");
std::exit(1);
}

uint32_t left_lca = rmq_data[log_dist][left_idx];
uint32_t right_lca = rmq_data[log_dist][right_idx - (1 << log_dist) + 1];

if (node_depth.at(left_lca) > node_depth.at(right_lca)) {
return left_lca;
}
return right_lca;
}

std::vector<TaxId> TaxonomyClsImportDB::get_lca_taxids_for_seq(const std::string_view &sequence, bool reversed) const {
cerr << "Assign class not implemented reversed = " << reversed << "\n";
throw std::runtime_error("get_lca_taxids_for_seq TaxonomyClsImportDB not implemented. Received seq size" + sequence.size());
exit(0);
}

TaxId TaxonomyClsImportDB::find_lca(const std::vector<TaxId> &taxids) const {
throw std::runtime_error("find_lca TaxonomyClsImportDB not implemented. Received taxids size" + taxids.size());
exit(0);
}

} // namespace annot
Expand Down
27 changes: 17 additions & 10 deletions metagraph/src/annotation/taxonomy/tax_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ class TaxonomyBase {

virtual ~TaxonomyBase() {};

// TODO implement
virtual TaxId assign_class(const std::string &sequence) const = 0;
TaxId assign_class(const std::string &sequence) const;

PROTECTED_TESTABLE:
void assign_label_type(const std::string &label, bool *require_accversion_to_taxid_map);

// TODO implement.
TaxId find_lca(const std::vector<TaxId> &taxids) const;
virtual TaxId find_lca(const std::vector<TaxId> &taxids) const = 0;

std::string get_accession_version_from_label(const std::string &label) const;

Expand All @@ -57,7 +55,6 @@ class TaxonomyBase {
*/
void read_accversion_to_taxid_map(const std::string &filepath, const graph::AnnotatedDBG *anno_matrix);

// TODO implement.
/**
* Update the current node_scores and best_lca by taking into account the weight of the start_node and all its ancestors.
*
Expand All @@ -75,7 +72,13 @@ class TaxonomyBase {
tsl::hopscotch_map<TaxId, uint64_t> *node_scores,
tsl::hopscotch_set<TaxId> *nodes_already_propagated,
TaxId *best_lca,
uint32_t *best_lca_dist_to_root);
uint32_t *best_lca_dist_to_root) const;

/**
* Get the list of LCA taxid for each kmer in a given sequences.
* The sequence can be given in forward or in reversed orientation.
*/
virtual std::vector<TaxId> get_lca_taxids_for_seq(const std::string_view &sequence, bool reversed) const = 0;

LabelType label_type;

Expand Down Expand Up @@ -104,7 +107,10 @@ class TaxonomyClsImportDB : public TaxonomyBase {
TaxonomyClsImportDB(const std::string &taxdb_filepath,
const double lca_coverage_rate,
const double kmers_discovery_rate);
TaxId assign_class(const std::string &sequence) const;

PRIVATE_TESTABLE:
std::vector<TaxId> get_lca_taxids_for_seq(const std::string_view &sequence, bool reversed) const;
TaxId find_lca(const std::vector<TaxId> &taxids) const;
};

class TaxonomyClsAnno : public TaxonomyBase {
Expand All @@ -128,9 +134,6 @@ class TaxonomyClsAnno : public TaxonomyBase {
// todo implement
void export_taxdb(const std::string &filepath) const;

// todo implement
TaxId assign_class(const std::string &sequence) const;

PRIVATE_TESTABLE:
/**
* Reads and returns the taxonomic tree as a list of children.
Expand Down Expand Up @@ -162,6 +165,10 @@ class TaxonomyClsAnno : public TaxonomyBase {
const ChildrenList &tree,
std::vector<TaxId> *tree_linearization);

TaxId find_lca(const std::vector<TaxId> &taxids) const;

std::vector<TaxId> get_lca_taxids_for_seq(const std::string_view &sequence, bool reversed) const;

/**
* rmq_data[0] contains the taxonomic tree linearization
* (e.g. for root 1 and edges={1-2; 1-3}, the linearization is "1 2 1 3 1").
Expand Down
Loading

0 comments on commit 22eec24

Please sign in to comment.