Skip to content

Commit

Permalink
Merge pull request #12 from as-ideas/mergeable_dev
Browse files Browse the repository at this point in the history
Add Forward Model. Fix Autoregressive.
  • Loading branch information
cfrancesco authored May 29, 2020
2 parents 2f3a1b5 + cb2fe3a commit a138f50
Show file tree
Hide file tree
Showing 20 changed files with 1,970 additions and 791 deletions.
55 changes: 44 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,41 @@
<p>A Text-to-Speech Transformer in TensorFlow 2</p>
</h2>

Implementation of an autoregressive Transformer based neural network for Text-to-Speech (TTS). <br>
This repo is based on the following paper:
Implementation of a non-autoregressive Transformer based neural network for Text-to-Speech (TTS). <br>
This repo is based on the following papers:
- [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263)

Spectrograms produced with LJSpeech and standard data configuration from this repo are compatible with [WaveRNN](https://github.com/fatchord/WaveRNN).

#### Non-Autoregressive
Being non-autoregressive, this Transformer model is:
- Robust: No repeats and failed attention modes for challenging sentences.
- Fast: With no autoregression, predictions take a fraction of the time.
- Controllable: It is possible to control the speed of the generated utterance.

## 🔈 Samples

[Can be found here.](https://as-ideas.github.io/TransformerTTS/)

These samples' spectrograms are converted using the pre-trained [WaveRNN](https://github.com/fatchord/WaveRNN) vocoder.<br>

The TTS weights used for these samples can be found [here](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts).

Check out the notebooks folder for predictions with TransformerTTS and WaveRNN or just try out our Colab notebook:
Try it out on Colab:

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize.ipynb)
| Version | Colab Link |
|---|---|
| Forward | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize_forward.ipynb) |
Autoregressive | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/master/notebooks/synthesize_autoregressive.ipynb) |

## 📖 Contents
- [Installation](#installation)
- [Dataset](#dataset)
- [Training](#training)
- [Autoregressive](#train-autoregressive-model)
- [Forward](#train-forward-model)
- [Prediction](#prediction)
- [Model Weights](#model_weights)

## Installation

Expand Down Expand Up @@ -69,16 +81,28 @@ Prepare a dataset in the following format:
where `metadata.csv` has the following format:
``` wav_file_name|transcription ```

## Training
### Train Autoregressive Model
#### Create training dataset
```bash
python create_dataset.py --config config/standard
```

## Training
#### Training
```bash
python train.py --config config/standard
python train_autoregressive.py --config config/standard
```
### Train Forward Model
#### Compute alignment dataset
First use the autoregressive model to create the durations dataset
```bash
python extract_durations.py --config config/standard --binary --fix_jumps --fill_mode_next
```
this will add an additional folder to the dataset folder containing the new datasets for validation and training of the forward model.<br>
If the rhythm of the trained model is off, play around with the flags of this script to fix the durations.
#### Training
```bash
python train_forward.py --config /path/to/config_folder/
```

#### Training & Model configuration
- Training and model settings can be configured in `model_config.yaml`

Expand All @@ -92,25 +116,34 @@ We log some information that can be visualized with TensorBoard:
tensorboard --logdir /logs/directory/
```

![Tensorboard Demo](https://raw.githubusercontent.com/as-ideas/TransformerTTS/master/docs/tboard_demo.gif)

## Prediction
Predict with either the Forward or Autoregressive model
```python
from utils.config_manager import ConfigManager
from utils.audio import reconstruct_waveform

config_loader = ConfigManager('config/standard')
config_loader = ConfigManager('/path/to/config/', model_kind='forward')
model = config_loader.load_model()
out = model.predict('Please, say something.')

# Convert spectrogram to wav (with griffin lim)
wav = reconstruct_waveform(out['mel'].numpy().T, config=config_loader.config)
```

## Model Weights
| Model URL | Commit |
|---|---|
|[ljspeech_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_forward_transformer.zip)| 4945e775b|
[ljspeech_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_autoregressive_transformer.zip)| 4945e775b|
|[ljspeech_autoregressive_model_v1](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts)| 2f3a1b5|
## Maintainers
* Francesco Cardinale, github: [cfrancesco](https://github.com/cfrancesco)

## Special thanks
[WaveRNN](https://github.com/fatchord/WaveRNN): we took the data processing from here and use their vocoder to produce the samples. <br>
[Erogol](https://github.com/erogol): for the lively exchange on TTS topics. <br>
[Erogol](https://github.com/erogol) and the Mozilla TTS team for the lively exchange on the topic. <br>

## Copyright
See [LICENSE](LICENSE) for details.
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,18 @@ stop_loss_scaling: 8

# TRAINING
dropout_rate: 0.1
decoder_dropout_schedule: # dropout scheduling for the decoder status
- [0, 0.54]
decoder_prenet_dropout_schedule:
- [0, 0.]
- [25_000, 0.]
- [35_000, .5]
learning_rate_schedule:
- [0, 1.0e-4]
head_drop_schedule: # head-level dropout: how many heads to set to zero at training time
- [0, 0]
- [15_000, 1]
- [30_000, 2]
- [70_000, 3]
- [150_000, 1]
reduction_factor_schedule:
- [0, 10]
- [20_000, 5]
- [50_000, 2]
- [100_000, 1]
- [80_000, 1]
max_steps: 900_000
batch_size: 16
debug: False
Expand All @@ -44,7 +41,7 @@ validation_frequency: 1_000
prediction_frequency: 10_000
weights_save_frequency: 10_000
train_images_plotting_frequency: 1_000
keep_n_weights: 5
keep_n_weights: 2
keep_checkpoint_every_n_hours: 12
n_steps_avg_losses: [100, 500, 1_000, 5_000]
n_predictions: 2 # autoregressive predictions take time
Expand Down
45 changes: 45 additions & 0 deletions config/standard/forward_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ARCHITECTURE
decoder_model_dimension: 256
encoder_model_dimension: 512
decoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers
encoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers
encoder_feed_forward_dimension: 1024
decoder_feed_forward_dimension: 1024
decoder_prenet_dimension: 256
encoder_prenet_dimension: 512
encoder_attention_conv_filters: 512
decoder_attention_conv_filters: 512
encoder_attention_conv_kernel: 3
decoder_attention_conv_kernel: 3
encoder_max_position_encoding: 1000
decoder_max_position_encoding: 10000
postnet_conv_filters: 256
postnet_conv_layers: 5
postnet_kernel_size: 5
encoder_dense_blocks: 1
decoder_dense_blocks: 0

# TRAINING
dropout_rate: 0.1
decoder_dropout_schedule: # dropout scheduling for the decoder status
- [0, 0.]
learning_rate_schedule:
- [0, 1.0e-4]
head_drop_schedule: # head-level dropout: how many heads to set to zero at training time
- [0, 0]
max_steps: 400_000
batch_size: 16
debug: False

# LOGGING
validation_frequency: 1_000
prediction_frequency: 1_000
weights_save_frequency: 5_000
train_images_plotting_frequency: 1_000
keep_n_weights: 5
keep_checkpoint_every_n_hours: 12
n_steps_avg_losses: [100, 500, 1_000, 5_000]
n_predictions: 5
prediction_start_step: 1_000
audio_start_step: 5_000
audio_prediction_frequency: 5_000 # converting to glim takes time
3 changes: 2 additions & 1 deletion create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
for arg in vars(args):
print('{}: {}'.format(arg, getattr(args, arg)))
yaml = ruamel.yaml.YAML()
config = yaml.load(open(str(Path(args.CONFIG) / 'data_config.yaml'), 'rb'))
with open(str(Path(args.CONFIG) / 'data_config.yaml'), 'rb') as conf_yaml:
config = yaml.load(conf_yaml)
args.DATA_DIR = config['data_directory']
args.META_FILE = os.path.join(args.DATA_DIR, config['metadata_filename'])
args.WAV_DIR = os.path.join(args.DATA_DIR, config['wav_subdir_name'])
Expand Down
47 changes: 41 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,52 @@
<p>A Text-to-Speech Transformer in TensorFlow 2</p>
</h2>

<p class="text">All samples are converted with the pre-trained <a href="https://github.com/fatchord/WaveRNN"> WaveRNN </a>vocoder.</p>

### 🎧 Samples from the autoregressive model, converted with [WaveRNN](https://github.com/fatchord/WaveRNN) vocoder
## 🎧 Model samples

<p class="text">President Trump met with other leaders at the Group of twenty conference.</p>

| forward | autoregressive |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/trump.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/Trump.wav?raw=true" controls preload></audio>|

<p class="text">Scientists, at the CERN laboratory, say they have discovered a new particle.</p>
<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/cern_particle.wav?raw=true" controls preload></audio>

| forward | autoregressive |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/scientists.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/cern_particle.wav?raw=true" controls preload></audio>|

<p class="text">There’s a way to measure the acute emotional intelligence that has never gone out of style.</p>
<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/EQ.wav?raw=true" controls preload></audio>

<p class="text">President Trump met with other leaders at the Group of twenty conference.</p>
<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/Trump.wav?raw=true" controls preload></audio>
| forward | autoregressive |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/EQ.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/EQ.wav?raw=true" controls preload></audio>|

<p class="text">The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.</p>
<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/affordablecareact.wav?raw=true" controls preload></audio>

| forward | autoregressive |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/senate.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/affordablecareact.wav?raw=true" controls preload></audio>|


### Robustness

<p class="text">To deliver interfaces that are significantly better suited to create and process RFC eight twenty one , RFC eight twenty two , RFC nine seventy seven , and MIME content.</p>

| forward | autoregressive |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/hard.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_transformertts/hard.wav?raw=true" controls preload></audio>|

### Speed control
<p class="text">For a while the preacher addresses himself to the congregation at large, who listen attentively.</p>

| 10% slower | normal speed | 25% faster |
|:---:|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/speed_090.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/speed_100.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/speed_125.wav?raw=true" controls preload></audio>|

### Comparison with [ForwardTacotron](https://github.com/as-ideas/ForwardTacotron)
<p class="text"> In a statement announcing his resignation, Mr Ross, said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked."</p>
| ForwardTacotron | TransformerTTS |
|:---:|:---:|
|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward/forward_transformer_comparison.wav?raw=true" controls preload></audio>|<audio src="https://github.com/as-ideas/tts_model_outputs/blob/master/ljspeech_forward_transformer/tacotron_comparison.wav?raw=true" controls preload></audio>|
Binary file added docs/tboard_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a138f50

Please sign in to comment.