Skip to content

Commit

Permalink
Add greedy decode to forward_gen
Browse files Browse the repository at this point in the history
  • Loading branch information
gwilczynski95 committed Dec 29, 2023
1 parent 71aa07d commit 661efe3
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions transformer/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,29 +561,23 @@ def forward(self, x, y, src_lens, tgt_lens, src_mask, tgt_mask):
out = self.decoder(y, tgt_input_lens, enc_x, src_mask, tgt_mask)
return out

def forward_gen(self, x, src_lens, src_mask, max_len, bos_idx, temperature):
def forward_gen(self, x, src_lens, src_mask, max_len, bos_idx, temperature, greedy=True):
enc_x = self.encoder(x, src_lens, src_mask)
out_tokens = torch.full([x.shape[0], 1], bos_idx, dtype=torch.int64, device=x.device)
out_lens = [1] * x.shape[0]
out_probas = None
for i in range(max_len):
tgt_mask = get_time_mask(i + 1).bool().to(self.device)
dec_probas = self.decoder(out_tokens, out_lens, enc_x, src_mask, tgt_mask)
if out_probas is None:
out_probas = dec_probas
if greedy:
_, dec_tokens = torch.max(dec_probas[:, -1, :], dim=-1, keepdim=True)
else:
out_probas = torch.cat([
out_probas,
dec_probas[:, -1:, :]
],
dim=1
_p = torch.softmax(dec_probas / temperature, dim=-1).cpu().detach().numpy()[:, -1, :]
dec_tokens = np.array(
[
[np.random.choice(np.arange(_p.shape[-1]), p=_p[x])] for x in range(x.shape[0])
]
)
_p = torch.softmax(dec_probas / temperature, dim=-1).cpu().detach().numpy()[:, -1, :]
dec_tokens = np.array(
[
[np.random.choice(np.arange(_p.shape[-1]), p=_p[x])] for x in range(x.shape[0])
]
)
out_tokens = torch.cat([out_tokens, torch.tensor(dec_tokens, device=x.device)], dim=-1)
dec_tokens = torch.tensor(dec_tokens, device=x.device)
out_tokens = torch.cat([out_tokens, dec_tokens], dim=-1)
out_lens = [x + 1 for x in out_lens]
return out_tokens, out_probas
return out_tokens, dec_probas

0 comments on commit 661efe3

Please sign in to comment.