diff --git a/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py b/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py index 6a6054c..3add5dd 100644 --- a/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py +++ b/beat/algorithms/loicbarrault/mt_lifelong_loop/1.py @@ -289,11 +289,13 @@ def get_sen_vecs(data_dict_train, src_vocab, word_embs): for index_sen, sen in enumerate(data_dict_train['src']): splitsen = sen.split() + sen_length = 0 for token in splitsen: if token in src_vocab.keys(): + sen_length += 1 index_token = int(src_vocab[token].split()[0]) - sen_vecs[index_sen] = word_embs[index_token] - sen_vecs[index_sen] /= len(splitsen) + sen_vecs[index_sen] += word_embs[index_token] + sen_vecs[index_sen] /= sen_length return sen_vecs @@ -324,11 +326,13 @@ def data_selection_emb(N, data_dict_train, train_sen_vecs, doc_input, src_vocab, for sen in doc_input: sen_vec = torch.zeros(1, word_embs.size()[1], dtype=word_embs.dtype, device=word_embs.device) splitsen = sen.split() + sen_length = 0 for token in splitsen: if token in src_vocab.keys(): + sen_length += 1 index = int(src_vocab[token].split()[0]) sen_vec += word_embs[index] - sen_vec /= len(splitsen) + sen_vec /= sen_length coss = cos(sen_vec, train_sen_vecs) coss[torch.isnan(coss)] = 0 index_topn_sens = torch.topk(coss, n, largest=True).indices