From dd5037045685d68c00935a0176d322bd4778f2e2 Mon Sep 17 00:00:00 2001 From: hecko Date: Tue, 19 Dec 2023 11:29:57 +0100 Subject: [PATCH] Fix batch size 1 by specifying squeeze dims --- Modules/slmadv.py | 24 ++++++++++++------------ Utils/JDC/model.py | 2 +- losses.py | 2 +- train_finetune.py | 8 ++++---- train_finetune_accelerate.py | 8 ++++---- train_second.py | 8 ++++---- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/Modules/slmadv.py b/Modules/slmadv.py index 11acb914..5ab3c50d 100644 --- a/Modules/slmadv.py +++ b/Modules/slmadv.py @@ -149,41 +149,41 @@ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_t if use_rec: # use reconstructed (shorter lengths), do length invariant regularization if wav.size(-1) > y_pred.size(-1): real_GP = wav[:, : , :crop_size] - out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) - out_org = self.wl.discriminator_forward(wav.detach().squeeze()) + out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze(0)) + out_org = self.wl.discriminator_forward(wav.detach().squeeze(0)) loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) if np.random.randint(0, 2) == 0: - d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean() + d_loss = self.wl.discriminator(real_GP.detach().squeeze(0), y_pred.detach().squeeze(0)).mean() else: - d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() + d_loss = self.wl.discriminator(wav.detach().squeeze(0), y_pred.detach().squeeze(0)).mean() else: real_GP = y_pred[:, : , :crop_size] - out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) - out_org = self.wl.discriminator_forward(y_pred.detach().squeeze()) + out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze(0)) + out_org = self.wl.discriminator_forward(y_pred.detach().squeeze(0)) loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) if np.random.randint(0, 2) == 0: - d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean() + d_loss = self.wl.discriminator(wav.detach().squeeze(0), real_GP.detach().squeeze(0)).mean() else: - d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() + d_loss = self.wl.discriminator(wav.detach().squeeze(0), y_pred.detach().squeeze(0)).mean() # regularization (ignore length variation) d_loss += loss_reg - out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze()) - out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze()) + out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze(0)) + out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze(0)) # regularization (ignore reconstruction artifacts) d_loss += F.l1_loss(out_gt, out_rec) else: - d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() + d_loss = self.wl.discriminator(wav.detach().squeeze(0), y_pred.detach().squeeze(0)).mean() else: d_loss = 0 # generator loss - gen_loss = self.wl.generator(y_pred.squeeze()) + gen_loss = self.wl.generator(y_pred.squeeze(0)) gen_loss = gen_loss.mean() diff --git a/Utils/JDC/model.py b/Utils/JDC/model.py index 83cd266d..a59866a4 100644 --- a/Utils/JDC/model.py +++ b/Utils/JDC/model.py @@ -134,7 +134,7 @@ def forward(self, x): # sizes: (b, 31, 722), (b, 31, 2) # classifier output consists of predicted pitch classes per frame # detector output consists of: (isvoice, notvoice) estimates per frame - return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out + return torch.abs(classifier_out.squeeze((1, 2))), GAN_feature, poolblock_out @staticmethod def init_weights(m): diff --git a/losses.py b/losses.py index 1766acd1..ce091d3f 100644 --- a/losses.py +++ b/losses.py @@ -203,7 +203,7 @@ def forward(self, wav, y_rec): wav_16 = self.resample(wav) wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states y_rec_16 = self.resample(y_rec) - y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True).hidden_states + y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states floss = 0 for er, eg in zip(wav_embeddings, y_rec_embeddings): diff --git a/train_finetune.py b/train_finetune.py index 3c650747..ac5b7d2e 100644 --- a/train_finetune.py +++ b/train_finetune.py @@ -304,8 +304,8 @@ def main(config_path): s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) gs.append(s) - s_dur = torch.stack(ss).squeeze() # global prosodic styles - gs = torch.stack(gs).squeeze() # global acoustic styles + s_dur = torch.stack(ss).squeeze(1) # global prosodic styles + gs = torch.stack(gs).squeeze(1) # global acoustic styles s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) @@ -388,7 +388,7 @@ def main(config_path): with torch.no_grad(): F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) - F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze() + F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2]) N_real = log_norm(gt.unsqueeze(1)).squeeze(1) @@ -415,7 +415,7 @@ def main(config_path): loss_mel = stft_loss(y_rec, wav) loss_gen_all = gl(wav, y_rec).mean() - loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean() + loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean() loss_ce = 0 loss_dur = 0 diff --git a/train_finetune_accelerate.py b/train_finetune_accelerate.py index 4cfd95f7..58ffeeaf 100644 --- a/train_finetune_accelerate.py +++ b/train_finetune_accelerate.py @@ -311,8 +311,8 @@ def main(config_path): s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) gs.append(s) - s_dur = torch.stack(ss).squeeze() # global prosodic styles - gs = torch.stack(gs).squeeze() # global acoustic styles + s_dur = torch.stack(ss).squeeze(1) # global prosodic styles + gs = torch.stack(gs).squeeze(1) # global acoustic styles s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) @@ -395,7 +395,7 @@ def main(config_path): with torch.no_grad(): F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) - F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze() + F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2]) N_real = log_norm(gt.unsqueeze(1)).squeeze(1) @@ -422,7 +422,7 @@ def main(config_path): loss_mel = stft_loss(y_rec, wav) loss_gen_all = gl(wav, y_rec).mean() - loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean() + loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean() loss_ce = 0 loss_dur = 0 diff --git a/train_second.py b/train_second.py index fb1048dc..70bdbc71 100644 --- a/train_second.py +++ b/train_second.py @@ -302,8 +302,8 @@ def main(config_path): s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) gs.append(s) - s_dur = torch.stack(ss).squeeze() # global prosodic styles - gs = torch.stack(gs).squeeze() # global acoustic styles + s_dur = torch.stack(ss).squeeze(1) # global prosodic styles + gs = torch.stack(gs).squeeze(1) # global acoustic styles s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) @@ -381,7 +381,7 @@ def main(config_path): with torch.no_grad(): F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) - F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze() + F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2]) asr_real = model.text_aligner.get_feature(gt) @@ -421,7 +421,7 @@ def main(config_path): loss_gen_all = gl(wav, y_rec).mean() else: loss_gen_all = 0 - loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean() + loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean() loss_ce = 0 loss_dur = 0