diff --git a/README.md b/README.md index 80b134d..05165c1 100644 --- a/README.md +++ b/README.md @@ -8,29 +8,41 @@

A Text-to-Speech Transformer in TensorFlow 2

-Implementation of an autoregressive Transformer based neural network for Text-to-Speech (TTS).
-This repo is based on the following paper: +Implementation of a non-autoregressive Transformer based neural network for Text-to-Speech (TTS).
+This repo is based on the following papers: - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) +- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) Spectrograms produced with LJSpeech and standard data configuration from this repo are compatible with [WaveRNN](https://github.com/fatchord/WaveRNN). +#### Non-Autoregressive +Being non-autoregressive, this Transformer model is: +- Robust: No repeats and failed attention modes for challenging sentences. +- Fast: With no autoregression, predictions take a fraction of the time. +- Controllable: It is possible to control the speed of the generated utterance. + ## πŸ”ˆ Samples [Can be found here.](https://as-ideas.github.io/TransformerTTS/) These samples' spectrograms are converted using the pre-trained [WaveRNN](https://github.com/fatchord/WaveRNN) vocoder.
-The TTS weights used for these samples can be found [here](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts). -Check out the notebooks folder for predictions with TransformerTTS and WaveRNN or just try out our Colab notebook: +Try it out on Colab: -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize.ipynb) +| Version | Colab Link | +|---|---| +| Forward | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize_forward.ipynb) | +Autoregressive | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize_autoregressive.ipynb) | ## πŸ“– Contents - [Installation](#installation) - [Dataset](#dataset) - [Training](#training) + - [Autoregressive](#train-autoregressive-model) + - [Forward](#train-forward-model) - [Prediction](#prediction) +- [Model Weights](#model_weights) ## Installation @@ -69,16 +81,28 @@ Prepare a dataset in the following format: where `metadata.csv` has the following format: ``` wav_file_name|transcription ``` +## Training +### Train Autoregressive Model #### Create training dataset ```bash python create_dataset.py --config config/standard ``` - -## Training +#### Training ```bash -python train.py --config config/standard +python train_autoregressive.py --config config/standard +``` +### Train Forward Model +#### Compute alignment dataset +First use the autoregressive model to create the durations dataset +```bash +python extract_durations.py --config config/standard --binary --fix_jumps --fill_mode_next +``` +this will add an additional folder to the dataset folder containing the new datasets for validation and training of the forward model.
+If the rhythm of the trained model is off, play around with the flags of this script to fix the durations. +#### Training +```bash +python train_forward.py --config /path/to/config_folder/ ``` - #### Training & Model configuration - Training and model settings can be configured in `model_config.yaml` @@ -92,12 +116,15 @@ We log some information that can be visualized with TensorBoard: tensorboard --logdir /logs/directory/ ``` +![Tensorboard Demo](https://raw.githubusercontent.com/as-ideas/TransformerTTS/master/docs/tboard_demo.gif) + ## Prediction +Predict with either the Forward or Autoregressive model ```python from utils.config_manager import ConfigManager from utils.audio import reconstruct_waveform -config_loader = ConfigManager('config/standard') +config_loader = ConfigManager('/path/to/config/', model_kind='forward') model = config_loader.load_model() out = model.predict('Please, say something.') @@ -105,12 +132,18 @@ out = model.predict('Please, say something.') wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config) ``` +## Model Weights +| Model URL | Commit | +|---|---| +|[ljspeech_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_forward_transformer.zip)| 4945e775b| +[ljspeech_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_autoregressive_transformer.zip)| 4945e775b| +|[ljspeech_autoregressive_model_v1](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts)| 2f3a1b5| ## Maintainers * Francesco Cardinale, github: [cfrancesco](https://github.com/cfrancesco) ## Special thanks [WaveRNN](https://github.com/fatchord/WaveRNN): we took the data processing from here and use their vocoder to produce the samples.
-[Erogol](https://github.com/erogol): for the lively exchange on TTS topics.
+[Erogol](https://github.com/erogol) and the Mozilla TTS team for the lively exchange on the topic.
## Copyright See [LICENSE](LICENSE) for details. diff --git a/config/standard/model_config.yaml b/config/standard/autoregressive_config.yaml similarity index 86% rename from config/standard/model_config.yaml rename to config/standard/autoregressive_config.yaml index 1db2ad3..f5336a7 100644 --- a/config/standard/model_config.yaml +++ b/config/standard/autoregressive_config.yaml @@ -20,21 +20,18 @@ stop_loss_scaling: 8 # TRAINING dropout_rate: 0.1 -decoder_dropout_schedule: # dropout scheduling for the decoder status - - [0, 0.54] +decoder_prenet_dropout_schedule: + - [0, 0.] + - [25_000, 0.] + - [35_000, .5] learning_rate_schedule: - [0, 1.0e-4] head_drop_schedule: # head-level dropout: how many heads to set to zero at training time - [0, 0] - [15_000, 1] - - [30_000, 2] - - [70_000, 3] - - [150_000, 1] reduction_factor_schedule: - [0, 10] - - [20_000, 5] - - [50_000, 2] - - [100_000, 1] + - [80_000, 1] max_steps: 900_000 batch_size: 16 debug: False @@ -44,7 +41,7 @@ validation_frequency: 1_000 prediction_frequency: 10_000 weights_save_frequency: 10_000 train_images_plotting_frequency: 1_000 -keep_n_weights: 5 +keep_n_weights: 2 keep_checkpoint_every_n_hours: 12 n_steps_avg_losses: [100, 500, 1_000, 5_000] n_predictions: 2 # autoregressive predictions take time diff --git a/config/standard/forward_config.yaml b/config/standard/forward_config.yaml new file mode 100644 index 0000000..483dd98 --- /dev/null +++ b/config/standard/forward_config.yaml @@ -0,0 +1,45 @@ +# ARCHITECTURE +decoder_model_dimension: 256 +encoder_model_dimension: 512 +decoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers +encoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers +encoder_feed_forward_dimension: 1024 +decoder_feed_forward_dimension: 1024 +decoder_prenet_dimension: 256 +encoder_prenet_dimension: 512 +encoder_attention_conv_filters: 512 +decoder_attention_conv_filters: 512 +encoder_attention_conv_kernel: 3 +decoder_attention_conv_kernel: 3 +encoder_max_position_encoding: 1000 +decoder_max_position_encoding: 10000 +postnet_conv_filters: 256 +postnet_conv_layers: 5 +postnet_kernel_size: 5 +encoder_dense_blocks: 1 +decoder_dense_blocks: 0 + +# TRAINING +dropout_rate: 0.1 +decoder_dropout_schedule: # dropout scheduling for the decoder status + - [0, 0.] +learning_rate_schedule: + - [0, 1.0e-4] +head_drop_schedule: # head-level dropout: how many heads to set to zero at training time + - [0, 0] +max_steps: 400_000 +batch_size: 16 +debug: False + +# LOGGING +validation_frequency: 1_000 +prediction_frequency: 1_000 +weights_save_frequency: 5_000 +train_images_plotting_frequency: 1_000 +keep_n_weights: 5 +keep_checkpoint_every_n_hours: 12 +n_steps_avg_losses: [100, 500, 1_000, 5_000] +n_predictions: 5 +prediction_start_step: 1_000 +audio_start_step: 5_000 +audio_prediction_frequency: 5_000 # converting to glim takes time diff --git a/create_dataset.py b/create_dataset.py index f75deb9..85317ad 100644 --- a/create_dataset.py +++ b/create_dataset.py @@ -20,7 +20,8 @@ for arg in vars(args): print('{}: {}'.format(arg, getattr(args, arg))) yaml = ruamel.yaml.YAML() -config = yaml.load(open(str(Path(args.CONFIG) / 'data_config.yaml'), 'rb')) +with open(str(Path(args.CONFIG) / 'data_config.yaml'), 'rb') as conf_yaml: + config = yaml.load(conf_yaml) args.DATA_DIR = config['data_directory'] args.META_FILE = os.path.join(args.DATA_DIR, config['metadata_filename']) args.WAV_DIR = os.path.join(args.DATA_DIR, config['wav_subdir_name']) diff --git a/docs/index.md b/docs/index.md index a033c73..fa0f0ec 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,17 +8,52 @@

A Text-to-Speech Transformer in TensorFlow 2

+

All samples are converted with the pre-trained WaveRNN vocoder.

-### 🎧 Samples from the autoregressive model, converted with [WaveRNN](https://github.com/fatchord/WaveRNN) vocoder +## 🎧 Model samples + +

President Trump met with other leaders at the Group of twenty conference.

+ +| forward | autoregressive | +|:---:|:---:| +|||

Scientists, at the CERN laboratory, say they have discovered a new particle.

- + +| forward | autoregressive | +|:---:|:---:| +|||

There’s a way to measure the acute emotional intelligence that has never gone out of style.

- -

President Trump met with other leaders at the Group of twenty conference.

- +| forward | autoregressive | +|:---:|:---:| +|||

The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.

- + +| forward | autoregressive | +|:---:|:---:| +||| + + +### Robustness + +

To deliver interfaces that are significantly better suited to create and process RFC eight twenty one , RFC eight twenty two , RFC nine seventy seven , and MIME content.

+ +| forward | autoregressive | +|:---:|:---:| +||| + +### Speed control +

For a while the preacher addresses himself to the congregation at large, who listen attentively.

+ +| 10% slower | normal speed | 25% faster | +|:---:|:---:|:---:| +|||| + +### Comparison with [ForwardTacotron](https://github.com/as-ideas/ForwardTacotron) +

In a statement announcing his resignation, Mr Ross, said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked."

+| ForwardTacotron | TransformerTTS | +|:---:|:---:| +||| diff --git a/docs/tboard_demo.gif b/docs/tboard_demo.gif new file mode 100644 index 0000000..e95fff3 Binary files /dev/null and b/docs/tboard_demo.gif differ diff --git a/extract_durations.py b/extract_durations.py new file mode 100644 index 0000000..192337d --- /dev/null +++ b/extract_durations.py @@ -0,0 +1,208 @@ +import argparse +import traceback +import pickle + +import tensorflow as tf +import numpy as np +from tqdm import tqdm + +from utils.config_manager import ConfigManager +from utils.logging import SummaryManager +from preprocessing.data_handling import load_files, Dataset, DataPrepper +from model.transformer_utils import create_mel_padding_mask +from utils.alignments import get_durations_from_alignment + +# dynamically allocate GPU +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + try: + # Currently, memory growth needs to be the same across GPUs + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.experimental.list_logical_devices('GPU') + print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs') + except Exception: + traceback.print_exc() + +# consuming CLI, creating paths and directories, load data + +parser = argparse.ArgumentParser() +parser.add_argument('--config', dest='config', type=str) +parser.add_argument('--session_name', dest='session_name', default=None) +parser.add_argument('--recompute_pred', dest='recompute_pred', action='store_true') +parser.add_argument('--best', dest='best', action='store_true') +parser.add_argument('--binary', dest='binary', action='store_true') +parser.add_argument('--fix_jumps', dest='fix_jumps', action='store_true') +parser.add_argument('--fill_mode_max', dest='fill_mode_max', action='store_true') +parser.add_argument('--fill_mode_next', dest='fill_mode_next', action='store_true') +parser.add_argument('--use_GT', action='store_true') +args = parser.parse_args() +assert (args.fill_mode_max is False) or (args.fill_mode_next is False), 'Choose one gap filling mode.' +weighted = not args.best +binary = args.binary +fill_gaps = args.fill_mode_max or args.fill_mode_next +fix_jumps = args.fix_jumps +fill_mode = f"{f'max' * args.fill_mode_max}{f'next' * args.fill_mode_next}" +filling_tag = f"{f'(max)' * args.fill_mode_max}{f'(next)' * args.fill_mode_next}" +tag_description = ''.join( + [f'{"_weighted" * weighted}{"_best" * (not weighted)}', + f'{"_binary" * binary}', + f'{"_filled" * fill_gaps}{filling_tag}', + f'{"_fix_jumps" * fix_jumps}']) +writer_tag = f'DurationExtraction{tag_description}' +print(writer_tag) +config_manager = ConfigManager(config_path=args.config, model_kind='autoregressive', session_name=args.session_name) +config = config_manager.config + +meldir = config_manager.train_datadir / 'mels' +target_dir = config_manager.train_datadir / f'forward_data' +train_target_dir = target_dir / 'train' +val_target_dir = target_dir / 'val' +train_predictions_dir = target_dir / f'train_predictions_{config_manager.session_name}' +val_predictions_dir = target_dir / f'val_predictions_{config_manager.session_name}' +target_dir.mkdir(exist_ok=True) +train_target_dir.mkdir(exist_ok=True) +val_target_dir.mkdir(exist_ok=True) +train_predictions_dir.mkdir(exist_ok=True) +val_predictions_dir.mkdir(exist_ok=True) +config_manager.dump_config() +script_batch_size = 5 * config['batch_size'] +val_has_files = len([batch_file for batch_file in val_predictions_dir.iterdir() if batch_file.suffix == '.npy']) +train_has_files = len([batch_file for batch_file in train_predictions_dir.iterdir() if batch_file.suffix == '.npy']) +model = config_manager.load_model() +if args.recompute_pred or (val_has_files == 0) or (train_has_files == 0): + train_meta = config_manager.train_datadir / 'train_metafile.txt' + test_meta = config_manager.train_datadir / 'test_metafile.txt' + train_samples, _ = load_files(metafile=str(train_meta), + meldir=str(meldir), + num_samples=config['n_samples']) # (phonemes, mel) + val_samples, _ = load_files(metafile=str(test_meta), + meldir=str(meldir), + num_samples=config['n_samples']) # (phonemes, text, mel) + + # get model, prepare data for model, create datasets + + data_prep = DataPrepper(config=config, + tokenizer=model.tokenizer) + script_batch_size = 5 * config['batch_size'] # faster parallel computation + train_dataset = Dataset(samples=train_samples, + preprocessor=data_prep, + batch_size=script_batch_size, + shuffle=False, + drop_remainder=False) + val_dataset = Dataset(samples=val_samples, + preprocessor=data_prep, + batch_size=script_batch_size, + shuffle=False, + drop_remainder=False) + if model.r != 1: + print(f"ERROR: model's reduction factor is greater than 1, check config. (r={model.r}") + # identify last decoder block + n_layers = len(config_manager.config['decoder_num_heads']) + n_dense = int(config_manager.config['decoder_dense_blocks']) + n_convs = int(n_layers - n_dense) + if n_convs > 0: + last_layer_key = f'Decoder_ConvBlock{n_convs}_CrossfAttention' + else: + last_layer_key = f'Decoder_DenseBlock{n_dense}_CrossAttention' + print(f'Extracting attention from layer {last_layer_key}') + + iterator = tqdm(enumerate(val_dataset.all_batches())) + for c, (val_mel, val_text, val_stop) in iterator: + iterator.set_description(f'Processing validation set') + outputs = model.val_step(inp=val_text, + tar=val_mel, + stop_prob=val_stop) + if args.use_GT: + batch = (val_mel.numpy(), val_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) + else: + mask = create_mel_padding_mask(val_mel) + out_val = tf.expand_dims(1 - tf.squeeze(create_mel_padding_mask(val_mel[:, 1:, :])), -1) * outputs[ + 'final_output'].numpy() + batch = (out_val.numpy(), val_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) + with open(str(val_predictions_dir / f'{c}_batch_prediction.npy'), 'wb') as file: + pickle.dump(batch, file) + + iterator = tqdm(enumerate(train_dataset.all_batches())) + for c, (train_mel, train_text, train_stop) in iterator: + iterator.set_description(f'Processing training set') + outputs = model.val_step(inp=train_text, + tar=train_mel, + stop_prob=train_stop) + if args.use_GT: + batch = (train_mel.numpy(), train_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) + else: + mask = create_mel_padding_mask(train_mel) + out_train = tf.expand_dims(1 - tf.squeeze(create_mel_padding_mask(train_mel[:, 1:, :])), -1) * outputs[ + 'final_output'].numpy() + batch = (out_train.numpy(), train_text.numpy(), outputs['decoder_attention'][last_layer_key].numpy()) + with open(str(train_predictions_dir / f'{c}_batch_prediction.npy'), 'wb') as file: + pickle.dump(batch, file) + +summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir / writer_tag, config=config, + default_writer=writer_tag) +val_batch_files = [batch_file for batch_file in val_predictions_dir.iterdir() if batch_file.suffix == '.npy'] +iterator = tqdm(enumerate(val_batch_files)) +all_val_durations = np.array([]) +new_alignments = [] +total_val_samples = 0 +for c, batch_file in iterator: + iterator.set_description(f'Extracting validation alignments') + val_mel, val_text, val_alignments = np.load(str(batch_file), allow_pickle=True) + durations, unpad_mels, unpad_phonemes, final_align = get_durations_from_alignment( + batch_alignments=val_alignments, + mels=val_mel, + phonemes=val_text, + weighted=weighted, + binary=binary, + fill_gaps=fill_gaps, + fill_mode=fill_mode, + fix_jumps=fix_jumps) + batch_size = len(val_mel) + for i in range(batch_size): + sample_idx = total_val_samples + i + all_val_durations = np.append(all_val_durations, durations[i]) + new_alignments.append(final_align[i]) + sample = (unpad_mels[i], unpad_phonemes[i], durations[i]) + np.save(str(val_target_dir / f'{sample_idx}_mel_phon_dur.npy'), sample) + total_val_samples += batch_size +all_val_durations[all_val_durations >= 20] = 20 +buckets = len(set(all_val_durations)) +summary_manager.add_histogram(values=all_val_durations, tag='ValidationDurations', buckets=buckets) +for i, alignment in enumerate(new_alignments): + summary_manager.add_image(tag='ExtractedValidationAlignments', + image=tf.expand_dims(tf.expand_dims(alignment, 0), -1), + step=i) + +train_batch_files = [batch_file for batch_file in train_predictions_dir.iterdir() if batch_file.suffix == '.npy'] +iterator = tqdm(enumerate(train_batch_files)) +all_train_durations = np.array([]) +new_alignments = [] +total_train_samples = 0 +for c, batch_file in iterator: + iterator.set_description(f'Extracting training alignments') + train_mel, train_text, train_alignments = np.load(str(batch_file), allow_pickle=True) + durations, unpad_mels, unpad_phonemes, final_align = get_durations_from_alignment( + batch_alignments=train_alignments, + mels=train_mel, + phonemes=train_text, + weighted=weighted, + binary=binary, + fill_gaps=fill_gaps, + fill_mode=fill_mode, + fix_jumps=fix_jumps) + batch_size = len(train_mel) + for i in range(batch_size): + sample_idx = total_train_samples + i + sample = (unpad_mels[i], unpad_phonemes[i], durations[i]) + new_alignments.append(final_align[i]) + all_train_durations = np.append(all_train_durations, durations[i]) + np.save(str(train_target_dir / f'{sample_idx}_mel_phon_dur.npy'), sample) + total_train_samples += batch_size +all_train_durations[all_train_durations >= 20] = 20 +buckets = len(set(all_train_durations)) +summary_manager.add_histogram(values=all_train_durations, tag='TrainDurations', buckets=buckets) +for i, alignment in enumerate(new_alignments): + summary_manager.add_image(tag='ExtractedTrainingAlignments', image=tf.expand_dims(tf.expand_dims(alignment, 0), -1), + step=i) +print('Done.') diff --git a/model/layers.py b/model/layers.py index d2d8381..d72336d 100644 --- a/model/layers.py +++ b/model/layers.py @@ -3,57 +3,71 @@ from model.transformer_utils import positional_encoding, scaled_dot_product_attention -class PointWiseFFN(tf.keras.layers.Layer): - - def __init__(self, model_dim: int, dense_hidden_units: int, **kwargs): - super(PointWiseFFN, self).__init__(**kwargs) - self.d1 = tf.keras.layers.Dense(dense_hidden_units, - activation='relu') # (batch_size, seq_len, dense_hidden_units) - self.d2 = tf.keras.layers.Dense(model_dim) # (batch_size, seq_len, model_dim) - - def call(self, x): - x = self.d1(x) - x = self.d2(x) +class CNNResNorm(tf.keras.layers.Layer): + def __init__(self, + out_size: int, + n_layers: int, + hidden_size: int, + kernel_size: int, + inner_activation: str, + last_activation: str, + padding: str, + normalization: str, + **kwargs): + super(CNNResNorm, self).__init__(**kwargs) + self.convolutions = [tf.keras.layers.Conv1D(filters=hidden_size, + kernel_size=kernel_size, + padding=padding) + for _ in range(n_layers - 1)] + self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(n_layers - 1)] + self.last_conv = tf.keras.layers.Conv1D(filters=out_size, + kernel_size=kernel_size, + padding=padding) + self.last_activation = tf.keras.layers.Activation(last_activation) + if normalization == 'layer': + self.normalization = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(n_layers + 1)] + elif normalization == 'batch': + self.normalization = [tf.keras.layers.BatchNormalization() for _ in range(n_layers + 1)] + else: + assert False is True, f'normalization must be either "layer" or "batch", not {normalization}.' + + def call_convs(self, x, training): + for i in range(0, len(self.convolutions)): + x = self.convolutions[i](x) + x = self.normalization[i](x, training=training) + x = self.inner_activations[i](x) return x + + def call(self, inputs, training): + x = self.call_convs(inputs, training=training) + x = self.last_conv(x) + x = self.normalization[-2](x, training=training) + x = self.last_activation(x) + return self.normalization[-1](inputs + x) class FFNResNorm(tf.keras.layers.Layer): - def __init__(self, model_dim: int, dense_hidden_units: int, dropout_rate: float = 0.1, **kwargs): - super(FFNResNorm, self).__init__(**kwargs) - self.ffn = PointWiseFFN(model_dim, dense_hidden_units) - self.dropout = tf.keras.layers.Dropout(dropout_rate) - self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) - - def call(self, x, training): - ffn_out = self.ffn(x) # (batch_size, input_seq_len, model_dim) - ffn_out = self.dropout(ffn_out, training=training) - out = self.ln(x + ffn_out) # (batch_size, input_seq_len, model_dim) - - return out - - -class Conv1DResNorm(tf.keras.layers.Layer): def __init__(self, model_dim: int, + dense_hidden_units: int, dropout_rate: float, - kernel_size: int = 5, - conv_padding: str = 'same', - activation: str = 'relu', **kwargs): - super(Conv1DResNorm, self).__init__(**kwargs) - self.conv = tf.keras.layers.Conv1D(filters=model_dim, - kernel_size=kernel_size, - padding=conv_padding, - activation=activation) + super(FFNResNorm, self).__init__(**kwargs) + self.d1 = tf.keras.layers.Dense(dense_hidden_units) + self.activation = tf.keras.layers.Activation('relu') + self.d2 = tf.keras.layers.Dense(model_dim) self.dropout = tf.keras.layers.Dropout(dropout_rate) - self.layer_norm = tf.keras.layers.LayerNormalization() + self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) def call(self, x, training): - convs = self.conv(x) - convs = self.dropout(convs, training=training) - res_norm = self.layer_norm(x + convs) - return res_norm + ffn_out = self.d1(x) + ffn_out = self.d2(ffn_out) # (batch_size, input_seq_len, model_dim) + ffn_out = self.ln(ffn_out) # (batch_size, input_seq_len, model_dim) + ffn_out = self.activation(ffn_out) + ffn_out = self.dropout(ffn_out, training=training) + return self.last_ln(ffn_out + x) class HeadDrop(tf.keras.layers.Layer): @@ -134,26 +148,36 @@ def call(self, v, k, q_in, mask, training, drop_n_heads): class SelfAttentionResNorm(tf.keras.layers.Layer): - def __init__(self, model_dim: int, num_heads: int, dropout_rate: float, **kwargs): + def __init__(self, + model_dim: int, + num_heads: int, + dropout_rate: float, + **kwargs): super(SelfAttentionResNorm, self).__init__(**kwargs) self.mha = MultiHeadAttention(model_dim, num_heads) self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.dropout = tf.keras.layers.Dropout(dropout_rate) + self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) def call(self, x, training, mask, drop_n_heads): attn_out, attn_weights = self.mha(x, x, x, mask, training=training, drop_n_heads=drop_n_heads) # (batch_size, input_seq_len, model_dim) - attn_out = self.dropout(attn_out, training=training) - out = self.ln(x + attn_out) # (batch_size, input_seq_len, model_dim) - return out, attn_weights + attn_out = self.ln(attn_out) # (batch_size, input_seq_len, model_dim) + out = self.dropout(attn_out, training=training) + return self.last_ln(out + x), attn_weights class SelfAttentionDenseBlock(tf.keras.layers.Layer): - def __init__(self, model_dim: int, num_heads: int, dense_hidden_units: int, dropout_rate: float = 0.1, **kwargs): + def __init__(self, + model_dim: int, + num_heads: int, + dense_hidden_units: int, + dropout_rate: float, + **kwargs): super(SelfAttentionDenseBlock, self).__init__(**kwargs) self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) - self.ffn = FFNResNorm(model_dim, dense_hidden_units) + self.ffn = FFNResNorm(model_dim, dense_hidden_units, dropout_rate=dropout_rate) def call(self, x, training, mask, drop_n_heads): attn_out, attn_weights = self.sarn(x, mask=mask, training=training, drop_n_heads=drop_n_heads) @@ -166,18 +190,25 @@ def __init__(self, model_dim: int, num_heads: int, dropout_rate: float, - kernel_size: int = 5, - conv_padding: str = 'same', - conv_activation: str = 'relu', + conv_filters: int, + kernel_size: int, + conv_activation: str, **kwargs): super(SelfAttentionConvBlock, self).__init__(**kwargs) self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) - self.conv = Conv1DResNorm(model_dim=model_dim, dropout_rate=dropout_rate, kernel_size=kernel_size, - conv_padding=conv_padding, activation=conv_activation) + self.conv = CNNResNorm(out_size=model_dim, + n_layers=2, + hidden_size=conv_filters, + kernel_size=kernel_size, + inner_activation=conv_activation, + last_activation=conv_activation, + padding='same', + normalization='batch') def call(self, x, training, mask, drop_n_heads): attn_out, attn_weights = self.sarn(x, mask=mask, training=training, drop_n_heads=drop_n_heads) - return self.conv(attn_out, training=training), attn_weights + conv = self.conv(attn_out) + return conv, attn_weights class SelfAttentionBlocks(tf.keras.layers.Layer): @@ -186,8 +217,11 @@ def __init__(self, feed_forward_dimension: int, num_heads: list, maximum_position_encoding: int, - dropout_rate=0.1, - dense_blocks=1, + conv_filters: int, + dropout_rate: float, + dense_blocks: int, + kernel_size: int, + conv_activation: str, **kwargs): super(SelfAttentionBlocks, self).__init__(**kwargs) self.model_dim = model_dim @@ -200,7 +234,8 @@ def __init__(self, for i, n_heads in enumerate(num_heads[:dense_blocks])] self.encoder_SACB = [ SelfAttentionConvBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads, - name=f'{self.name}_SACB_{i}') + name=f'{self.name}_SACB_{i}', kernel_size=kernel_size, + conv_activation=conv_activation, conv_filters=conv_filters) for i, n_heads in enumerate(num_heads[dense_blocks:])] def call(self, inputs, training, padding_mask, drop_n_heads, reduction_factor=1): @@ -221,7 +256,11 @@ def call(self, inputs, training, padding_mask, drop_n_heads, reduction_factor=1) class CrossAttentionResnorm(tf.keras.layers.Layer): - def __init__(self, model_dim: int, num_heads: int, dropout_rate: float = 0.1, **kwargs): + def __init__(self, + model_dim: int, + num_heads: int, + dropout_rate: float, + **kwargs): super(CrossAttentionResnorm, self).__init__(**kwargs) self.mha = MultiHeadAttention(model_dim, num_heads) self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6) @@ -236,7 +275,12 @@ def call(self, q, k, v, training, mask, drop_n_heads): class CrossAttentionDenseBlock(tf.keras.layers.Layer): - def __init__(self, model_dim: int, num_heads: int, dense_hidden_units: int, dropout_rate: float = 0.1, **kwargs): + def __init__(self, + model_dim: int, + num_heads: int, + dense_hidden_units: int, + dropout_rate: float, + **kwargs): super(CrossAttentionDenseBlock, self).__init__(**kwargs) self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate) @@ -256,16 +300,23 @@ class CrossAttentionConvBlock(tf.keras.layers.Layer): def __init__(self, model_dim: int, num_heads: int, - dropout_rate: float = 0.1, - kernel_size: int = 5, - conv_padding: str = 'same', - conv_activation: str = 'relu', + conv_filters: int, + dropout_rate: float, + kernel_size: int, + conv_padding: str, + conv_activation: str, **kwargs): super(CrossAttentionConvBlock, self).__init__(**kwargs) self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate) - self.conv = Conv1DResNorm(model_dim=model_dim, dropout_rate=dropout_rate, kernel_size=kernel_size, - conv_padding=conv_padding, activation=conv_activation) + self.conv = CNNResNorm(out_size=model_dim, + n_layers=2, + hidden_size=conv_filters, + kernel_size=kernel_size, + inner_activation=conv_activation, + last_activation=conv_activation, + padding=conv_padding, + normalization='batch') def call(self, x, enc_output, training, look_ahead_mask, padding_mask, drop_n_heads): attn1, attn_weights_block1 = self.sarn(x, mask=look_ahead_mask, training=training, drop_n_heads=drop_n_heads) @@ -283,8 +334,12 @@ def __init__(self, feed_forward_dimension: int, num_heads: list, maximum_position_encoding: int, - dropout_rate=0.1, - dense_blocks=1, + dropout_rate: float, + dense_blocks: int, + conv_filters: int, + conv_activation: str, + conv_padding: str, + conv_kernel: int, **kwargs): super(CrossAttentionBlocks, self).__init__(**kwargs) self.model_dim = model_dim @@ -297,7 +352,8 @@ def __init__(self, for i, n_heads in enumerate(num_heads[:dense_blocks])] self.CACB = [ CrossAttentionConvBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads, - name=f'{self.name}_CACB_{i}') + name=f'{self.name}_CACB_{i}', conv_filters=conv_filters, + conv_activation=conv_activation, conv_padding=conv_padding, kernel_size=conv_kernel) for i, n_heads in enumerate(num_heads[dense_blocks:])] def call(self, inputs, enc_output, training, decoder_padding_mask, encoder_padding_mask, drop_n_heads, @@ -321,75 +377,123 @@ def call(self, inputs, enc_output, training, decoder_padding_mask, encoder_paddi class DecoderPrenet(tf.keras.layers.Layer): - def __init__(self, model_dim: int, dense_hidden_units: int, dropout_rate: float = 0.5, **kwargs): + def __init__(self, + model_dim: int, + dense_hidden_units: int, + dropout_rate: float, + **kwargs): super(DecoderPrenet, self).__init__(**kwargs) self.d1 = tf.keras.layers.Dense(dense_hidden_units, activation='relu') # (batch_size, seq_len, dense_hidden_units) self.d2 = tf.keras.layers.Dense(model_dim, activation='relu') # (batch_size, seq_len, model_dim) - self.dropout_1 = tf.keras.layers.Dropout(dropout_rate) - self.dropout_2 = tf.keras.layers.Dropout(dropout_rate) + self.rate = tf.Variable(dropout_rate, trainable=False) + self.dropout_1 = tf.keras.layers.Dropout(self.rate) + self.dropout_2 = tf.keras.layers.Dropout(self.rate) - def call(self, x, dropout_rate: float = 0.5): - self.dropout_1.dropout_rate = dropout_rate - self.dropout_2.dropout_rate = dropout_rate + def call(self, x): + self.dropout_1.rate = self.rate + self.dropout_2.rate = self.rate x = self.d1(x) - # use dropout also in inference for additional noise as suggested in the original tacotron2 paper + # use dropout also in inference for positional encoding relevance x = self.dropout_1(x, training=True) x = self.d2(x) x = self.dropout_2(x, training=True) return x -class ConvBatchNormBlock(tf.keras.layers.Layer): - - def __init__(self, out_size: int, n_filters: int = 256, n_layers: int = 5, kernel_size: int = 5, - dropout_prob: float = 0.5, padding='causal', inner_activation='tanh', - last_activation='linear', **kwargs): - super(ConvBatchNormBlock, self).__init__(**kwargs) - self.convolutions = [tf.keras.layers.Conv1D(filters=n_filters, - kernel_size=kernel_size, - padding=padding, - activation=inner_activation) - for _ in range(n_layers - 1)] - self.dropouts = [tf.keras.layers.Dropout(dropout_prob) for _ in range(n_layers - 1)] - self.last_conv = tf.keras.layers.Conv1D(filters=out_size, - kernel_size=kernel_size, - padding=padding, - activation=last_activation) - self.batch_norms = [tf.keras.layers.BatchNormalization() for _ in range(n_layers)] - - def call(self, x, training): - for i in range(0, len(self.convolutions)): - x = self.convolutions[i](x) - x = self.batch_norms[i](x, training=training) - x = self.dropouts[i](x, training=training) - x = self.last_conv(x) - x = self.batch_norms[-1](x, training=training) - return x - - class Postnet(tf.keras.layers.Layer): - def __init__(self, mel_channels: int, conv_filters: int = 256, conv_layers: int = 5, kernel_size: int = 5, + def __init__(self, mel_channels: int, + conv_filters: int, + conv_layers: int, + kernel_size: int, **kwargs): super(Postnet, self).__init__(**kwargs) self.mel_channels = mel_channels self.stop_linear = tf.keras.layers.Dense(3) - self.postnet_conv_layers = ConvBatchNormBlock( - out_size=mel_channels, n_filters=conv_filters, n_layers=conv_layers, kernel_size=kernel_size - ) + self.conv_blocks = CNNResNorm(out_size=mel_channels, + kernel_size=kernel_size, + padding='causal', + inner_activation='tanh', + last_activation='linear', + hidden_size=conv_filters, + n_layers=conv_layers, + normalization='batch') self.add_layer = tf.keras.layers.Add() def call(self, x, training): stop = self.stop_linear(x) - conv_out = self.conv_net(x, training=training) + conv_out = self.conv_blocks(x, training=training) return { 'mel_linear': x, 'final_output': conv_out, 'stop_prob': stop, } + + +class DurationPredictor(tf.keras.layers.Layer): + def __init__(self, + model_dim: int, + kernel_size: int, + conv_padding: str, + conv_activation: str, + conv_block_n: int, + dense_activation: str, + **kwargs): + super(DurationPredictor, self).__init__(**kwargs) + self.conv_blocks = CNNResNorm(out_size=model_dim, + kernel_size=kernel_size, + padding=conv_padding, + inner_activation=conv_activation, + last_activation=conv_activation, + hidden_size=model_dim, + n_layers=conv_block_n, + normalization='layer') + self.linear = tf.keras.layers.Dense(1, activation=dense_activation, + bias_initializer=tf.keras.initializers.Constant(value=1)) - def conv_net(self, x, *, training): - conv_out = self.postnet_conv_layers(x, training) - x = self.add_layer([conv_out, x]) - return x \ No newline at end of file + def call(self, x, training): + x = self.conv_blocks(x, training=training) + x = self.linear(x) + return x + + +class Expand(tf.keras.layers.Layer): + """ Expands a 3D tensor on its second axis given a list of dimensions. + Tensor should be: + batch_size, seq_len, dimension + + E.g: + input = tf.Tensor([[[0.54710746 0.8943467 ] + [0.7140938 0.97968304] + [0.5347662 0.15213418]]], shape=(1, 3, 2), dtype=float32) + dimensions = tf.Tensor([1 3 2], shape=(3,), dtype=int32) + output = tf.Tensor([[[0.54710746 0.8943467 ] + [0.7140938 0.97968304] + [0.7140938 0.97968304] + [0.7140938 0.97968304] + [0.5347662 0.15213418] + [0.5347662 0.15213418]]], shape=(1, 6, 2), dtype=float32) + """ + + def __init__(self, model_dim, **kwargs): + super(Expand, self).__init__(**kwargs) + self.model_dimension = model_dim + + def call(self, x, dimensions): + dimensions = tf.squeeze(dimensions, axis=-1) + dimensions = tf.cast(tf.math.round(dimensions), tf.int32) + seq_len = tf.shape(x)[1] + batch_size = tf.shape(x)[0] + # build masks from dimensions + max_dim = tf.math.reduce_max(dimensions) + tot_dim = tf.math.reduce_sum(dimensions) + index_masks = tf.RaggedTensor.from_row_lengths(tf.ones(tot_dim), tf.reshape(dimensions, [-1])).to_tensor() + index_masks = tf.cast(tf.reshape(index_masks, (batch_size, seq_len * max_dim)), tf.float32) + non_zeros = seq_len * max_dim - tf.reduce_sum(max_dim - dimensions, axis=1) + # stack and mask + tiled = tf.tile(x, [1, 1, max_dim]) + reshaped = tf.reshape(tiled, (batch_size, seq_len * max_dim, self.model_dimension)) + mask_reshape = tf.multiply(reshaped, index_masks[:, :, tf.newaxis]) + ragged = tf.RaggedTensor.from_row_lengths(mask_reshape[index_masks > 0], non_zeros) + return ragged.to_tensor() diff --git a/model/models.py b/model/models.py index 6bc7b66..059729a 100644 --- a/model/models.py +++ b/model/models.py @@ -8,19 +8,16 @@ from utils.losses import masked_mean_absolute_error, new_scaled_crossentropy from preprocessing.data_handling import Tokenizer from preprocessing.text_processing import _phonemes, Phonemizer, _punctuations -from model.layers import SelfAttentionBlocks, CrossAttentionBlocks +from model.layers import DurationPredictor, Expand, SelfAttentionBlocks, CrossAttentionBlocks, CNNResNorm class AutoregressiveTransformer(tf.keras.models.Model): def __init__(self, - mel_channels: int, encoder_model_dimension: int, decoder_model_dimension: int, encoder_num_heads: list, decoder_num_heads: list, - encoder_feed_forward_dimension: int, - decoder_feed_forward_dimension: int, encoder_maximum_position_encoding: int, decoder_maximum_position_encoding: int, encoder_dense_blocks: int, @@ -33,9 +30,16 @@ def __init__(self, dropout_rate: float, mel_start_value: int, mel_end_value: int, + mel_channels: int, + phoneme_language: str, + encoder_attention_conv_filters: int = None, + decoder_attention_conv_filters: int = None, + encoder_attention_conv_kernel: int = None, + decoder_attention_conv_kernel: int = None, + encoder_feed_forward_dimension: int = None, + decoder_feed_forward_dimension: int = None, + decoder_prenet_dropout=0.5, max_r: int = 10, - phoneme_language: str = 'en', - decoder_prenet_dropout=0., debug=False, **kwargs): super(AutoregressiveTransformer, self).__init__(**kwargs) @@ -45,7 +49,6 @@ def __init__(self, self.max_r = max_r self.r = max_r self.mel_channels = mel_channels - self.decoder_prenet_dropout = decoder_prenet_dropout self.drop_n_heads = 0 self.tokenizer = Tokenizer(sorted(list(_phonemes) + list(_punctuations)), add_start_end=True) @@ -58,9 +61,13 @@ def __init__(self, feed_forward_dimension=encoder_feed_forward_dimension, maximum_position_encoding=encoder_maximum_position_encoding, dense_blocks=encoder_dense_blocks, + conv_filters=encoder_attention_conv_filters, + kernel_size=encoder_attention_conv_kernel, + conv_activation='relu', name='Encoder') self.decoder_prenet = DecoderPrenet(model_dim=decoder_model_dimension, dense_hidden_units=decoder_prenet_dimension, + dropout_rate=decoder_prenet_dropout, name='DecoderPrenet') self.decoder = CrossAttentionBlocks(model_dim=decoder_model_dimension, dropout_rate=dropout_rate, @@ -68,6 +75,10 @@ def __init__(self, feed_forward_dimension=decoder_feed_forward_dimension, maximum_position_encoding=decoder_maximum_position_encoding, dense_blocks=decoder_dense_blocks, + conv_filters=decoder_attention_conv_filters, + conv_kernel=decoder_attention_conv_kernel, + conv_activation='relu', + conv_padding='causal', name='Decoder') self.final_proj_mel = tf.keras.layers.Dense(self.mel_channels * self.max_r, name='FinalProj') self.decoder_postnet = Postnet(mel_channels=mel_channels, @@ -83,7 +94,7 @@ def __init__(self, ] self.forward_input_signature = [ tf.TensorSpec(shape=(None, None), dtype=tf.int32), - tf.TensorSpec(shape=(None, None, mel_channels), dtype=tf.float32) + tf.TensorSpec(shape=(None, None, mel_channels), dtype=tf.float32), ] self.encoder_signature = [ tf.TensorSpec(shape=(None, None), dtype=tf.int32) @@ -94,25 +105,25 @@ def __init__(self, tf.TensorSpec(shape=(None, None, None, None), dtype=tf.float32), ] self.debug = debug - self.__apply_all_signatures() + self._apply_all_signatures() @property def step(self): return int(self.optimizer.iterations) - def __apply_all_signatures(self): - self.forward = self.__apply_signature(self._forward, self.forward_input_signature) - self.train_step = self.__apply_signature(self._train_step, self.training_input_signature) - self.val_step = self.__apply_signature(self._val_step, self.training_input_signature) - self.forward_encoder = self.__apply_signature(self._forward_encoder, self.encoder_signature) - self.forward_decoder = self.__apply_signature(self._forward_decoder, self.decoder_signature) - - def __apply_signature(self, function, signature): + def _apply_signature(self, function, signature): if self.debug: return function else: return tf.function(input_signature=signature)(function) + def _apply_all_signatures(self): + self.forward = self._apply_signature(self._forward, self.forward_input_signature) + self.train_step = self._apply_signature(self._train_step, self.training_input_signature) + self.val_step = self._apply_signature(self._val_step, self.training_input_signature) + self.forward_encoder = self._apply_signature(self._forward_encoder, self.encoder_signature) + self.forward_decoder = self._apply_signature(self._forward_decoder, self.decoder_signature) + def _call_encoder(self, inputs, training): padding_mask = create_encoder_padding_mask(inputs) enc_input = self.encoder_prenet(inputs) @@ -126,7 +137,7 @@ def _call_decoder(self, encoder_output, targets, encoder_padding_mask, training) dec_target_padding_mask = create_mel_padding_mask(targets) look_ahead_mask = create_look_ahead_mask(tf.shape(targets)[1]) combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) - dec_input = self.decoder_prenet(targets, training=training, dropout_rate=self.decoder_prenet_dropout) + dec_input = self.decoder_prenet(targets) dec_output, attention_weights = self.decoder(inputs=dec_input, enc_output=encoder_output, training=training, @@ -202,7 +213,7 @@ def _set_r(self, r): if self.r == r: return self.r = r - self.__apply_all_signatures() + self._apply_all_signatures() def call(self, inputs, targets, training): encoder_output, padding_mask, encoder_attention = self._call_encoder(inputs, training) @@ -238,7 +249,7 @@ def predict(self, inp, max_length=1000, encode=True, verbose=True): def set_constants(self, decoder_prenet_dropout: float = None, learning_rate: float = None, reduction_factor: float = None, drop_n_heads: int = None): if decoder_prenet_dropout is not None: - self.decoder_prenet_dropout = decoder_prenet_dropout + self.decoder_prenet.rate.assign(decoder_prenet_dropout) if learning_rate is not None: self.optimizer.lr.assign(learning_rate) if reduction_factor is not None: @@ -249,3 +260,192 @@ def set_constants(self, decoder_prenet_dropout: float = None, learning_rate: flo def encode_text(self, text): phons = self.phonemizer.encode(text, clean=True) return self.tokenizer.encode(phons) + + +class ForwardTransformer(tf.keras.models.Model): + def __init__(self, + encoder_model_dimension: int, + decoder_model_dimension: int, + dropout_rate: float, + decoder_num_heads: list, + encoder_num_heads: list, + encoder_maximum_position_encoding: int, + decoder_maximum_position_encoding: int, + postnet_conv_filters: int, + postnet_conv_layers: int, + postnet_kernel_size: int, + encoder_dense_blocks: int, + decoder_dense_blocks: int, + mel_channels: int, + phoneme_language: str, + encoder_attention_conv_filters: int = None, + decoder_attention_conv_filters: int = None, + encoder_attention_conv_kernel: int = None, + decoder_attention_conv_kernel: int = None, + encoder_feed_forward_dimension: int = None, + decoder_feed_forward_dimension: int = None, + debug=False, + decoder_prenet_dropout=0., + **kwargs): + super(ForwardTransformer, self).__init__(**kwargs) + self.tokenizer = Tokenizer(sorted(list(_phonemes) + list(_punctuations)), add_start_end=False) + self.phonemizer = Phonemizer(language=phoneme_language) + self.drop_n_heads = 0 + self.mel_channels = mel_channels + self.encoder_prenet = tf.keras.layers.Embedding(self.tokenizer.vocab_size, encoder_model_dimension, + name='Embedding') + self.encoder = SelfAttentionBlocks(model_dim=encoder_model_dimension, + dropout_rate=dropout_rate, + num_heads=encoder_num_heads, + feed_forward_dimension=encoder_feed_forward_dimension, + maximum_position_encoding=encoder_maximum_position_encoding, + dense_blocks=encoder_dense_blocks, + conv_filters=encoder_attention_conv_filters, + kernel_size=encoder_attention_conv_kernel, + conv_activation='relu', + name='Encoder') + self.dur_pred = DurationPredictor(model_dim=encoder_model_dimension, + kernel_size=3, + conv_padding='same', + conv_activation='relu', + conv_block_n=2, + dense_activation='relu', + name='dur_pred') + self.expand = Expand(name='expand', model_dim=encoder_model_dimension) + self.decoder_prenet = DecoderPrenet(model_dim=decoder_model_dimension, + dense_hidden_units=decoder_feed_forward_dimension, + dropout_rate=decoder_prenet_dropout, + name='DecoderPrenet') + self.decoder = SelfAttentionBlocks(model_dim=decoder_model_dimension, + dropout_rate=dropout_rate, + num_heads=decoder_num_heads, + feed_forward_dimension=decoder_feed_forward_dimension, + maximum_position_encoding=decoder_maximum_position_encoding, + dense_blocks=decoder_dense_blocks, + conv_filters=decoder_attention_conv_filters, + kernel_size=decoder_attention_conv_kernel, + conv_activation='relu', + name='Decoder') + self.out = tf.keras.layers.Dense(mel_channels) + self.decoder_postnet = CNNResNorm(out_size=mel_channels, + kernel_size=postnet_kernel_size, + padding='same', + inner_activation='tanh', + last_activation='linear', + hidden_size=postnet_conv_filters, + n_layers=postnet_conv_layers, + normalization='batch', + name='Postnet') + self.training_input_signature = [ + tf.TensorSpec(shape=(None, None), dtype=tf.int32), + tf.TensorSpec(shape=(None, None, mel_channels), dtype=tf.float32), + tf.TensorSpec(shape=(None, None), dtype=tf.int32) + ] + self.forward_input_signature = [ + tf.TensorSpec(shape=(None, None), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.float32), + ] + self.debug = debug + self._apply_all_signatures() + + def _apply_signature(self, function, signature): + if self.debug: + return function + else: + return tf.function(input_signature=signature)(function) + + def _apply_all_signatures(self): + self.forward = self._apply_signature(self._forward, self.forward_input_signature) + self.train_step = self._apply_signature(self._train_step, self.training_input_signature) + self.val_step = self._apply_signature(self._val_step, self.training_input_signature) + + def _train_step(self, input_sequence, target_sequence, target_durations): + target_durations = tf.expand_dims(target_durations, -1) + mel_len = int(tf.shape(target_sequence)[1]) + with tf.GradientTape() as tape: + model_out = self.__call__(input_sequence, target_durations, training=True) + loss, loss_vals = weighted_sum_losses((target_sequence, + target_durations), + (model_out['mel'][:, :mel_len, :], + model_out['duration']), + self.loss, + self.loss_weights) + model_out.update({'loss': loss}) + model_out.update({'losses': {'mel': loss_vals[0], 'duration': loss_vals[1]}}) + gradients = tape.gradient(model_out['loss'], self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + return model_out + + def _compile(self, optimizer): + self.loss_weights = [3., 1.] + self.compile(loss=[masked_mean_absolute_error, + masked_mean_absolute_error], + loss_weights=self.loss_weights, + optimizer=optimizer) + + def _val_step(self, input_sequence, target_sequence, target_durations): + target_durations = tf.expand_dims(target_durations, -1) + mel_len = int(tf.shape(target_sequence)[1]) + model_out = self.__call__(input_sequence, target_durations, training=False) + loss, loss_vals = weighted_sum_losses((target_sequence, + target_durations), + (model_out['mel'][:, :mel_len, :], + model_out['duration']), + self.loss, + self.loss_weights) + model_out.update({'loss': loss}) + model_out.update({'losses': {'mel': loss_vals[0], 'duration': loss_vals[1]}}) + return model_out + + def _forward(self, input_sequence, durations_scalar): + return self.__call__(input_sequence, target_durations=None, training=False, durations_scalar=durations_scalar) + + @property + def step(self): + return int(self.optimizer.iterations) + + def call(self, x, target_durations, training, durations_scalar=1.): + padding_mask = create_encoder_padding_mask(x) + x = self.encoder_prenet(x) + x, encoder_attention = self.encoder(x, training=training, padding_mask=padding_mask, + drop_n_heads=self.drop_n_heads) + durations = self.dur_pred(x, training=training) * durations_scalar + durations = (1. - tf.reshape(padding_mask, tf.shape(durations))) * durations + if target_durations is not None: + mels = self.expand(x, target_durations) + else: + mels = self.expand(x, durations) + expanded_mask = create_mel_padding_mask(mels) + mels = self.decoder_prenet(mels) + mels, decoder_attention = self.decoder(mels, training=training, padding_mask=expanded_mask, + drop_n_heads=self.drop_n_heads, reduction_factor=1) + mels = self.out(mels) + mels = self.decoder_postnet(mels, training=training) + model_out = {'mel': mels, + 'duration': durations, + 'expanded_mask': expanded_mask, + 'encoder_attention': encoder_attention, + 'decoder_attention': decoder_attention} + return model_out + + def set_constants(self, decoder_prenet_dropout: float = None, learning_rate: float = None, + drop_n_heads: int = None, **kwargs): + if decoder_prenet_dropout is not None: + self.decoder_prenet.rate.assign(decoder_prenet_dropout) + if learning_rate is not None: + self.optimizer.lr.assign(learning_rate) + if drop_n_heads is not None: + self.drop_n_heads = drop_n_heads + + def encode_text(self, text): + phons = self.phonemizer.encode(text, clean=True) + return self.tokenizer.encode(phons) + + def predict(self, inp, encode=True, speed_regulator=1.): + if encode: + inp = self.encode_text(inp) + inp = tf.cast(tf.expand_dims(inp, 0), tf.int32) + duration_scalar = tf.cast(1. / speed_regulator, tf.float32) + out = self.forward(inp, durations_scalar=duration_scalar) + out['mel'] = tf.squeeze(out['mel']) + return out diff --git a/notebooks/Prediction for WaveRNN.ipynb b/notebooks/Prediction for WaveRNN.ipynb index 770b58c..e4fb4bd 100644 --- a/notebooks/Prediction for WaveRNN.ipynb +++ b/notebooks/Prediction for WaveRNN.ipynb @@ -34,12 +34,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING: could not check git hash. 'git_hash'\n" + "WARNING: git hash mismatch. Current: c9917ac. Config hash: 1fec787\n", + "restored weights from /Users/fcardina/logs/TTS4wavernn/NSSTN_paper_62e2566_W_B_FN_FJ_with_postnet/forward_weights/ckpt-124 at step 620000\n" ] } ], "source": [ - "config_loader = ConfigManager(config_path)\n", + "config_loader = ConfigManager(config_path, model_kind='forward')\n", "model = config_loader.load_model()" ] }, @@ -49,15 +50,7 @@ "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pred text mel: 399 stop out: -5.445310592651367Stopping\n" - ] - } - ], + "outputs": [], "source": [ "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", "out = model.predict(sentence)" @@ -73,7 +66,7 @@ "text/html": [ "\n", " \n", " " diff --git a/notebooks/WaveRNN prediction.ipynb b/notebooks/WaveRNN prediction.ipynb index 14a15e4..3ad8aff 100644 --- a/notebooks/WaveRNN prediction.ipynb +++ b/notebooks/WaveRNN prediction.ipynb @@ -70,26 +70,20 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], - "source": [ - "mel = np.load(WaveRNN_path / 'scientists.npy')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "| β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 27400/27500 | Batch Size: 1 | Gen Rate: 1.2kHz | " + "| β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 109900/110000 | Batch Size: 1 | Gen Rate: 0.8kHz | " ] } ], "source": [ - "_ = model.generate(mel[np.newaxis,:,:], 'scientists.wav', False, 1, hp.voc_overlap, hp.mu_law)" + "file_name = Path('scientists.npy')\n", + "mel = np.load(WaveRNN_path / file_name)\n", + "batch_pred = True # False is slower but possibly better\n", + "_ = model.generate(mel.clip(0,1)[np.newaxis,:,:], file_name.stem + '.wav', batch_pred, 110_000, hp.voc_overlap, hp.mu_law)" ] } ], diff --git a/notebooks/synthesize.ipynb b/notebooks/synthesize.ipynb deleted file mode 100644 index 473db5b..0000000 --- a/notebooks/synthesize.ipynb +++ /dev/null @@ -1,550 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "synthesize", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "zdMgfG7GMF_R", - "colab_type": "text" - }, - "source": [ - "# Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JQ5YuFPAxXUy", - "colab_type": "code", - "outputId": "e9f81ab0-adbe-4741-daee-fd115387b047", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 323 - } - }, - "source": [ - "# Clone the repo, the pretrained model and WaveRNN for the vocoder\n", - "!git clone https://github.com/as-ideas/TransformerTTS.git\n", - "!git clone https://github.com/as-ideas/tts_model_outputs.git\n", - "!git clone https://github.com/fatchord/WaveRNN" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Cloning into 'TransformerTTS'...\n", - "remote: Enumerating objects: 110, done.\u001b[K\n", - "remote: Counting objects: 100% (110/110), done.\u001b[K\n", - "remote: Compressing objects: 100% (90/90), done.\u001b[K\n", - "remote: Total 2334 (delta 55), reused 48 (delta 17), pack-reused 2224\u001b[K\n", - "Receiving objects: 100% (2334/2334), 1.60 MiB | 1.82 MiB/s, done.\n", - "Resolving deltas: 100% (1573/1573), done.\n", - "Cloning into 'tts_model_outputs'...\n", - "remote: Enumerating objects: 22, done.\u001b[K\n", - "remote: Counting objects: 100% (22/22), done.\u001b[K\n", - "remote: Compressing objects: 100% (21/21), done.\u001b[K\n", - "remote: Total 65 (delta 9), reused 0 (delta 0), pack-reused 43\u001b[K\n", - "Unpacking objects: 100% (65/65), done.\n", - "Cloning into 'WaveRNN'...\n", - "remote: Enumerating objects: 928, done.\u001b[K\n", - "remote: Total 928 (delta 0), reused 0 (delta 0), pack-reused 928\n", - "Receiving objects: 100% (928/928), 241.65 MiB | 13.75 MiB/s, done.\n", - "Resolving deltas: 100% (540/540), done.\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9bIzkIGjMRwm", - "colab_type": "code", - "outputId": "89e451ea-c101-4694-c404-d3c15a358854", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - } - }, - "source": [ - "# Install requirements\n", - "!apt-get install -y espeak\n", - "!pip install -r TransformerTTS/requirements.txt" - ], - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Reading package lists... Done\n", - "Building dependency tree \n", - "Reading state information... Done\n", - "The following additional packages will be installed:\n", - " espeak-data libespeak1 libportaudio2 libsonic0\n", - "The following NEW packages will be installed:\n", - " espeak espeak-data libespeak1 libportaudio2 libsonic0\n", - "0 upgraded, 5 newly installed, 0 to remove and 31 not upgraded.\n", - "Need to get 1,219 kB of archives.\n", - "After this operation, 3,031 kB of additional disk space will be used.\n", - "Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libportaudio2 amd64 19.6.0-1 [64.6 kB]\n", - "Get:2 http://archive.ubuntu.com/ubuntu bionic/main amd64 libsonic0 amd64 0.2.0-6 [13.4 kB]\n", - "Get:3 http://archive.ubuntu.com/ubuntu bionic/universe amd64 espeak-data amd64 1.48.04+dfsg-5 [934 kB]\n", - "Get:4 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libespeak1 amd64 1.48.04+dfsg-5 [145 kB]\n", - "Get:5 http://archive.ubuntu.com/ubuntu bionic/universe amd64 espeak amd64 1.48.04+dfsg-5 [61.6 kB]\n", - "Fetched 1,219 kB in 3s (457 kB/s)\n", - "Selecting previously unselected package libportaudio2:amd64.\n", - "(Reading database ... 144433 files and directories currently installed.)\n", - "Preparing to unpack .../libportaudio2_19.6.0-1_amd64.deb ...\n", - "Unpacking libportaudio2:amd64 (19.6.0-1) ...\n", - "Selecting previously unselected package libsonic0:amd64.\n", - "Preparing to unpack .../libsonic0_0.2.0-6_amd64.deb ...\n", - "Unpacking libsonic0:amd64 (0.2.0-6) ...\n", - "Selecting previously unselected package espeak-data:amd64.\n", - "Preparing to unpack .../espeak-data_1.48.04+dfsg-5_amd64.deb ...\n", - "Unpacking espeak-data:amd64 (1.48.04+dfsg-5) ...\n", - "Selecting previously unselected package libespeak1:amd64.\n", - "Preparing to unpack .../libespeak1_1.48.04+dfsg-5_amd64.deb ...\n", - "Unpacking libespeak1:amd64 (1.48.04+dfsg-5) ...\n", - "Selecting previously unselected package espeak.\n", - "Preparing to unpack .../espeak_1.48.04+dfsg-5_amd64.deb ...\n", - "Unpacking espeak (1.48.04+dfsg-5) ...\n", - "Setting up libportaudio2:amd64 (19.6.0-1) ...\n", - "Setting up espeak-data:amd64 (1.48.04+dfsg-5) ...\n", - "Setting up libsonic0:amd64 (0.2.0-6) ...\n", - "Setting up libespeak1:amd64 (1.48.04+dfsg-5) ...\n", - "Setting up espeak (1.48.04+dfsg-5) ...\n", - "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n", - "Processing triggers for libc-bin (2.27-3ubuntu1) ...\n", - "/sbin/ldconfig.real: /usr/local/lib/python3.6/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link\n", - "\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from -r TransformerTTS/requirements.txt (line 1)) (3.2.1)\n", - "Collecting librosa>=0.7.1\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/77/b5/1817862d64a7c231afd15419d8418ae1f000742cac275e85c74b219cbccb/librosa-0.7.2.tar.gz (1.6MB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.6MB 2.8MB/s \n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17.4 in /usr/local/lib/python3.6/dist-packages (from -r TransformerTTS/requirements.txt (line 3)) (1.18.4)\n", - "Collecting phonemizer==2.1\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d3/82/666045375029df9c2f274923539f43346a7b7abc349b02e33dff585da56f/phonemizer-2.1-py3-none-any.whl (47kB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 6.4MB/s \n", - "\u001b[?25hCollecting ruamel.yaml>=0.16.6\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/92/59af3e38227b9cc14520bf1e59516d99ceca53e3b8448094248171e9432b/ruamel.yaml-0.16.10-py2.py3-none-any.whl (111kB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 112kB 19.5MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from -r TransformerTTS/requirements.txt (line 6)) (2.2.0)\n", - "Requirement already satisfied: tqdm>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from -r TransformerTTS/requirements.txt (line 7)) (4.41.1)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->-r TransformerTTS/requirements.txt (line 1)) (1.2.0)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->-r TransformerTTS/requirements.txt (line 1)) (0.10.0)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->-r TransformerTTS/requirements.txt (line 1)) (2.4.7)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->-r TransformerTTS/requirements.txt (line 1)) (2.8.1)\n", - "Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (2.1.8)\n", - "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (1.4.1)\n", - "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (0.22.2.post1)\n", - "Requirement already satisfied: joblib>=0.12 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (0.15.0)\n", - "Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (4.4.2)\n", - "Requirement already satisfied: six>=1.3 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (1.12.0)\n", - "Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (0.2.2)\n", - "Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.6/dist-packages (from librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (0.48.0)\n", - "Collecting soundfile>=0.9.0\n", - " Downloading https://files.pythonhosted.org/packages/eb/f2/3cbbbf3b96fb9fa91582c438b574cff3f45b29c772f94c400e2c99ef5db9/SoundFile-0.10.3.post1-py2.py3-none-any.whl\n", - "Requirement already satisfied: attrs>=18.1 in /usr/local/lib/python3.6/dist-packages (from phonemizer==2.1->-r TransformerTTS/requirements.txt (line 4)) (19.3.0)\n", - "Collecting segments\n", - " Downloading https://files.pythonhosted.org/packages/5b/a0/0c3fe64787745c39eb3f2f5f5f9ed8d008d9ef22e9d7f9f52f71ea4712f7/segments-2.1.3-py2.py3-none-any.whl\n", - "Collecting ruamel.yaml.clib>=0.1.2; platform_python_implementation == \"CPython\" and python_version < \"3.9\"\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/53/77/4bcd63f362bcb6c8f4f06253c11f9772f64189bf08cf3f40c5ccbda9e561/ruamel.yaml.clib-0.2.0-cp36-cp36m-manylinux1_x86_64.whl (548kB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 552kB 14.8MB/s \n", - "\u001b[?25hRequirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.10.0)\n", - "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.1.0)\n", - "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.9.0)\n", - "Requirement already satisfied: tensorboard<2.3.0,>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2.2.1)\n", - "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.29.0)\n", - "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.12.1)\n", - "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.3.3)\n", - "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.6.3)\n", - "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.2.1)\n", - "Requirement already satisfied: tensorflow-estimator<2.3.0,>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2.2.0)\n", - "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.34.2)\n", - "Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.2.0)\n", - "Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.1.2)\n", - "Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2.10.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from numba>=0.43.0->librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (46.3.0)\n", - "Requirement already satisfied: llvmlite<0.32.0,>=0.31.0dev0 in /usr/local/lib/python3.6/dist-packages (from numba>=0.43.0->librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (0.31.0)\n", - "Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.6/dist-packages (from soundfile>=0.9.0->librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (1.14.0)\n", - "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from segments->phonemizer==2.1->-r TransformerTTS/requirements.txt (line 4)) (2019.12.20)\n", - "Collecting clldutils>=1.7.3\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f1/ec/76860c7c36e8f6683a6d5041ebda054f4c1deca1a8aac9ea3357105139f5/clldutils-3.5.1-py2.py3-none-any.whl (188kB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 194kB 20.4MB/s \n", - "\u001b[?25hCollecting csvw>=1.5.6\n", - " Downloading https://files.pythonhosted.org/packages/d1/b6/8fef6788b8f05b21424a17ae3881eff916d42e5c7e87f57a85d9d7abf0a1/csvw-1.7.0-py2.py3-none-any.whl\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.2.2)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.4.1)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.6.0.post3)\n", - "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.7.2)\n", - "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.0.1)\n", - "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2.23.0)\n", - "Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi>=1.0->soundfile>=0.9.0->librosa>=0.7.1->-r TransformerTTS/requirements.txt (line 2)) (2.20)\n", - "Collecting colorlog\n", - " Downloading https://files.pythonhosted.org/packages/00/0d/22c73c2eccb21dd3498df7d22c0b1d4a30f5a5fb3feb64e1ce06bc247747/colorlog-4.1.0-py2.py3-none-any.whl\n", - "Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.6/dist-packages (from clldutils>=1.7.3->segments->phonemizer==2.1->-r TransformerTTS/requirements.txt (line 4)) (0.8.7)\n", - "Collecting rfc3986\n", - " Downloading https://files.pythonhosted.org/packages/78/be/7b8b99fd74ff5684225f50dd0e865393d2265656ef3b4ba9eaaaffe622b8/rfc3986-1.4.0-py2.py3-none-any.whl\n", - "Collecting isodate\n", - "\u001b[?25l Downloading https://files.pythonhosted.org/packages/9b/9f/b36f7774ff5ea8e428fdcfc4bb332c39ee5b9362ddd3d40d9516a55221b2/isodate-0.6.0-py2.py3-none-any.whl (45kB)\n", - "\u001b[K |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 51kB 6.1MB/s \n", - "\u001b[?25hRequirement already satisfied: uritemplate>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from csvw>=1.5.6->segments->phonemizer==2.1->-r TransformerTTS/requirements.txt (line 4)) (3.0.1)\n", - "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.6.0)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.3.0)\n", - "Requirement already satisfied: rsa<4.1,>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (4.0)\n", - "Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.1.1)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.2.8)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.0.4)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2020.4.5.1)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (1.24.3)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (2.9)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.1.0)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (3.1.0)\n", - "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<4.1,>=3.1.4->google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow>=2.2.0->-r TransformerTTS/requirements.txt (line 6)) (0.4.8)\n", - "Building wheels for collected packages: librosa\n", - " Building wheel for librosa (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for librosa: filename=librosa-0.7.2-cp36-none-any.whl size=1612885 sha256=24d7bbec2757303377b47d49db4db3bfc6608222ba90930a2c1e0564e6a04aee\n", - " Stored in directory: /root/.cache/pip/wheels/4c/6e/d7/bb93911540d2d1e44d690a1561871e5b6af82b69e80938abef\n", - "Successfully built librosa\n", - "Installing collected packages: soundfile, librosa, colorlog, rfc3986, isodate, csvw, clldutils, segments, phonemizer, ruamel.yaml.clib, ruamel.yaml\n", - " Found existing installation: librosa 0.6.3\n", - " Uninstalling librosa-0.6.3:\n", - " Successfully uninstalled librosa-0.6.3\n", - "Successfully installed clldutils-3.5.1 colorlog-4.1.0 csvw-1.7.0 isodate-0.6.0 librosa-0.7.2 phonemizer-2.1 rfc3986-1.4.0 ruamel.yaml-0.16.10 ruamel.yaml.clib-0.2.0 segments-2.1.3 soundfile-0.10.3.post1\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "LucwkAK1yEVq", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 88 - }, - "outputId": "bef65ba5-1549-488c-d4dc-788965fa93ee" - }, - "source": [ - "# Load pretrained models\n", - "config_path = 'tts_model_outputs/ljspeech_transformertts/standard'\n", - "project_path = 'TransformerTTS'\n", - "\n", - "import sys\n", - "sys.path.append(project_path)\n", - "from utils.config_manager import ConfigManager\n", - "from utils.audio import reconstruct_waveform\n", - "\n", - "import IPython.display as ipd\n", - "\n", - "config_loader = ConfigManager(config_path)\n", - "model = config_loader.load_model('tts_model_outputs/ljspeech_transformertts/standard/model_weights/ckpt-90')" - ], - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "text": [ - "WARNING: could not retrieve git hash. Command '['git', 'describe', '--always']' returned non-zero exit status 128.\n", - "WARNING: could not check git hash. Command '['git', 'describe', '--always']' returned non-zero exit status 128.\n", - "restored weights from tts_model_outputs/ljspeech_transformertts/standard/model_weights/ckpt-90 at step 900000\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_5RKHIDQyZvo", - "colab_type": "code", - "outputId": "a8c04963-ab23-480e-9826-53de2db0c67c", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "# Synthesize text\n", - "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", - "out = model.predict(sentence)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "text": [ - "pred text mel: 397 stop out: -1.9915766716003418Stopping\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GXxdDHOAyZ6f", - "colab_type": "code", - "outputId": "d319bc2c-2843-4b51-e1ce-76c2857f255e", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 75 - } - }, - "source": [ - "# Convert spectrogram to wav (with griffin lim)\n", - "wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)\n", - "ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate']))" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eZJo81viVus-", - "colab_type": "text" - }, - "source": [ - "### WaveRNN" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "5oQhgBhUPB9C", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Export for WaveRNN\n", - "import numpy as np\n", - "from pathlib import Path\n", - "WaveRNN_path = Path('WaveRNN/')\n", - "np.save(WaveRNN_path / 'scientists.npy', (out['mel'].numpy().T+4.)/8.)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "WjIuQALHTr-R", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Do some sys cleaning and imports\n", - "sys.path.remove('TransformerTTS')\n", - "sys.modules.pop('utils')\n", - "\n", - "import sys\n", - "sys.path.append('WaveRNN/')\n", - "from utils.dsp import hp\n", - "from models.fatchord_version import WaveRNN\n", - "import torch\n", - "import numpy as np\n", - "from pathlib import Path" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "dptoYzL1XFAr", - "colab_type": "code", - "outputId": "a87f9520-94cb-4306-d1b9-b8aa6b5b68bc", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - } - }, - "source": [ - "# Unzip the pretrained model\n", - "!unzip WaveRNN/pretrained/ljspeech.wavernn.mol.800k.zip -d WaveRNN/pretrained/" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Archive: WaveRNN/pretrained/ljspeech.wavernn.mol.800k.zip\n", - " inflating: WaveRNN/pretrained/latest_weights.pyt \n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rKixR97aTtwX", - "colab_type": "code", - "outputId": "5bf538f8-bf7c-4ca3-f6a8-93926a457ba3", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "# Load pretrained model\n", - "hp.configure(WaveRNN_path / 'hparams.py') # Load hparams from file\n", - "device = torch.device('cpu')\n", - "model = WaveRNN(rnn_dims=hp.voc_rnn_dims,\n", - " fc_dims=hp.voc_fc_dims,\n", - " bits=hp.bits,\n", - " pad=hp.voc_pad,\n", - " upsample_factors=hp.voc_upsample_factors,\n", - " feat_dims=hp.num_mels,\n", - " compute_dims=hp.voc_compute_dims,\n", - " res_out_dims=hp.voc_res_out_dims,\n", - " res_blocks=hp.voc_res_blocks,\n", - " hop_length=hp.hop_length,\n", - " sample_rate=hp.sample_rate,\n", - " mode=hp.voc_mode).to(device)\n", - "\n", - "model.load(str(WaveRNN_path / 'pretrained/latest_weights.pyt'))" - ], - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Trainable Parameters: 4.234M\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "mPF7TrqDOE8S", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Ignore some TF warnings\n", - "import tensorflow as tf\n", - "tf.get_logger().setLevel('ERROR')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "EVkdFQeRUGQ-", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "62d4a302-de02-4acb-8979-4b30b3903db5" - }, - "source": [ - "# Generate sample with pre-trained WaveRNN vocoder\n", - "mel = np.load(WaveRNN_path / 'scientists.npy')\n", - "_ = model.generate(mel[np.newaxis,:,:], 'scientists.wav', False, 1, hp.voc_overlap, hp.mu_law)" - ], - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "text": [ - "| β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 109400/109450 | Batch Size: 1 | Gen Rate: 0.7kHz | " - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vQYaZawLXTJI", - "colab_type": "code", - "outputId": "bc677767-da4c-4125-b4a0-0e2f43a93efc", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 75 - } - }, - "source": [ - "# Load wav file\n", - "ipd.display(ipd.Audio('scientists.wav'))" - ], - "execution_count": 12, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jWX00MuHYojU", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/notebooks/synthesize_autoregressive.ipynb b/notebooks/synthesize_autoregressive.ipynb new file mode 100644 index 0000000..c89260b --- /dev/null +++ b/notebooks/synthesize_autoregressive.ipynb @@ -0,0 +1,343 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zdMgfG7GMF_R" + }, + "source": [ + "# Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2\n", + "## Autoregressive Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 323 + }, + "colab_type": "code", + "id": "JQ5YuFPAxXUy", + "outputId": "e9f81ab0-adbe-4741-daee-fd115387b047" + }, + "outputs": [], + "source": [ + "# Clone the Transformer TTS and WaveRNN repos\n", + "!git clone https://github.com/as-ideas/TransformerTTS.git\n", + "!git clone https://github.com/fatchord/WaveRNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "9bIzkIGjMRwm", + "outputId": "89e451ea-c101-4694-c404-d3c15a358854" + }, + "outputs": [], + "source": [ + "# Install requirements\n", + "!apt-get install -y espeak\n", + "!pip install -r TransformerTTS/requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download the pre-trained weights\n", + "! wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_autoregressive_transformer.zip\n", + "! unzip ljspeech_autoregressive_transformer.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the paths\n", + "from pathlib import Path\n", + "WaveRNN_path = 'WaveRNN/'\n", + "TTS_path = 'TransformerTTS/'\n", + "config_path = Path('ljspeech_autoregressive_transformer/standard')\n", + "\n", + "import sys\n", + "sys.path.append(TTS_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 88 + }, + "colab_type": "code", + "id": "LucwkAK1yEVq", + "outputId": "bef65ba5-1549-488c-d4dc-788965fa93ee" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING: git hash mismatch. Current: 0afbb76. Config hash: 1fec787\n", + "restored weights from standard/autoregressive_weights/ckpt-38 at step 380000\n" + ] + } + ], + "source": [ + "# Load pretrained models\n", + "from utils.config_manager import ConfigManager\n", + "from utils.audio import reconstruct_waveform\n", + "\n", + "import IPython.display as ipd\n", + "\n", + "config_loader = ConfigManager(str(config_path), model_kind='autoregressive')\n", + "model = config_loader.load_model(str(config_path / 'autoregressive_weights/ckpt-38'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Synthesize text\n", + "sentence = 'Scientists at the CERN laboratory, say that they have discovered a new particle.'\n", + "out = model.predict(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pred text mel: 372 stop out: -1.466308832168579Stopping\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Convert spectrogram to wav (with griffin lim)\n", + "wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)\n", + "ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate']))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize for WaveRNN\n", + "mel = (out['mel'].numpy().T+4.)/8." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eZJo81viVus-" + }, + "source": [ + "### WaveRNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5oQhgBhUPB9C" + }, + "outputs": [], + "source": [ + "# Do some sys cleaning and imports\n", + "sys.path.remove(TTS_path)\n", + "sys.modules.pop('utils')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WjIuQALHTr-R" + }, + "outputs": [], + "source": [ + "sys.path.append(WaveRNN_path)\n", + "from utils.dsp import hp\n", + "from models.fatchord_version import WaveRNN\n", + "import torch\n", + "import numpy as np\n", + "WaveRNN_path = Path(WaveRNN_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "dptoYzL1XFAr", + "outputId": "a87f9520-94cb-4306-d1b9-b8aa6b5b68bc" + }, + "outputs": [], + "source": [ + "# Unzip the pretrained model\n", + "!unzip WaveRNN/pretrained/ljspeech.wavernn.mol.800k.zip -d WaveRNN/pretrained/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "rKixR97aTtwX", + "outputId": "5bf538f8-bf7c-4ca3-f6a8-93926a457ba3" + }, + "outputs": [], + "source": [ + "# Load pretrained model\n", + "hp.configure(WaveRNN_path / 'hparams.py') # Load hparams from file\n", + "device = torch.device('cpu')\n", + "model = WaveRNN(rnn_dims=hp.voc_rnn_dims,\n", + " fc_dims=hp.voc_fc_dims,\n", + " bits=hp.bits,\n", + " pad=hp.voc_pad,\n", + " upsample_factors=hp.voc_upsample_factors,\n", + " feat_dims=hp.num_mels,\n", + " compute_dims=hp.voc_compute_dims,\n", + " res_out_dims=hp.voc_res_out_dims,\n", + " res_blocks=hp.voc_res_blocks,\n", + " hop_length=hp.hop_length,\n", + " sample_rate=hp.sample_rate,\n", + " mode=hp.voc_mode).to(device)\n", + "\n", + "model.load(str(WaveRNN_path / 'pretrained/latest_weights.pyt'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "mPF7TrqDOE8S" + }, + "outputs": [], + "source": [ + "# Ignore some TF warnings\n", + "import tensorflow as tf\n", + "tf.get_logger().setLevel('ERROR')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "EVkdFQeRUGQ-", + "outputId": "62d4a302-de02-4acb-8979-4b30b3903db5" + }, + "outputs": [], + "source": [ + "# Generate sample with pre-trained WaveRNN vocoder\n", + "batch_pred = True # False is slower but possibly better\n", + "_ = model.generate(mel.clip(0,1)[np.newaxis,:,:], 'scientists.wav', batch_pred, 11_000, hp.voc_overlap, hp.mu_law)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "colab_type": "code", + "id": "vQYaZawLXTJI", + "outputId": "bc677767-da4c-4125-b4a0-0e2f43a93efc" + }, + "outputs": [], + "source": [ + "# Load wav file\n", + "ipd.display(ipd.Audio('scientists.wav'))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "synthesize", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/notebooks/synthesize_forward.ipynb b/notebooks/synthesize_forward.ipynb new file mode 100644 index 0000000..7897d75 --- /dev/null +++ b/notebooks/synthesize_forward.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zdMgfG7GMF_R" + }, + "source": [ + "# Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2\n", + "## Forward Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 323 + }, + "colab_type": "code", + "id": "JQ5YuFPAxXUy", + "outputId": "e9f81ab0-adbe-4741-daee-fd115387b047" + }, + "outputs": [], + "source": [ + "# Clone the Transformer TTS and WaveRNN repos\n", + "!git clone https://github.com/as-ideas/TransformerTTS.git\n", + "!git clone https://github.com/fatchord/WaveRNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "9bIzkIGjMRwm", + "outputId": "89e451ea-c101-4694-c404-d3c15a358854" + }, + "outputs": [], + "source": [ + "# Install requirements\n", + "!apt-get install -y espeak\n", + "!pip install -r TransformerTTS/requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download the pre-trained weights\n", + "! wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_forward_transformer.zip\n", + "! unzip ljspeech_forward_transformer.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the paths\n", + "from pathlib import Path\n", + "WaveRNN_path = 'WaveRNN/'\n", + "TTS_path = 'TransformerTTS/'\n", + "config_path = Path('ljspeech_forward_transformer/standard')\n", + "\n", + "import sys\n", + "sys.path.append(TTS_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 88 + }, + "colab_type": "code", + "id": "LucwkAK1yEVq", + "outputId": "bef65ba5-1549-488c-d4dc-788965fa93ee" + }, + "outputs": [], + "source": [ + "# Load pretrained models\n", + "from utils.config_manager import ConfigManager\n", + "from utils.audio import reconstruct_waveform\n", + "\n", + "import IPython.display as ipd\n", + "\n", + "config_loader = ConfigManager(str(config_path), model_kind='forward')\n", + "model = config_loader.load_model(str(config_path / 'forward_weights/ckpt-133'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "_5RKHIDQyZvo", + "outputId": "a8c04963-ab23-480e-9826-53de2db0c67c" + }, + "outputs": [], + "source": [ + "# Synthesize text\n", + "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", + "out = model.predict(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "colab_type": "code", + "id": "GXxdDHOAyZ6f", + "outputId": "d319bc2c-2843-4b51-e1ce-76c2857f255e" + }, + "outputs": [], + "source": [ + "# Convert spectrogram to wav (with griffin lim)\n", + "wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)\n", + "ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate']))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize for WaveRNN\n", + "mel = (out['mel'].numpy().T+4.)/8." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also vary the speech speed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 20% faster\n", + "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", + "out = model.predict(sentence, speed_regulator=1.20)\n", + "wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)\n", + "ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate']))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 10% slower\n", + "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", + "out = model.predict(sentence, speed_regulator=.9)\n", + "wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)\n", + "ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate']))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eZJo81viVus-" + }, + "source": [ + "### WaveRNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WjIuQALHTr-R" + }, + "outputs": [], + "source": [ + "# Do some sys cleaning and imports\n", + "sys.path.remove(TTS_path)\n", + "sys.modules.pop('utils')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WjIuQALHTr-R" + }, + "outputs": [], + "source": [ + "sys.path.append(WaveRNN_path)\n", + "from utils.dsp import hp\n", + "from models.fatchord_version import WaveRNN\n", + "import torch\n", + "import numpy as np\n", + "WaveRNN_path = Path(WaveRNN_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "dptoYzL1XFAr", + "outputId": "a87f9520-94cb-4306-d1b9-b8aa6b5b68bc" + }, + "outputs": [], + "source": [ + "# Unzip the pretrained model\n", + "!unzip WaveRNN/pretrained/ljspeech.wavernn.mol.800k.zip -d WaveRNN/pretrained/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "rKixR97aTtwX", + "outputId": "5bf538f8-bf7c-4ca3-f6a8-93926a457ba3" + }, + "outputs": [], + "source": [ + "# Load pretrained model\n", + "hp.configure(WaveRNN_path / 'hparams.py') # Load hparams from file\n", + "device = torch.device('cpu')\n", + "model = WaveRNN(rnn_dims=hp.voc_rnn_dims,\n", + " fc_dims=hp.voc_fc_dims,\n", + " bits=hp.bits,\n", + " pad=hp.voc_pad,\n", + " upsample_factors=hp.voc_upsample_factors,\n", + " feat_dims=hp.num_mels,\n", + " compute_dims=hp.voc_compute_dims,\n", + " res_out_dims=hp.voc_res_out_dims,\n", + " res_blocks=hp.voc_res_blocks,\n", + " hop_length=hp.hop_length,\n", + " sample_rate=hp.sample_rate,\n", + " mode=hp.voc_mode).to(device)\n", + "\n", + "model.load(str(WaveRNN_path / 'pretrained/latest_weights.pyt'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "mPF7TrqDOE8S" + }, + "outputs": [], + "source": [ + "# Ignore some TF warnings\n", + "import tensorflow as tf\n", + "tf.get_logger().setLevel('ERROR')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "EVkdFQeRUGQ-", + "outputId": "62d4a302-de02-4acb-8979-4b30b3903db5" + }, + "outputs": [], + "source": [ + "# Generate sample with pre-trained WaveRNN vocoder\n", + "batch_pred = True # False is slower but possibly better\n", + "_ = model.generate(mel.clip(0,1)[np.newaxis,:,:], 'scientists.wav', batch_pred, 11_000, hp.voc_overlap, hp.mu_law)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "colab_type": "code", + "id": "vQYaZawLXTJI", + "outputId": "bc677767-da4c-4125-b4a0-0e2f43a93efc" + }, + "outputs": [], + "source": [ + "# Load wav file\n", + "ipd.display(ipd.Audio('scientists.wav'))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "synthesize", + "provenance": [] + }, + "kernelspec": { + "display_name": "ttsTF", + "language": "python", + "name": "ttstf" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/preprocessing/data_handling.py b/preprocessing/data_handling.py index 68c4862..35cea41 100644 --- a/preprocessing/data_handling.py +++ b/preprocessing/data_handling.py @@ -113,3 +113,10 @@ def _run(self, phonemes, text, mel): stop_probs = np.ones((norm_mel.shape[0])) stop_probs[-1] = 2 return norm_mel, encoded_phonemes, stop_probs + + +class ForwardDataPrepper: + + def __call__(self, sample): + mel, encoded_phonemes, durations = np.load(str(sample), allow_pickle=True) + return mel, encoded_phonemes, durations diff --git a/train.py b/train_autoregressive.py similarity index 95% rename from train.py rename to train_autoregressive.py index 904cb2c..11a90a2 100644 --- a/train.py +++ b/train_autoregressive.py @@ -1,4 +1,5 @@ import argparse +import traceback import tensorflow as tf import numpy as np @@ -13,7 +14,7 @@ np.random.seed(42) tf.random.set_seed(42) -# dinamically allocate GPU +# dynamically allocate GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: @@ -21,11 +22,10 @@ for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logical_gpus = tf.config.experimental.list_logical_devices('GPU') - print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - except RuntimeError as e: - # Memory growth must be set before GPUs have been initialized - print(e) - + print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs') + except Exception: + traceback.print_exc() + @ignore_exception @time_it @@ -63,7 +63,7 @@ def validate(model, help="deletes weights under this config's folder.") parser.add_argument('--session_name', dest='session_name', default=None) args = parser.parse_args() -config_manager = ConfigManager(config_path=args.config, session_name=args.session_name) +config_manager = ConfigManager(config_path=args.config, model_kind='autoregressive', session_name=args.session_name) config = config_manager.config config_manager.create_remove_dirs(clear_dir=args.clear_dir, clear_logs=args.clear_logs, @@ -118,7 +118,7 @@ def validate(model, for _ in t: t.set_description(f'step {model.step}') mel, phonemes, stop = train_dataset.next_batch() - decoder_prenet_dropout = piecewise_linear_schedule(model.step, config['decoder_dropout_schedule']) + decoder_prenet_dropout = piecewise_linear_schedule(model.step, config['decoder_prenet_dropout_schedule']) learning_rate = piecewise_linear_schedule(model.step, config['learning_rate_schedule']) reduction_factor = reduction_schedule(model.step, config['reduction_factor_schedule']) drop_n_heads = tf.cast(reduction_schedule(model.step, config['head_drop_schedule']), tf.int32) @@ -138,7 +138,7 @@ def validate(model, t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2) summary_manager.display_loss(output, tag='Train') - summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout', scalar_value=model.decoder_prenet_dropout) + summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout', scalar_value=model.decoder_prenet.rate) summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr) summary_manager.display_scalar(tag='Meta/reduction_factor', scalar_value=model.r) summary_manager.display_scalar(tag='Meta/drop_n_heads', scalar_value=model.drop_n_heads) diff --git a/train_forward.py b/train_forward.py new file mode 100644 index 0000000..715caf4 --- /dev/null +++ b/train_forward.py @@ -0,0 +1,195 @@ +import argparse +import traceback +from pathlib import Path +from time import time + +import tensorflow as tf +import numpy as np +from tqdm import trange + +from utils.config_manager import ConfigManager +from preprocessing.data_handling import Dataset, ForwardDataPrepper +from utils.decorators import ignore_exception, time_it +from utils.scheduling import piecewise_linear_schedule, reduction_schedule +from utils.logging import SummaryManager +from model.transformer_utils import create_mel_padding_mask + +np.random.seed(42) +tf.random.set_seed(42) + +# dynamically allocate GPU +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + try: + # Currently, memory growth needs to be the same across GPUs + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.experimental.list_logical_devices('GPU') + print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs') + except Exception: + traceback.print_exc() + + +def build_file_list(data_dir: Path): + sample_paths = [] + for item in data_dir.iterdir(): + if item.suffix == '.npy': + sample_paths.append(str(item)) + return sample_paths + + +@ignore_exception +@time_it +def validate(model, + val_dataset, + summary_manager): + val_loss = {'loss': 0.} + norm = 0. + for mel, phonemes, durations in val_dataset.all_batches(): + model_out = model.val_step(input_sequence=phonemes, + target_sequence=mel, + target_durations=durations) + norm += 1 + val_loss['loss'] += model_out['loss'] + val_loss['loss'] /= norm + summary_manager.display_loss(model_out, tag='Validation', plot_all=True) + summary_manager.display_attention_heads(model_out, tag='ValidationAttentionHeads') + summary_manager.add_histogram(tag=f'Validation/Predicted durations', values=model_out['duration']) + summary_manager.add_histogram(tag=f'Validation/Target durations', values=durations) + summary_manager.display_mel(mel=model_out['mel'][0], tag=f'Validation/linear_mel_out') + summary_manager.display_mel(mel=mel[0], tag=f'Validation/target_mel') + return val_loss['loss'] + + +# consuming CLI, creating paths and directories, load data + +parser = argparse.ArgumentParser() +parser.add_argument('--config', dest='config', type=str) +parser.add_argument('--reset_dir', dest='clear_dir', action='store_true', + help="deletes everything under this config's folder.") +parser.add_argument('--reset_logs', dest='clear_logs', action='store_true', + help="deletes logs under this config's folder.") +parser.add_argument('--reset_weights', dest='clear_weights', action='store_true', + help="deletes weights under this config's folder.") +parser.add_argument('--session_name', dest='session_name', default=None) +args = parser.parse_args() + +config_manager = ConfigManager(config_path=args.config, model_kind='forward', session_name=args.session_name) +config = config_manager.config +config_manager.create_remove_dirs(clear_dir=args.clear_dir, + clear_logs=args.clear_logs, + clear_weights=args.clear_weights) +config_manager.dump_config() +config_manager.print_config() + +train_data_list = build_file_list(config_manager.train_datadir / 'forward_data/train') +dataprep = ForwardDataPrepper() +train_dataset = Dataset(samples=train_data_list, + mel_channels=config['mel_channels'], + preprocessor=dataprep, + batch_size=config['batch_size'], + shuffle=True) +val_data_list = build_file_list(config_manager.train_datadir / 'forward_data/val') +val_dataset = Dataset(samples=val_data_list, + mel_channels=config['mel_channels'], + preprocessor=dataprep, + batch_size=config['batch_size'], + shuffle=False) + +# get model, prepare data for model, create datasets +model = config_manager.get_model() +config_manager.compile_model(model) + +# create logger and checkpointer and restore latest model +summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir, config=config) +checkpoint = tf.train.Checkpoint(step=tf.Variable(1), + optimizer=model.optimizer, + net=model) +manager = tf.train.CheckpointManager(checkpoint, config_manager.weights_dir, + max_to_keep=config['keep_n_weights'], + keep_checkpoint_every_n_hours=config['keep_checkpoint_every_n_hours']) +checkpoint.restore(manager.latest_checkpoint) +if manager.latest_checkpoint: + print(f'\nresuming training from step {model.step} ({manager.latest_checkpoint})') +else: + print(f'\nstarting training from scratch') +# main event +print('\nTRAINING') +losses = [] +test_batch = val_dataset.next_batch() +t = trange(model.step, config['max_steps'], leave=True) +for _ in t: + t.set_description(f'step {model.step}') + mel, phonemes, durations = train_dataset.next_batch() + learning_rate = piecewise_linear_schedule(model.step, config['learning_rate_schedule']) + decoder_prenet_dropout = piecewise_linear_schedule(model.step, config['decoder_dropout_schedule']) + drop_n_heads = tf.cast(reduction_schedule(model.step, config['head_drop_schedule']), tf.int32) + model.set_constants(decoder_prenet_dropout=decoder_prenet_dropout, + learning_rate=learning_rate, + drop_n_heads=drop_n_heads) + output = model.train_step(input_sequence=phonemes, + target_sequence=mel, + target_durations=durations) + losses.append(float(output['loss'])) + + t.display(f'step loss: {losses[-1]}', pos=1) + for pos, n_steps in enumerate(config['n_steps_avg_losses']): + if len(losses) > n_steps: + t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2) + + summary_manager.display_loss(output, tag='Train') + summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr) + summary_manager.display_scalar(tag='Meta/decoder_prenet_dropout', scalar_value=model.decoder_prenet.rate) + summary_manager.display_scalar(tag='Meta/drop_n_heads', scalar_value=model.drop_n_heads) + if model.step % config['train_images_plotting_frequency'] == 0: + summary_manager.display_attention_heads(output, tag='TrainAttentionHeads') + summary_manager.display_mel(mel=output['mel'][0], tag=f'Train/linear_mel_out') + summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel') + summary_manager.add_histogram(tag=f'Train/Predicted durations', values=output['duration']) + summary_manager.add_histogram(tag=f'Train/Target durations', values=durations) + + if model.step % config['weights_save_frequency'] == 0: + save_path = manager.save() + t.display(f'checkpoint at step {model.step}: {save_path}', pos=len(config['n_steps_avg_losses']) + 2) + + if model.step % config['validation_frequency'] == 0: + t.display(f'Validating', pos=len(config['n_steps_avg_losses']) + 3) + val_loss, time_taken = validate(model=model, + val_dataset=val_dataset, + summary_manager=summary_manager) + t.display(f'validation loss at step {model.step}: {val_loss} (took {time_taken}s)', + pos=len(config['n_steps_avg_losses']) + 3) + + if model.step % config['prediction_frequency'] == 0 and (model.step >= config['prediction_start_step']): + tar_mel, phonemes, durs = test_batch + t.display(f'Predicting', pos=len(config['n_steps_avg_losses']) + 4) + timed_pred = time_it(model.predict) + model_out, time_taken = timed_pred(phonemes, encode=False) + summary_manager.display_attention_heads(model_out, tag='TestAttentionHeads') + summary_manager.add_histogram(tag=f'Test/Predicted durations', values=model_out['duration']) + summary_manager.add_histogram(tag=f'Test/Target durations', values=durs) + if model.r > 1: + pred_lengths = tf.cast(tf.math.round(model_out['duration']), tf.int32) + pred_lengths = tf.reduce_sum(pred_lengths, axis=1) + else: + pred_lengths = tf.cast(tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32) + pred_lengths = tf.squeeze(pred_lengths) + tar_lengths = tf.cast(tf.reduce_sum(1 - create_mel_padding_mask(tar_mel), axis=-1), tf.int32) + tar_lengths = tf.squeeze(tar_lengths) + display_start = time() + for j, pred_mel in enumerate(model_out['mel']): + predval = pred_mel[:pred_lengths[j], :] + tar_value = tar_mel[j, :tar_lengths[j], :] + summary_manager.display_mel(mel=predval, tag=f'Test/sample {j}/predicted_mel') + summary_manager.display_mel(mel=tar_value, tag=f'Test/sample {j}/target_mel') + if j < config['n_predictions']: + if model.step >= config['audio_start_step'] and ( + model.step % config['audio_prediction_frequency'] == 0): + summary_manager.display_audio(tag=f'Target/sample {j}', mel=tar_value) + summary_manager.display_audio(tag=f'Prediction/sample {j}', mel=predval) + else: + break + display_end = time() + t.display(f'Predictions took {time_taken}. Displaying took {display_end - display_start}.', + pos=len(config['n_steps_avg_losses']) + 4) +print('Done.') diff --git a/utils/alignments.py b/utils/alignments.py new file mode 100644 index 0000000..e132e05 --- /dev/null +++ b/utils/alignments.py @@ -0,0 +1,165 @@ +import numpy as np +import tensorflow as tf + +from model.transformer_utils import create_mel_padding_mask, create_encoder_padding_mask + +logger = tf.get_logger() +logger.setLevel('ERROR') + + +def duration_to_alignment_matrix(durations): + starts = np.cumsum(np.append([0], durations[:-1])) + tot_duration = np.sum(durations) + pads = tot_duration - starts - durations + alignments = [np.concatenate([np.zeros(starts[i]), np.ones(durations[i]), np.zeros(pads[i])]) for i in + range(len(durations))] + return np.array(alignments) + + +def clean_attention(binary_attention, jump_threshold): + phon_idx = 0 + clean_attn = np.zeros(binary_attention.shape) + for i, av in enumerate(binary_attention): + next_phon_idx = np.argmax(av) + if abs(next_phon_idx - phon_idx) > jump_threshold: + next_phon_idx = phon_idx + phon_idx = next_phon_idx + clean_attn[i, min(phon_idx, clean_attn.shape[1] - 1)] = 1 + return clean_attn + + +def weight_mask(attention_weights): + """ Exponential loss mask based on distance from approximate diagonal""" + max_m, max_n = attention_weights.shape + I = np.tile(np.arange(max_n), (max_m, 1)) / max_n + J = np.swapaxes(np.tile(np.arange(max_m), (max_n, 1)), 0, 1) / max_m + return np.sqrt(np.square(I - J)) + + +def fill_zeros(duration, take_from='next'): + """ Fills zeros with one. Takes either from the next non-zero duration, or max.""" + for i in range(len(duration)): + if i < (len(duration) - 1): + if duration[i] == 0: + if take_from == 'next': + next_avail = np.where(duration[i:] > 1)[0] + if len(next_avail) > 1: + next_avail = next_avail[0] + elif take_from == 'max': + next_avail = np.argmax(duration[i:]) + if next_avail: + duration[i] = 1 + duration[i + next_avail] -= 1 + return duration + + +def fix_attention_jumps(binary_attn, alignments_weights, binary_score): + """ Scans for jumps in attention and attempts to fix. If score decreases, a collapse + is likely so it tries to relax the jump size. + Lower jumps size is more accurate, but more prone to collapse. + """ + clean_scores = [] + clean_attns = [] + for jumpth in [2, 3, 4, 5]: + cl_at = clean_attention(binary_attention=binary_attn, jump_threshold=jumpth) + clean_attns.append(cl_at) + sclean_score = np.sum(alignments_weights * cl_at) + clean_scores.append(sclean_score) + best_idx = np.argmin(clean_scores) + best_score = clean_scores[best_idx] + best_cleaned_attention = clean_attns[best_idx] + while ((best_score - binary_score) > 2.) and (jumpth < 20): + jumpth += 1 + best_cleaned_attention = clean_attention(binary_attention=binary_attn, jump_threshold=jumpth) + best_score = np.sum(alignments_weights * best_cleaned_attention) + return best_cleaned_attention + + +def binary_attention(attention_weights): + attention_peak_per_phoneme = attention_weights.max(axis=1) + binary_attn = (attention_weights.T == attention_peak_per_phoneme).astype(int).T + assert np.sum( + np.sum(attention_weights.T == attention_peak_per_phoneme, axis=0) != 1) == 0 # single peak per mel step + binary_score = np.sum(attention_weights * binary_attn) + return binary_attn, binary_score + + +def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False, binary=False, fill_gaps=False, + fix_jumps=False, fill_mode='max'): + """ + + :param batch_alignments: attention weights from autoregressive model. + :param mels: mel spectrograms. + :param phonemes: phoneme sequence. + :param weighted: if True use weighted average of durations of heads, best head if False. + :param binary: if True take maximum attention peak, sum if False. + :param fill_gaps: if True fills zeros durations with ones. + :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate. + :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration + needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum. + :return: + """ + assert (binary is True) or (fix_jumps is False), 'Cannot fix jumps in non-binary attention.' + mel_pad_mask = create_mel_padding_mask(mels) + phon_pad_mask = create_encoder_padding_mask(phonemes) + durations = [] + # remove start end token or vector + unpad_mels = [] + unpad_phonemes = [] + final_alignment = [] + for i, al in enumerate(batch_alignments): + mel_len = int(mel_pad_mask[i].shape[-1] - np.sum(mel_pad_mask[i])) + phon_len = int(phon_pad_mask[i].shape[-1] - np.sum(phon_pad_mask[i])) + unpad_alignments = al[:, 1:mel_len - 1, 1:phon_len - 1] # first dim is heads + unpad_mels.append(mels[i, 1:mel_len - 1, :]) + unpad_phonemes.append(phonemes[i, 1:phon_len - 1]) + alignments_weights = weight_mask(unpad_alignments[0]) + heads_scores = [] + scored_attention = [] + for _, attention_weights in enumerate(unpad_alignments): + score = np.sum(alignments_weights * attention_weights) + scored_attention.append(attention_weights / score) + heads_scores.append(score) + + if weighted: + ref_attention_weights = np.sum(scored_attention, axis=0) + else: + best_head = np.argmin(heads_scores) + ref_attention_weights = unpad_alignments[best_head] + + if binary: # pick max attention for each mel time-step + binary_attn, binary_score = binary_attention(ref_attention_weights) + if fix_jumps: + binary_attn = fix_attention_jumps( + binary_attn=binary_attn, + alignments_weights=alignments_weights, + binary_score=binary_score) + integer_durations = binary_attn.sum(axis=0) + + else: # takes actual attention values and normalizes to mel_len + attention_durations = np.sum(ref_attention_weights, axis=0) + normalized_durations = attention_durations * ((mel_len - 2) / np.sum(attention_durations)) + integer_durations = np.round(normalized_durations) + tot_duration = np.sum(integer_durations) + duration_diff = tot_duration - (mel_len - 2) + while duration_diff != 0: + rounding_diff = integer_durations - normalized_durations + if duration_diff > 0: # duration is too long -> reduce highest (positive) rounding difference + max_error_idx = np.argmax(rounding_diff) + integer_durations[max_error_idx] -= 1 + elif duration_diff < 0: # duration is too short -> increase lowest (negative) rounding difference + min_error_idx = np.argmin(rounding_diff) + integer_durations[min_error_idx] += 1 + tot_duration = np.sum(integer_durations) + duration_diff = tot_duration - (mel_len - 2) + + if fill_gaps: # fill zeros durations + integer_durations = fill_zeros(integer_durations, take_from=fill_mode) + + assert np.sum(integer_durations) == mel_len - 2, f'{np.sum(integer_durations)} vs {mel_len - 2}' + new_alignment = duration_to_alignment_matrix(integer_durations.astype(int)) + best_head = np.argmin(heads_scores) + best_attention = unpad_alignments[best_head] + final_alignment.append(best_attention.T + new_alignment) + durations.append(integer_durations) + return durations, unpad_mels, unpad_phonemes, final_alignment diff --git a/utils/config_manager.py b/utils/config_manager.py index c6924c6..1c958df 100644 --- a/utils/config_manager.py +++ b/utils/config_manager.py @@ -6,14 +6,17 @@ import tensorflow as tf import ruamel.yaml -from model.models import AutoregressiveTransformer +from model.models import AutoregressiveTransformer, ForwardTransformer from utils.scheduling import piecewise_linear_schedule, reduction_schedule class ConfigManager: - def __init__(self, config_path: str, session_name: str = None): + def __init__(self, config_path: str, model_kind: str, session_name: str = None): + if model_kind not in ['autoregressive', 'forward']: + raise TypeError(f"model_kind must be in {['autoregressive', 'forward']}") self.config_path = Path(config_path) + self.model_kind = model_kind self.yaml = ruamel.yaml.YAML() self.config, self.data_config, self.model_config = self._load_config() self.git_hash = self._get_git_hash() @@ -23,12 +26,15 @@ def __init__(self, config_path: str, session_name: str = None): self.session_name = '_'.join(filter(None, [self.config_path.name, session_name])) self.base_dir, self.log_dir, self.train_datadir, self.weights_dir = self._make_folder_paths() self.learning_rate = np.array(self.config['learning_rate_schedule'])[0, 1].astype(np.float32) - self.max_r = np.array(self.config['reduction_factor_schedule'])[0, 1].astype(np.int32) - self.stop_scaling = self.config.get('stop_loss_scaling', 1.) + if model_kind == 'autoregressive': + self.max_r = np.array(self.config['reduction_factor_schedule'])[0, 1].astype(np.int32) + self.stop_scaling = self.config.get('stop_loss_scaling', 1.) def _load_config(self): - data_config = self.yaml.load(open(str(self.config_path / 'data_config.yaml'), 'rb')) - model_config = self.yaml.load(open(str(self.config_path / f'model_config.yaml'), 'rb')) + with open(str(self.config_path / 'data_config.yaml'), 'rb') as data_yaml: + data_config = self.yaml.load(data_yaml) + with open(str(self.config_path / f'{self.model_kind}_config.yaml'), 'rb') as model_yaml: + model_config = self.yaml.load(model_yaml) all_config = {} all_config.update(model_config) all_config.update(data_config) @@ -51,8 +57,8 @@ def _check_hash(self): def _make_folder_paths(self): base_dir = Path(self.config['log_directory']) / self.session_name - log_dir = base_dir / f'training_logs' - weights_dir = base_dir / f'model_weights' + log_dir = base_dir / f'{self.model_kind}_logs' + weights_dir = base_dir / f'{self.model_kind}_weights' train_datadir = self.config['train_data_directory'] if train_datadir is None: train_datadir = self.config['data_directory'] @@ -86,31 +92,64 @@ def update_config(self): def get_model(self, ignore_hash=False): if not ignore_hash: self._check_hash() - return AutoregressiveTransformer(mel_channels=self.config['mel_channels'], - encoder_model_dimension=self.config['encoder_model_dimension'], - decoder_model_dimension=self.config['decoder_model_dimension'], - encoder_num_heads=self.config['encoder_num_heads'], - decoder_num_heads=self.config['decoder_num_heads'], - encoder_feed_forward_dimension=self.config['encoder_feed_forward_dimension'], - decoder_feed_forward_dimension=self.config['decoder_feed_forward_dimension'], - encoder_maximum_position_encoding=self.config['encoder_max_position_encoding'], - decoder_maximum_position_encoding=self.config['decoder_max_position_encoding'], - encoder_dense_blocks=self.config['encoder_dense_blocks'], - decoder_dense_blocks=self.config['decoder_dense_blocks'], - decoder_prenet_dimension=self.config['decoder_prenet_dimension'], - encoder_prenet_dimension=self.config['encoder_prenet_dimension'], - postnet_conv_filters=self.config['postnet_conv_filters'], - postnet_conv_layers=self.config['postnet_conv_layers'], - postnet_kernel_size=self.config['postnet_kernel_size'], - dropout_rate=self.config['dropout_rate'], - max_r=self.max_r, - mel_start_value=self.config['mel_start_value'], - mel_end_value=self.config['mel_end_value'], - phoneme_language=self.config['phoneme_language'], - debug=self.config['debug']) + if self.model_kind == 'autoregressive': + return AutoregressiveTransformer(mel_channels=self.config['mel_channels'], + encoder_model_dimension=self.config['encoder_model_dimension'], + decoder_model_dimension=self.config['decoder_model_dimension'], + encoder_num_heads=self.config['encoder_num_heads'], + decoder_num_heads=self.config['decoder_num_heads'], + encoder_feed_forward_dimension=self.config[ + 'encoder_feed_forward_dimension'], + decoder_feed_forward_dimension=self.config[ + 'decoder_feed_forward_dimension'], + encoder_maximum_position_encoding=self.config[ + 'encoder_max_position_encoding'], + decoder_maximum_position_encoding=self.config[ + 'decoder_max_position_encoding'], + encoder_dense_blocks=self.config['encoder_dense_blocks'], + decoder_dense_blocks=self.config['decoder_dense_blocks'], + decoder_prenet_dimension=self.config['decoder_prenet_dimension'], + encoder_prenet_dimension=self.config['encoder_prenet_dimension'], + postnet_conv_filters=self.config['postnet_conv_filters'], + postnet_conv_layers=self.config['postnet_conv_layers'], + postnet_kernel_size=self.config['postnet_kernel_size'], + dropout_rate=self.config['dropout_rate'], + max_r=self.max_r, + mel_start_value=self.config['mel_start_value'], + mel_end_value=self.config['mel_end_value'], + phoneme_language=self.config['phoneme_language'], + debug=self.config['debug']) + + else: + return ForwardTransformer(encoder_model_dimension=self.config['encoder_model_dimension'], + decoder_model_dimension=self.config['decoder_model_dimension'], + dropout_rate=self.config['dropout_rate'], + decoder_num_heads=self.config['decoder_num_heads'], + encoder_num_heads=self.config['encoder_num_heads'], + encoder_maximum_position_encoding=self.config['encoder_max_position_encoding'], + decoder_maximum_position_encoding=self.config['decoder_max_position_encoding'], + encoder_feed_forward_dimension=self.config['encoder_feed_forward_dimension'], + decoder_feed_forward_dimension=self.config['decoder_feed_forward_dimension'], + encoder_attention_conv_filters=self.config[ + 'encoder_attention_conv_filters'], + decoder_attention_conv_filters=self.config[ + 'decoder_attention_conv_filters'], + encoder_attention_conv_kernel=self.config['encoder_attention_conv_kernel'], + decoder_attention_conv_kernel=self.config['decoder_attention_conv_kernel'], + mel_channels=self.config['mel_channels'], + postnet_conv_filters=self.config['postnet_conv_filters'], + postnet_conv_layers=self.config['postnet_conv_layers'], + postnet_kernel_size=self.config['postnet_kernel_size'], + encoder_dense_blocks=self.config['encoder_dense_blocks'], + decoder_dense_blocks=self.config['decoder_dense_blocks'], + phoneme_language=self.config['phoneme_language'], + debug=self.config['debug']) def compile_model(self, model): - model._compile(stop_scaling=self.stop_scaling, optimizer=self.new_adam(self.learning_rate)) + if self.model_kind == 'autoregressive': + model._compile(stop_scaling=self.stop_scaling, optimizer=self.new_adam(self.learning_rate)) + else: + model._compile(optimizer=self.new_adam(self.learning_rate)) # TODO: move to model @staticmethod @@ -122,8 +161,10 @@ def new_adam(learning_rate): def dump_config(self): self.update_config() - self.yaml.dump(self.model_config, open(self.base_dir / f'model_config.yaml', 'w')) - self.yaml.dump(self.data_config, open(self.base_dir / 'data_config.yaml', 'w')) + with open(self.base_dir / f'{self.model_kind}_config.yaml', 'w') as model_yaml: + self.yaml.dump(self.model_config, model_yaml) + with open(self.base_dir / 'data_config.yaml', 'w') as data_yaml: + self.yaml.dump(self.data_config, data_yaml) def create_remove_dirs(self, clear_dir: False, clear_logs: False, clear_weights: False): self.base_dir.mkdir(exist_ok=True) @@ -160,8 +201,9 @@ def load_model(self, checkpoint_path: str = None, verbose=True): ckpt.restore(manager.latest_checkpoint) if verbose: print(f'restored weights from {manager.latest_checkpoint} at step {model.step}') - decoder_prenet_dropout = piecewise_linear_schedule(model.step, self.config['decoder_dropout_schedule']) - reduction_factor = reduction_schedule(model.step, self.config['reduction_factor_schedule']) - model.set_constants(decoder_prenet_dropout=decoder_prenet_dropout, - reduction_factor=reduction_factor) + decoder_prenet_dropout = piecewise_linear_schedule(model.step, self.config['decoder_prenet_dropout_schedule']) + reduction_factor = None + if self.model_kind == 'autoregressive': + reduction_factor = reduction_schedule(model.step, self.config['reduction_factor_schedule']) + model.set_constants(reduction_factor=reduction_factor, decoder_prenet_dropout=decoder_prenet_dropout) return model diff --git a/utils/logging.py b/utils/logging.py index e070f3b..9d8f39a 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -34,22 +34,29 @@ def __init__(self, model, log_dir, config, - max_plot_frequency=10): + max_plot_frequency=10, + default_writer='log_dir'): self.model = model self.log_dir = Path(log_dir) self.config = config self.plot_frequency = max_plot_frequency - self.writers = {'log_dir': tf.summary.create_file_writer(str(self.log_dir))} + self.default_writer = default_writer + self.writers = {} + self.add_writer(tag=default_writer, path=self.log_dir, default=True) - def add_writer(self, path): + def add_writer(self, path, tag=None, default=False): """ Adds a writer to self.writers if the writer does not exist already. To avoid spamming writers on disk. - :returns the writer with path as tag + :returns the writer on path with tag tag or path """ - if path not in self.writers.keys(): - self.writers[path] = tf.summary.create_file_writer(str(path)) - return self.writers[path] + if not tag: + tag = path + if tag not in self.writers.keys(): + self.writers[tag] = tf.summary.create_file_writer(str(path)) + if default: + self.default_writer = tag + return self.writers[tag] @property def global_step(self): @@ -61,19 +68,21 @@ def add_scalars(self, tag, dictionary): tf.summary.scalar(name=tag, data=dictionary[k], step=self.global_step) def add_scalar(self, tag, scalar_value): - with self.writers['log_dir'].as_default(): + with self.writers[self.default_writer].as_default(): tf.summary.scalar(name=tag, data=scalar_value, step=self.global_step) - def add_image(self, tag, image): - with self.writers['log_dir'].as_default(): - tf.summary.image(name=tag, data=image, step=self.global_step, max_outputs=4) + def add_image(self, tag, image, step=None): + if step is None: + step = self.global_step + with self.writers[self.default_writer].as_default(): + tf.summary.image(name=tag, data=image, step=step, max_outputs=4) - def add_histogram(self, tag, values): - with self.writers['log_dir'].as_default(): - tf.summary.histogram(name=tag, data=values, step=self.global_step) + def add_histogram(self, tag, values, buckets=None): + with self.writers[self.default_writer].as_default(): + tf.summary.histogram(name=tag, data=values, step=self.global_step, buckets=buckets) def add_audio(self, tag, wav, sr): - with self.writers['log_dir'].as_default(): + with self.writers[self.default_writer].as_default(): tf.summary.audio(name=tag, data=wav, sample_rate=sr, @@ -87,7 +96,7 @@ def display_attention_heads(self, outputs, tag=''): # dim 0 of image_batch is now number of heads batch_plot_path = f'{tag}/{layer}/{k}' self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1)) - + @ignore_exception def display_mel(self, mel, tag='', sr=22050): amp_mel = denormalize(mel, self.config)