Skip to content

Commit

Permalink
Add different way of calculating loss for eval UNOPTIMIZED!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
gwilczynski95 committed Dec 29, 2023
1 parent 661efe3 commit 02e5e40
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ def calculate_metrics(self, data_loader, loss_fn, temperature=1.):
for src, tgt, src_lens, tgt_lens in data_loader:
src = src.to(self.device).T
tgt = tgt.to(self.device).T
tgt_input = tgt[:, :-1]
tgt = tgt[:, 1:]

src_mask, _ = create_masks(src, None, self.set_loader.pad_idx)
src_mask, tgt_mask = create_masks(src, tgt_input, self.set_loader.pad_idx)

out_tokens, out_probas = self.model.forward_gen(
out_probas = self.model(src, tgt_input, src_lens, tgt_lens, src_mask, tgt_mask)
out_tokens, _ = self.model.forward_gen(
src, src_lens, src_mask, max(tgt_lens) - 1, self.set_loader.bos_idx, temperature
)
output_texts = parse_tokens(out_tokens, self.set_loader.vocab_transform["en"])
Expand Down

0 comments on commit 02e5e40

Please sign in to comment.