From 02e5e40a53f5c1f902c78d5ecdf014ccd4d02e4a Mon Sep 17 00:00:00 2001 From: gwilczynski95 Date: Fri, 29 Dec 2023 11:31:49 +0100 Subject: [PATCH] Add different way of calculating loss for eval UNOPTIMIZED!!! --- trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trainer.py b/trainer.py index 5db8177..13c0ecf 100644 --- a/trainer.py +++ b/trainer.py @@ -109,11 +109,13 @@ def calculate_metrics(self, data_loader, loss_fn, temperature=1.): for src, tgt, src_lens, tgt_lens in data_loader: src = src.to(self.device).T tgt = tgt.to(self.device).T + tgt_input = tgt[:, :-1] tgt = tgt[:, 1:] - src_mask, _ = create_masks(src, None, self.set_loader.pad_idx) + src_mask, tgt_mask = create_masks(src, tgt_input, self.set_loader.pad_idx) - out_tokens, out_probas = self.model.forward_gen( + out_probas = self.model(src, tgt_input, src_lens, tgt_lens, src_mask, tgt_mask) + out_tokens, _ = self.model.forward_gen( src, src_lens, src_mask, max(tgt_lens) - 1, self.set_loader.bos_idx, temperature ) output_texts = parse_tokens(out_tokens, self.set_loader.vocab_transform["en"])