Skip to content

Commit

Permalink
Format & polish training and webui
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Sep 12, 2023
1 parent c99131f commit f63acb9
Show file tree
Hide file tree
Showing 18 changed files with 91 additions and 115 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
VITS2 Backbone with bert
## 成熟的旅行者/开拓者/舰长/博士/sensei/猎魔人/喵喵露/V应该参阅代码自己学习如何训练。
### 严禁将此项目用于一切违反《中华人民共和国宪法》,《中华人民共和国刑法》,《中华人民共和国治安管理处罚法》和《中华人民共和国民法典》之用途。
#### Video:https://www.bilibili.com/video/BV1hp4y1K78E
#### Video:https://www.bilibili.com/video/BV1hp4y1K78E
#### Demo:https://www.bilibili.com/video/BV1TF411k78w
5 changes: 2 additions & 3 deletions attentions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import math
import torch
from torch import nn
Expand Down Expand Up @@ -341,7 +340,7 @@ def _matmul_with_relative_keys(self, x, y):
return ret

def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
Expand Down Expand Up @@ -385,7 +384,7 @@ def _absolute_position_to_relative_position(self, x):
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
# pad along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
Expand Down
11 changes: 4 additions & 7 deletions commons.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


Expand All @@ -16,8 +14,8 @@ def get_padding(kernel_size, dilation=1):


def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape


Expand Down Expand Up @@ -110,8 +108,8 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):


def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape


Expand All @@ -132,7 +130,6 @@ def generate_path(duration, mask):
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device

b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
Expand Down
3 changes: 2 additions & 1 deletion configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
"c_kl": 1.0,
"skip_optimizer": true
},
"data": {
"training_files": "filelists/train.list",
Expand Down
6 changes: 2 additions & 4 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import time
import os
import random
import numpy as np
import torch
import torch.utils.data
from tqdm import tqdm
from loguru import logger
import commons
from mel_processing import spectrogram_torch, mel_spectrogram_torch, spec_to_mel_torch
from mel_processing import spectrogram_torch, mel_spectrogram_torch
from utils import load_wav_to_torch, load_filepaths_and_text
from text import cleaned_text_to_sequence, get_bert

Expand Down Expand Up @@ -100,7 +98,7 @@ def get_audio(self, filename):
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} {} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
filename, sampling_rate, self.sampling_rate
)
)
audio_norm = audio / self.max_wav_value
Expand Down
3 changes: 0 additions & 3 deletions losses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import torch
from torch.nn import functional as F

import commons


def feature_loss(fmap_r, fmap_g):
Expand Down
11 changes: 0 additions & 11 deletions mel_processing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
import math
import os
import random
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data
import numpy as np
import librosa
import librosa.util as librosa_util
from librosa.util import normalize, pad_center, tiny
from scipy.signal import get_window
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn

MAX_WAV_VALUE = 32768.0
Expand Down
5 changes: 1 addition & 4 deletions modules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import functional as F

from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm

import commons
Expand Down
8 changes: 3 additions & 5 deletions resample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import argparse
import librosa
import numpy as np
from multiprocessing import Pool, cpu_count

import soundfile
from scipy.io import wavfile
from tqdm import tqdm


Expand All @@ -29,9 +27,9 @@ def process(item):
"--out_dir", type=str, default="./dataset", help="path to target dir"
)
args = parser.parse_args()
# processs = 8
processs = cpu_count() - 2 if cpu_count() > 4 else 1
pool = Pool(processes=processs)
# processes = 8
processes = cpu_count() - 2 if cpu_count() > 4 else 1
pool = Pool(processes=processes)

for speaker in os.listdir(args.in_dir):
spk_dir = os.path.join(args.in_dir, speaker)
Expand Down
19 changes: 4 additions & 15 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,9 @@ def get_text(text, language_str, hps):
else:
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
assert bert.shape[-1] == len(phone), (
bert.shape,
len(phone),
sum(word2ph),
p1,
p2,
t1,
t2,
pold,
pold2,
word2ph,
text,
w2pho,
)
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
Expand Down Expand Up @@ -126,7 +115,7 @@ def wav2(i, o, format):
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
**hps.model,
).to(dev)
_ = net_g.eval()

Expand Down
2 changes: 0 additions & 2 deletions text/chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import cn2an
from pypinyin import lazy_pinyin, Style

from text import symbols
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi

Expand Down Expand Up @@ -96,7 +95,6 @@ def _g2p(segments):
tones_list = []
word2ph = []
for seg in segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
Expand Down
16 changes: 8 additions & 8 deletions text/cmudict.rep
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
## origin.
##
## cmudict.0.6 is the fifth release of cmudict, first released as cmudict.0.1
## in September of 1993. There was no generally available public release
## in September of 1993. There was no generally available public release
## of version 0.5.
##
## See the README in this directory before you use this dictionary.
Expand All @@ -16,11 +16,11 @@
## Alex Rudnicky, Jack Mostow, Roni Rosenfeld, Richard Stern,
## Matthew Siegler, Kevin Lenzo, Maxine Eskenazi, Mosur Ravishankar,
## Eric Thayer, Kristie Seymore, and Raj Reddy at CMU; Lin Chase at
## LIMSI; Doug Paul at MIT Lincoln Labs; Ben Serridge at MIT SLS; Murray
## Spiegel at Bellcore; Tony Robinson at Cambridge UK; David Bowness of
## CAE Electronics Ltd. and CRIM; Stephen Hocking; Jerry Quinn at BNR
## Canada, and Marshal Midden for bringing to our attention problems and
## inadequacies with the first releases. Most special thanks to Bob Weide
## LIMSI; Doug Paul at MIT Lincoln Labs; Ben Serridge at MIT SLS; Murray
## Spiegel at Bellcore; Tony Robinson at Cambridge UK; David Bowness of
## CAE Electronics Ltd. and CRIM; Stephen Hocking; Jerry Quinn at BNR
## Canada, and Marshal Midden for bringing to our attention problems and
## inadequacies with the first releases. Most special thanks to Bob Weide
## for all his work on prior versions of the dictionary.
##
## We welcome input from users and will continue to acknowledge such input
Expand All @@ -37,12 +37,12 @@
## so keep your eyes open for problems and mail them to me.
##
## We hope this dictionary is an improvement over cmudict.0.4.
##
##
## email: [email protected]
## web: http://www.speech.cs.cmu.edu/cgi-bin/cmudict
## ftp: ftp://ftp.cs.cmu.edu/project/speech/dict/
##
## Thank you for your continued interest in the CMU Pronouncing
## Thank you for your continued interest in the CMU Pronouncing
## Dictionary. Further addictions and improvements are planned
## for forthcoming releases.
##
Expand Down
1 change: 0 additions & 1 deletion text/english.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
from g2p_en import G2p
from string import punctuation

from text import symbols

Expand Down
1 change: 0 additions & 1 deletion text/japanese.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Convert Japanese text to phonemes which is
# compatible with Julius https://github.com/julius-speech/segmentation-kit
import math
import re
import unicodedata

Expand Down
2 changes: 1 addition & 1 deletion text/tone_sandhi.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
[item.isnumeric() for item in word if item != "一"]
):
return finals
# "一" between reduplication words shold be yi5, e.g. 看一看
# "一" between reduplication words should be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
finals[1] = finals[1][:-1] + "5"
# when "一" is ordinal word, it should be yi1
Expand Down
37 changes: 16 additions & 21 deletions train_ms.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# flake8: noqa: E402

import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
Expand Down Expand Up @@ -42,7 +38,7 @@
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(
True
) # Not avaliable if torch version is lower than 2.0
) # Not available if torch version is lower than 2.0
torch.backends.cuda.enable_math_sdp(True)
global_step = 0

Expand Down Expand Up @@ -97,23 +93,20 @@ def run():
)
if (
"use_noise_scaled_mas" in hps.model.keys()
and hps.model.use_noise_scaled_mas == True
and hps.model.use_noise_scaled_mas is True
):
print("Using noise scaled MAS for VITS2")
use_noise_scaled_mas = True
mas_noise_scale_initial = 0.01
noise_scale_delta = 2e-6
else:
print("Using normal MAS for VITS1")
use_noise_scaled_mas = False
mas_noise_scale_initial = 0.0
noise_scale_delta = 0.0
if (
"use_duration_discriminator" in hps.model.keys()
and hps.model.use_duration_discriminator == True
and hps.model.use_duration_discriminator is True
):
print("Using duration discriminator for VITS2")
use_duration_discriminator = True
net_dur_disc = DurationDiscriminator(
hps.model.hidden_channels,
hps.model.hidden_channels,
Expand All @@ -123,16 +116,14 @@ def run():
).cuda(rank)
if (
"use_spk_conditioned_encoder" in hps.model.keys()
and hps.model.use_spk_conditioned_encoder == True
and hps.model.use_spk_conditioned_encoder is True
):
if hps.data.n_speakers == 0:
raise ValueError(
"n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
)
use_spk_conditioned_encoder = True
else:
print("Using normal encoder for VITS1")
use_spk_conditioned_encoder = False

net_g = SynthesizerTrn(
len(symbols),
Expand Down Expand Up @@ -176,19 +167,25 @@ def run():
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
net_dur_disc,
optim_dur_disc,
skip_optimizer=True,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
net_g,
optim_g,
skip_optimizer=True,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
net_d,
optim_d,
skip_optimizer=True,
skip_optimizer=hps.train.skip_optimizer
if "skip_optimizer" in hps.train
else True,
)
if not optim_g.param_groups[0].get("initial_lr"):
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
Expand Down Expand Up @@ -371,9 +368,7 @@ def train_and_evaluate(
optim_dur_disc.zero_grad()
scaler.scale(loss_dur_disc_all).backward()
scaler.unscale_(optim_dur_disc)
grad_norm_dur_disc = commons.clip_grad_value_(
net_dur_disc.parameters(), None
)
commons.clip_grad_value_(net_dur_disc.parameters(), None)
scaler.step(optim_dur_disc)

optim_d.zero_grad()
Expand Down
Loading

0 comments on commit f63acb9

Please sign in to comment.