diff --git a/transformer/blocks.py b/transformer/blocks.py index 92d27fa..d80fb3f 100644 --- a/transformer/blocks.py +++ b/transformer/blocks.py @@ -561,29 +561,23 @@ def forward(self, x, y, src_lens, tgt_lens, src_mask, tgt_mask): out = self.decoder(y, tgt_input_lens, enc_x, src_mask, tgt_mask) return out - def forward_gen(self, x, src_lens, src_mask, max_len, bos_idx, temperature): + def forward_gen(self, x, src_lens, src_mask, max_len, bos_idx, temperature, greedy=True): enc_x = self.encoder(x, src_lens, src_mask) out_tokens = torch.full([x.shape[0], 1], bos_idx, dtype=torch.int64, device=x.device) out_lens = [1] * x.shape[0] - out_probas = None for i in range(max_len): tgt_mask = get_time_mask(i + 1).bool().to(self.device) dec_probas = self.decoder(out_tokens, out_lens, enc_x, src_mask, tgt_mask) - if out_probas is None: - out_probas = dec_probas + if greedy: + _, dec_tokens = torch.max(dec_probas[:, -1, :], dim=-1, keepdim=True) else: - out_probas = torch.cat([ - out_probas, - dec_probas[:, -1:, :] - ], - dim=1 + _p = torch.softmax(dec_probas / temperature, dim=-1).cpu().detach().numpy()[:, -1, :] + dec_tokens = np.array( + [ + [np.random.choice(np.arange(_p.shape[-1]), p=_p[x])] for x in range(x.shape[0]) + ] ) - _p = torch.softmax(dec_probas / temperature, dim=-1).cpu().detach().numpy()[:, -1, :] - dec_tokens = np.array( - [ - [np.random.choice(np.arange(_p.shape[-1]), p=_p[x])] for x in range(x.shape[0]) - ] - ) - out_tokens = torch.cat([out_tokens, torch.tensor(dec_tokens, device=x.device)], dim=-1) + dec_tokens = torch.tensor(dec_tokens, device=x.device) + out_tokens = torch.cat([out_tokens, dec_tokens], dim=-1) out_lens = [x + 1 for x in out_lens] - return out_tokens, out_probas + return out_tokens, dec_probas