From c003ccf49c1ab2f7e306275bbbd09400619355f0 Mon Sep 17 00:00:00 2001 From: Ziqing Dong <55984224+d16144504@users.noreply.github.com> Date: Thu, 8 Oct 2020 22:43:13 +0800 Subject: [PATCH] Fixed bugs for data selection 1. Fixed bugs when generating sentence vectors of average word embeddings. 2. Fixed bugs when computing cosine similarities of sentences. --- beat/algorithms/loicbarrault/mt_lifelong_loop/1.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py b/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py index 6a6054c..3add5dd 100644 --- a/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py +++ b/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py @@ -289,11 +289,13 @@ def get_sen_vecs(data_dict_train, src_vocab, word_embs): for index_sen, sen in enumerate(data_dict_train['src']): splitsen = sen.split() + sen_length = 0 for token in splitsen: if token in src_vocab.keys(): + sen_length += 1 index_token = int(src_vocab[token].split()[0]) - sen_vecs[index_sen] = word_embs[index_token] - sen_vecs[index_sen] /= len(splitsen) + sen_vecs[index_sen] += word_embs[index_token] + sen_vecs[index_sen] /= sen_length return sen_vecs @@ -324,11 +326,13 @@ def data_selection_emb(N, data_dict_train, train_sen_vecs, doc_input, src_vocab, for sen in doc_input: sen_vec = torch.zeros(1, word_embs.size()[1], dtype=word_embs.dtype, device=word_embs.device) splitsen = sen.split() + sen_length = 0 for token in splitsen: if token in src_vocab.keys(): + sen_length += 1 index = int(src_vocab[token].split()[0]) sen_vec += word_embs[index] - sen_vec /= len(splitsen) + sen_vec /= sen_length coss = cos(sen_vec, train_sen_vecs) coss[torch.isnan(coss)] = 0 index_topn_sens = torch.topk(coss, n, largest=True).indices