Skip to content

Commit

Permalink
Add and use wer tag data structure
Browse files Browse the repository at this point in the history
  • Loading branch information
nishchalb committed Apr 17, 2024
1 parent f1fa887 commit 7f8fae1
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
18 changes: 9 additions & 9 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
std::string last_label;
bool firstTk = true;

auto logger = logger::GetOrCreateLogger("NlpFstLoader");
// fuse multiple rows that have the same id/label into one entry only
for (auto &row : records) {
auto curr_tk = row.token;
Expand All @@ -37,15 +38,13 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
auto curr_row_tags = row.wer_tags;

// Update wer tags in records to real string labels
vector<string> real_wer_tags;
for (auto &tag : curr_row_tags) {
auto real_tag = tag;
if (mWerSidecar != Json::nullValue) {
real_tag = "###" + real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
tag.entity_type = mWerSidecar[tag.tag_id]["entity_type"].asString();
logger->info(tag.entity_type);
}
real_wer_tags.push_back(real_tag);
}
row.wer_tags = real_wer_tags;
row.wer_tags = curr_row_tags;
std::string speaker = row.speakerId;
mNlpRows.push_back(row);

Expand Down Expand Up @@ -411,17 +410,18 @@ std::string NlpReader::GetBestLabel(std::string &labels) {
return labels;
}

std::vector<std::string> NlpReader::GetWerTags(std::string &wer_tags_str) {
std::vector<std::string> wer_tags;
std::vector<WerTagEntry> NlpReader::GetWerTags(std::string &wer_tags_str) {
std::vector<WerTagEntry> wer_tags;
if (wer_tags_str == "[]") {
return wer_tags;
}
// wer_tags_str looks like: ['89', '90', '100']
int current_pos = 2;
auto pos = wer_tags_str.find("'", current_pos);
while (pos != -1) {
std::string wer_tag = wer_tags_str.substr(current_pos, pos - current_pos);
wer_tags.push_back(wer_tag);
WerTagEntry entry;
entry.tag_id = wer_tags_str.substr(current_pos, pos - current_pos);
wer_tags.push_back(entry);
current_pos = wer_tags_str.find("'", pos + 1) + 1;
if (current_pos == 0) {
break;
Expand Down
9 changes: 7 additions & 2 deletions src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
using namespace std;
using namespace fst;

struct WerTagEntry {
string tag_id;
string entity_type;
};

struct RawNlpRecord {
string token;
string speakerId;
Expand All @@ -27,7 +32,7 @@ struct RawNlpRecord {
string labels;
string best_label;
string best_label_id;
vector<string> wer_tags;
vector<WerTagEntry> wer_tags;
string confidence;
};

Expand All @@ -37,7 +42,7 @@ class NlpReader {
virtual ~NlpReader();
vector<RawNlpRecord> read_from_disk(const std::string &filename);
string GetBestLabel(std::string &labels);
vector<string> GetWerTags(std::string &wer_tags_str);
vector<WerTagEntry> GetWerTags(std::string &wer_tags_str);
string GetLabelId(std::string &label);
};

Expand Down
2 changes: 1 addition & 1 deletion src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
<< "[";
/* for (auto wer_tag : nlpRow.wer_tags) { */
for (auto it = stitch.nlpRow.wer_tags.begin(); it != stitch.nlpRow.wer_tags.end(); ++it) {
output_nlp_file << "'" << *it << "'";
output_nlp_file << "'" << it->tag_id << "'";
if (std::next(it) != stitch.nlpRow.wer_tags.end()) {
output_nlp_file << ", ";
}
Expand Down
15 changes: 6 additions & 9 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,16 @@ void RecordTagWer(const vector<Stitching>& stitches) {
for (const auto &stitch : stitches) {
if (!stitch.nlpRow.wer_tags.empty()) {
for (auto wer_tag : stitch.nlpRow.wer_tags) {
int tag_start = wer_tag.find_first_not_of('#');
int tag_end = wer_tag.find('_');
string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start);
wer_results.insert(std::pair<std::string, WerResult>(wer_tag_id, {0, 0, 0, 0, 0}));
wer_results.insert(std::pair<std::string, WerResult>(wer_tag.tag_id, {0, 0, 0, 0, 0}));
// Check with rfind since other comments can be there
bool del = stitch.comment.rfind("del", 0) == 0;
bool ins = stitch.comment.rfind("ins", 0) == 0;
bool sub = stitch.comment.rfind("sub", 0) == 0;
wer_results[wer_tag_id].insertions += ins;
wer_results[wer_tag_id].deletions += del;
wer_results[wer_tag_id].substitutions += sub;
wer_results[wer_tag.tag_id].insertions += ins;
wer_results[wer_tag.tag_id].deletions += del;
wer_results[wer_tag.tag_id].substitutions += sub;
if (!ins) {
wer_results[wer_tag_id].numWordsInReference += 1;
wer_results[wer_tag.tag_id].numWordsInReference += 1;
}
}
}
Expand Down Expand Up @@ -555,7 +552,7 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
string tk_wer_tags = "";
auto wer_tags = p_stitch.nlpRow.wer_tags;
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + wer_tag + "|";
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
}
string ref_tk = p_stitch.reftk;
string hyp_tk = p_stitch.hyptk;
Expand Down

0 comments on commit 7f8fae1

Please sign in to comment.