Skip to content

Commit

Permalink
Add proper subsequent mask to generation
Browse files Browse the repository at this point in the history
  • Loading branch information
gwilczynski95 committed Dec 29, 2023
1 parent 31481bf commit 71aa07d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 5 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ def _tensor_transform(token_ids: List[int], bos_idx: int, eos_idx: int):
)


def get_time_mask(size):
return torch.triu(torch.ones((1, size, size)), diagonal=1) == 0


def create_masks(src, tgt, pad_idx):
src_mask = (src != pad_idx).unsqueeze(-2)

tgt_mask = None
if tgt is not None:
pad_tgt_mask = (tgt != pad_idx).unsqueeze(-2)
time_mask = torch.triu(torch.ones((pad_tgt_mask.shape[-1], pad_tgt_mask.shape[-1])), diagonal=1) == 0
time_mask = get_time_mask(pad_tgt_mask.shape[-1])
time_mask = time_mask.type_as(pad_tgt_mask.data)
tgt_mask = pad_tgt_mask & time_mask

Expand Down
4 changes: 3 additions & 1 deletion transformer/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from data import get_time_mask


class LinearLayer(nn.Module):
def __init__(self, in_dimension, out_dimension, weights_initialization="glorot_uniform",
Expand Down Expand Up @@ -565,7 +567,7 @@ def forward_gen(self, x, src_lens, src_mask, max_len, bos_idx, temperature):
out_lens = [1] * x.shape[0]
out_probas = None
for i in range(max_len):
tgt_mask = torch.ones((1, i + 1, i + 1), dtype=torch.bool).to(self.device)
tgt_mask = get_time_mask(i + 1).bool().to(self.device)
dec_probas = self.decoder(out_tokens, out_lens, enc_x, src_mask, tgt_mask)
if out_probas is None:
out_probas = dec_probas
Expand Down

0 comments on commit 71aa07d

Please sign in to comment.