Skip to content

Commit

Permalink
A TTS recipe VITS on VCTK dataset (k2-fsa#1380)
Browse files Browse the repository at this point in the history
* init

* isort formatted

* minor updates

* Create shared

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare.sh

* updated

* Update train.py

* Update train.py

* Update tts_datamodule.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* fixed formatting issue

* Update infer.py

* removed redundant files

* Create monotonic_align

* removed redundant files

* created symlinks

* Update prepare.sh

* minor adjustments

* Create requirements_tts.txt

* Update requirements_tts.txt

added version constraints

* Update infer.py

* Update infer.py

* Update infer.py

* updated docs

* Update export-onnx.py

* Update export-onnx.py

* Update test_onnx.py

* updated requirements.txt

* Update test_onnx.py

* Update test_onnx.py

* docs updated

* docs fixed

* minor updates
  • Loading branch information
JinZr authored Dec 6, 2023
1 parent f08af2f commit 735fb9a
Show file tree
Hide file tree
Showing 48 changed files with 2,904 additions and 84 deletions.
1 change: 1 addition & 0 deletions docs/source/recipes/TTS/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ TTS
:maxdepth: 2

ljspeech/vits
vctk/vits
12 changes: 11 additions & 1 deletion docs/source/recipes/TTS/ljspeech/vits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ VITS
This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.

.. note::

TTS related recipes require packages in ``requirements-tts.txt``.

.. note::

The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
Expand All @@ -27,6 +31,12 @@ To run stage 1 to stage 5, use
Build Monotonic Alignment Search
--------------------------------

.. code-block:: bash
$ ./prepare.sh --stage -1 --stop_stage -1
or

.. code-block:: bash
$ cd vits/monotonic_align
Expand Down Expand Up @@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--tokens data/tokens.txt \
--max-duration 500
.. note::
Expand Down
125 changes: 125 additions & 0 deletions docs/source/recipes/TTS/vctk/vits.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
VITS
===============

This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.

.. note::

TTS related recipes require packages in ``requirements-tts.txt``.

.. note::

The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_


Data preparation
----------------

.. code-block:: bash
$ cd egs/vctk/TTS
$ ./prepare.sh
To run stage 1 to stage 6, use

.. code-block:: bash
$ ./prepare.sh --stage 1 --stop_stage 6
Build Monotonic Alignment Search
--------------------------------

To build the monotonic alignment search, use the following commands:

.. code-block:: bash
$ ./prepare.sh --stage -1 --stop_stage -1
or

.. code-block:: bash
$ cd vits/monotonic_align
$ python setup.py build_ext --inplace
$ cd ../../
Training
--------

.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
$ ./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350
.. note::

You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.

.. note::

The training can take a long time (usually a couple of days).

Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.


Inference
---------

The inference part uses checkpoints saved by the training part, so you have to run the
training part first. It will save the ground-truth and generated wavs to the directory
``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.

.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0"
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt \
--max-duration 500
.. note::

For more details, please run ``./vits/infer.py --help``.


Export models
-------------

Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.

.. code-block:: bash
$ ./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
You can test the exported ONNX model with:

.. code-block:: bash
$ ./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
Download pretrained models
--------------------------

If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:

- `<https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05>`_
16 changes: 12 additions & 4 deletions egs/ljspeech/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

set -eou pipefail

nj=1
stage=-1
stage=0
stop_stage=100

dl_dir=$PWD/download
Expand All @@ -25,6 +24,17 @@ log() {

log "dl_dir: $dl_dir"

if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
fi
fi

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"

Expand Down Expand Up @@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--tokens data/tokens.txt
fi
fi


1 change: 0 additions & 1 deletion egs/ljspeech/TTS/vits/duration_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.nn.functional as F

from flow import (
ConvFlow,
DilatedDepthSeparableConv,
Expand Down
8 changes: 7 additions & 1 deletion egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,13 @@ def export_model_onnx(
model_filename,
verbose=False,
opset_version=opset_version,
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
Expand Down
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/vits/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Optional, Tuple, Union

import torch

from transform import piecewise_rational_quadratic_transform


Expand Down
5 changes: 2 additions & 3 deletions egs/ljspeech/TTS/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
import numpy as np
import torch
import torch.nn.functional as F

from icefall.utils import make_pad_mask

from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments

from icefall.utils import make_pad_mask


class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
Expand Down
29 changes: 20 additions & 9 deletions egs/ljspeech/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@
import torch
import torch.nn as nn
import torchaudio

from train import get_model, get_params
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule


def get_parser():
Expand Down Expand Up @@ -107,12 +106,12 @@ def _save_worker(
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
audio[i:i + 1, :audio_lens[i]],
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
audio_pred[i:i + 1, :audio_lens_pred[i]],
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)

Expand Down Expand Up @@ -144,14 +143,24 @@ def _save_worker(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]

audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)

futures.append(
executor.submit(
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
_save_worker,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)

Expand All @@ -160,7 +169,9 @@ def _save_worker(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()
Expand Down
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/vits/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
import torch.distributions as D
import torch.nn.functional as F

from lhotse.features.kaldi import Wav2LogFilterBank


Expand Down
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/vits/posterior_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from typing import Optional, Tuple

import torch
from wavenet import Conv1d, WaveNet

from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d


class PosteriorEncoder(torch.nn.Module):
Expand Down
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/vits/residual_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import Optional, Tuple, Union

import torch

from flow import FlipFlow
from wavenet import WaveNet

Expand Down
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

import argparse
import logging

import onnxruntime as ort
import torch
import torchaudio

from tokenizer import Tokenizer


Expand Down
Loading

0 comments on commit 735fb9a

Please sign in to comment.