forked from as-ideas/ForwardTacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_tacotron.py
167 lines (138 loc) · 6.97 KB
/
gen_tacotron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
from models.fatchord_version import WaveRNN
from utils import hparams as hp
from utils.text.symbols import phonemes
from utils.paths import Paths
from models.tacotron import Tacotron
import argparse
from utils.text import text_to_sequence, clean_text
from utils.display import save_attention, simple_table
from utils.dsp import reconstruct_waveform, save_wav
import numpy as np
if __name__ == '__main__':
# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
# name of subcommand goes to args.vocoder
subparsers = parser.add_subparsers(required=True, dest='vocoder')
wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
wr_parser.set_defaults(batched=None)
gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')
args = parser.parse_args()
if args.vocoder in ['griffinlim', 'gl']:
args.vocoder = 'griffinlim'
elif args.vocoder in ['wavernn', 'wr']:
args.vocoder = 'wavernn'
else:
raise argparse.ArgumentError('Must provide a valid vocoder type!')
hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.vocoder == 'wavernn':
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched
batched = args.batched
target = args.target
overlap = args.overlap
input_text = args.input_text
tts_weights = args.tts_weights
save_attn = args.save_attn
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('Using device:', device)
if args.vocoder == 'wavernn':
print('\nInitialising WaveRNN Model...\n')
# Instantiate WaveRNN Model
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)
voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
voc_model.load(voc_load_path)
print('\nInitialising Tacotron Model...\n')
# Instantiate Tacotron Model
tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
num_chars=len(phonemes),
encoder_dims=hp.tts_encoder_dims,
decoder_dims=hp.tts_decoder_dims,
n_mels=hp.num_mels,
fft_bins=hp.num_mels,
postnet_dims=hp.tts_postnet_dims,
encoder_K=hp.tts_encoder_K,
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)
tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
tts_model.load(tts_load_path)
if input_text:
text = clean_text(input_text.strip())
inputs = [text_to_sequence(text)]
else:
with open('sentences.txt') as f:
inputs = [clean_text(l.strip()) for l in f]
inputs = [text_to_sequence(t) for t in inputs]
if args.vocoder == 'wavernn':
voc_k = voc_model.get_step() // 1000
tts_k = tts_model.get_step() // 1000
simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'WaveRNN'),
('WaveRNN', str(voc_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
elif args.vocoder == 'griffinlim':
tts_k = tts_model.get_step() // 1000
simple_table([('Tacotron', str(tts_k) + 'k'),
('r', tts_model.r),
('Vocoder Type', 'Griffin-Lim'),
('GL Iters', args.iters)])
for i, x in enumerate(inputs, 1):
print(f'\n| Generating {i}/{len(inputs)}')
_, m, attention = tts_model.generate(x)
if args.vocoder == 'griffinlim':
v_type = args.vocoder
elif args.vocoder == 'wavernn' and args.batched:
v_type = 'wavernn_batched'
else:
v_type = 'wavernn_unbatched'
if input_text:
save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
else:
save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'
if save_attn: save_attention(attention, save_path)
if args.vocoder == 'wavernn':
m = torch.tensor(m).unsqueeze(0)
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
elif args.vocoder == 'griffinlim':
wav = reconstruct_waveform(m, n_iter=args.iters)
save_wav(wav, save_path)
print('\n\nDone.\n')