From 6e9d34f498b41f12f889923566ddaafd6f50e8cc Mon Sep 17 00:00:00 2001 From: zyingt <88301578+zyingt@users.noreply.github.com> Date: Fri, 23 Feb 2024 23:03:35 +0800 Subject: [PATCH] Support Multi-speaker VITS (#131) Support Multi-speaker VITS & Hi-Fi TTS dataset preprocessing --- bins/tts/preprocess.py | 4 +- config/tts.json | 1 + egs/datasets/README.md | 31 ++++++++++++ egs/tts/VITS/README.md | 79 +++++++++++++++++++++++-------- egs/tts/VITS/exp_config.json | 21 +++++--- egs/tts/VITS/run.sh | 23 ++++++--- models/tts/base/tts_dataset.py | 3 ++ models/tts/base/tts_inferece.py | 12 ++++- models/tts/vits/vits_dataset.py | 25 ++++++++-- models/tts/vits/vits_inference.py | 15 +++++- preprocessors/hifitts.py | 2 +- preprocessors/processor.py | 3 ++ processors/phone_extractor.py | 7 +-- utils/data_utils.py | 13 +++++ 14 files changed, 191 insertions(+), 48 deletions(-) diff --git a/bins/tts/preprocess.py b/bins/tts/preprocess.py index 39e955c8..914c0b44 100644 --- a/bins/tts/preprocess.py +++ b/bins/tts/preprocess.py @@ -88,11 +88,11 @@ def extract_phonme_sequences(dataset, output_path, cfg, dataset_types): dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type)) with open(dataset_file, "r") as f: metadata.extend(json.load(f)) - phone_extractor.extract_utt_phone_sequence(cfg, metadata) + phone_extractor.extract_utt_phone_sequence(dataset, cfg, metadata) def preprocess(cfg, args): - """Proprocess raw data of single or multiple datasets (in cfg.dataset) + """Preprocess raw data of single or multiple datasets (in cfg.dataset) Args: cfg (dict): dictionary that stores configurations diff --git a/config/tts.json b/config/tts.json index 31b53df3..882726db 100644 --- a/config/tts.json +++ b/config/tts.json @@ -16,6 +16,7 @@ // Directory names of processed data or extracted features "phone_dir": "phones", "use_phone": true, + "add_blank": true }, "model": { "text_token_num": 512, diff --git a/egs/datasets/README.md b/egs/datasets/README.md index 78a93266..426a4754 100644 --- a/egs/datasets/README.md +++ b/egs/datasets/README.md @@ -6,6 +6,7 @@ Amphion support the following academic datasets (sort alphabetically): - [AudioCaps](#audiocaps) - [CSD](#csd) - [CustomSVCDataset](#customsvcdataset) + - [Hi-Fi TTS](#hifitts) - [KiSing](#kising) - [LibriLight](#librilight) - [LibriTTS](#libritts) @@ -75,6 +76,36 @@ We support custom dataset for Singing Voice Conversion. Organize your data in th ┣ ... ``` + +## Hi-Fi TTS + +Download the official Hi-Fi TTS dataset [here](https://www.openslr.org/109/). The file structure looks like below: + +```plaintext +[Hi-Fi TTS dataset path] + ┣ audio + ┃ ┣ 11614_other {Speaker_ID}_{SNR_subset} + ┃ ┃ ┣ 10547 {Book_ID} + ┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0001.flac + ┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0003.flac + ┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0004.flac + ┃ ┃ ┃ ┣ ... + ┃ ┃ ┣ ... + ┃ ┣ ... + ┣ 92_manifest_clean_dev.json + ┣ 92_manifest_clean_test.json + ┣ 92_manifest_clean_train.json + ┣ ... + ┣ {Speaker_ID}_manifest_{SNR_subset}_{dataset_split}.json + ┣ ... + ┣ books_bandwidth.tsv + ┣ LICENSE.txt + ┣ readers_books_clean.txt + ┣ readers_books_other.txt + ┣ README.txt + +``` + ## KiSing Download the official KiSing dataset [here](http://shijt.site/index.php/2021/05/16/kising-the-first-open-source-mandarin-singing-voice-synthesis-corpus/). The file structure looks like below: diff --git a/egs/tts/VITS/README.md b/egs/tts/VITS/README.md index a5147c3a..ff489419 100644 --- a/egs/tts/VITS/README.md +++ b/egs/tts/VITS/README.md @@ -3,7 +3,7 @@ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/Text-to-Speech) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/Text-to-Speech) -In this recipe, we will show how to train [VITS](https://arxiv.org/abs/2106.06103) using Amphion's infrastructure. VITS is an end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning. +In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning. There are four stages in total: @@ -20,7 +20,7 @@ There are four stages in total: ## 1. Data Preparation ### Dataset Download -You can use the commonly used TTS dataset to train TTS model, e.g., LJSpeech, VCTK, LibriTTS, etc. We strongly recommend you use LJSpeech to train TTS model for the first time. How to download dataset is detailed [here](../../datasets/README.md). +You can use the commonly used TTS dataset to train TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train single-speaker TTS model for the first time. While for training multi-speaker TTS model for the first time, we would recommend using Hi-Fi TTS. The process of downloading dataset has been detailed [here](../../datasets/README.md). ### Configuration @@ -29,10 +29,12 @@ After downloading the dataset, you can set the dataset paths in `exp_config.jso ```json "dataset": [ "LJSpeech", + //"hifitts" ], "dataset_path": { // TODO: Fill in your dataset path "LJSpeech": "[LJSpeech dataset path]", + //"hifitts": "[Hi-Fi TTS dataset path] }, ``` @@ -40,21 +42,28 @@ After downloading the dataset, you can set the dataset paths in `exp_config.jso ### Configuration -Specify the `processed_dir` and the `log_dir` and for saving the processed data and the checkpoints in `exp_config.json`: +In `exp_config.json`, specify the `log_dir` for saving the checkpoints and logs, and specify the `processed_dir` for saving processed data. For preprocessing the multi-speaker TTS dataset, set `extract_audio` and `use_spkid` to `true`: ```json // TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts" "log_dir": "ckpts/tts", "preprocess": { + //"extract_audio": true, + "use_phone": true, + // linguistic features + "extract_phone": true, + "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)" // TODO: Fill in the output data path. The default value is "Amphion/data" "processed_dir": "data", - ... + "sample_rate": 22050, //target sampling rate + "valid_file": "valid.json", //validation set + //"use_spkid": true, //use speaker ID to train multi-speaker TTS model }, ``` ### Run -Run the `run.sh` as the preproces stage (set `--stage 1`): +Run the `run.sh` as the preprocess stage (set `--stage 1`): ```bash sh egs/tts/VITS/run.sh --stage 1 @@ -66,17 +75,22 @@ sh egs/tts/VITS/run.sh --stage 1 ### Configuration -We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on your GPU machines. +We provide the default hyparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines. +For training the multi-speaker TTS model, specify the `n_speakers` value to be greater (used for new speaker fine-tuning) than or equal to the number of speakers in your dataset(s) and set `multi_speaker_training` to `true`. -``` -"train": { - "batch_size": 16, - } +```json + "model": { + //"n_speakers": 10 //Number of speakers in the dataset(s) used. The default value is 0 if not specified. + }, + "train": { + "batch_size": 16, + //"multi_speaker_training": true, + } ``` ### Train From Scratch -Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`. +Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`. ```bash sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] @@ -139,12 +153,35 @@ For inference, you need to specify the following configurations when running `ru | `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` | | `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` | | `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. | -| `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`. | -| `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be  "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from test set as template testing set. | +| `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`.
For Hi-Fi TTS dataset, the inference dataset would be `hifitts`. | +| `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be  "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from the test set as template testing set.
For Hi-Fi TTS dataset, the testing set would be "`test`" split from Hi-Fi TTS during the feature extraction process. | | `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" | +| `--infer_speaker_name` | The target speaker's voice is to be synthesized.
(***Note: only applicable to multi-speaker TTS model***) | For Hi-Fi TTS dataset, the list of available speakers includes: "`hifitts_11614`", "`hifitts_11697`", "`hifitts_12787`", "`hifitts_6097`", "`hifitts_6670`", "`hifitts_6671`", "`hifitts_8051`", "`hifitts_9017`", "`hifitts_9136`", "`hifitts_92`".
You may find the list of available speakers from `spk2id.json` file generated in ```log_dir/[YourExptName]``` that you have specified in `exp_config.json`. | ### Run -For example, if you want to generate speech of all testing set split from LJSpeech, just run: +#### Single text inference: +For the single-speaker TTS model, if you want to generate a single clip of speech from a given text, just run: + +```bash +sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \ + --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \ + --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \ + --infer_mode "single" \ + --infer_text "This is a clip of generated speech with the given text from a TTS model." +``` + +For the multi-speaker TTS model, in addition to the above-mentioned arguments, you need to add ```infer_speaker_name``` argument, and run: +```bash +sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \ + --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \ + --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \ + --infer_mode "single" \ + --infer_text "This is a clip of generated speech with the given text from a TTS model." \ + --infer_speaker_name "hifitts_92" +``` + +#### Batch inference: +For the single-speaker TTS model, if you want to generate speech of all testing sets split from LJSpeech, just run: ```bash sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \ @@ -154,18 +191,18 @@ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \ --infer_dataset "LJSpeech" \ --infer_testing_set "test" ``` - -Or, if you want to generate a single clip of speech from a given text, just run: - +For the multi-speaker TTS model, if you want to generate speech of all testing sets split from Hi-Fi TTS, the same procedure follows from above, with ```LJSpeech``` replaced by ```hifitts```. ```bash sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \ - --infer_mode "single" \ - --infer_text "This is a clip of generated speech with the given text from a TTS model." + --infer_mode "batch" \ + --infer_dataset "hifitts" \ + --infer_testing_set "test" ``` -We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instruction. + +We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instruction. Meanwhile, the pre-trained multi-speaker VITS model trained on Hi-Fi TTS will be released soon. Stay tuned. ```bibtex @@ -176,4 +213,4 @@ We released a pre-trained Amphion VITS model trained on LJSpeech. So you can dow pages={5530--5540}, year={2021}, } -``` \ No newline at end of file +``` diff --git a/egs/tts/VITS/exp_config.json b/egs/tts/VITS/exp_config.json index b210a265..3a2332f2 100644 --- a/egs/tts/VITS/exp_config.json +++ b/egs/tts/VITS/exp_config.json @@ -2,26 +2,33 @@ "base_config": "config/vits.json", "model_type": "VITS", "dataset": [ - "LJSpeech" + "LJSpeech", + //"hifitts" ], "dataset_path": { // TODO: Fill in your dataset path - "LJSpeech": "[LJSpeech dataset path]" + "LJSpeech": "[LJSpeech dataset path]", + //"hifitts": "[Hi-Fi TTS dataset path] }, // TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts" "log_dir": "ckpts/tts", "preprocess": { + //"extract_audio":true, "use_phone": true, // linguistic features "extract_phone": true, - "phone_extractor": "lexicon", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)" + "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)" // TODO: Fill in the output data path. The default value is "Amphion/data" "processed_dir": "data", - - "sample_rate": 22050, - "valid_file": "test.json", // validattion set + "sample_rate": 22050, // target sampling rate + "valid_file": "valid.json", // validation set + //"use_spkid": true // use speaker ID to train multi-speaker TTS model + }, + "model":{ + //"n_speakers": 10 // number of speakers, greater than or equal to the number of speakers in the dataset(s) used. The default value is 0 if not specified. }, "train": { "batch_size": 16, + //"multi_speaker_training": true } -} \ No newline at end of file +} diff --git a/egs/tts/VITS/run.sh b/egs/tts/VITS/run.sh index ad63b425..dd702795 100644 --- a/egs/tts/VITS/run.sh +++ b/egs/tts/VITS/run.sh @@ -18,7 +18,7 @@ cd $work_dir ######## Parse the Given Parameters from the Commond ########### # options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir:,name:,stage: -- "$@") -options=$(getopt -o c:n:s --long gpu:,config:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage: -- "$@") +options=$(getopt -o c:n:s --long gpu:,config:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,infer_speaker_name:,name:,stage: -- "$@") eval set -- "$options" while true; do @@ -43,14 +43,16 @@ while true; do --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;; # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result" --infer_output_dir) shift; infer_output_dir=$1 ; shift ;; - # [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generage a single clip of speech. + # [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generate a single clip of speech. --infer_mode) shift; infer_mode=$1 ; shift ;; - # [Only for Inference] The inference dataset. It is only used when the inference model is "batch". + # [Only for Inference] The inference dataset. It is only used when the inference mode is "batch". --infer_dataset) shift; infer_dataset=$1 ; shift ;; - # [Only for Inference] The inference testing set. It is only used when the inference model is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set. + # [Only for Inference] The inference testing set. It is only used when the inference mode is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set. --infer_testing_set) shift; infer_testing_set=$1 ; shift ;; - # [Only for Inference] The text to be synthesized from. It is only used when the inference model is "single". + # [Only for Inference] The text to be synthesized from. It is only used when the inference mode is "single". --infer_text) shift; infer_text=$1 ; shift ;; + # [Only for Inference] The chosen speaker's voice to be synthesized. It is only used when the inference mode is "single" for multi-speaker VITS. + --infer_speaker_name) shift; infer_speaker_name=$1 ; shift ;; --) shift ; break ;; *) echo "Invalid option: $1" exit 1 ;; @@ -67,7 +69,7 @@ fi if [ -z "$exp_config" ]; then exp_config="${exp_dir}"/exp_config.json fi -echo "Exprimental Configuration File: $exp_config" +echo "Experimental Configuration File: $exp_config" if [ -z "$gpu" ]; then gpu="0" @@ -86,7 +88,7 @@ if [ $running_stage -eq 2 ]; then echo "[Error] Please specify the experiments name" exit 1 fi - echo "Exprimental Name: $exp_name" + echo "Experimental Name: $exp_name" # add default value if [ -z "$resume_from_ckpt_path" ]; then @@ -153,6 +155,12 @@ if [ $running_stage -eq 3 ]; then elif [ "$infer_mode" = "batch" ]; then infer_text='' fi + + if [ -z "$infer_speaker_name" ]; then + infer_speaker_name=None + fi + + CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/tts/inference.py \ @@ -163,6 +171,7 @@ if [ $running_stage -eq 3 ]; then --dataset $infer_dataset \ --testing_set $infer_testing_set \ --text "$infer_text" \ + --speaker_name $infer_speaker_name \ --log_level debug diff --git a/models/tts/base/tts_dataset.py b/models/tts/base/tts_dataset.py index b3f6ac7c..fc85afb9 100644 --- a/models/tts/base/tts_dataset.py +++ b/models/tts/base/tts_dataset.py @@ -209,6 +209,9 @@ def __init__(self, cfg, dataset, is_valid=False): phon_id_collator = phoneIDCollation(cfg, dataset=dataset) sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq) + if cfg.preprocess.add_blank: + sequence = intersperse(sequence, 0) + self.utt2seq[utt] = sequence def __getitem__(self, index): diff --git a/models/tts/base/tts_inferece.py b/models/tts/base/tts_inferece.py index cb09a4d4..f49ace0f 100644 --- a/models/tts/base/tts_inferece.py +++ b/models/tts/base/tts_inferece.py @@ -12,6 +12,7 @@ from tqdm import tqdm from accelerate.logging import get_logger from torch.utils.data import DataLoader +from safetensors.torch import load_file from abc import abstractmethod @@ -162,7 +163,16 @@ def _load_model( ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) checkpoint_path = ls[0] - self.accelerator.load_state(str(checkpoint_path)) + if ( + Path(os.path.join(checkpoint_path, "model.safetensors")).exists() + and accelerate.__version__ < "0.25" + ): + self.model.load_state_dict( + load_file(os.path.join(checkpoint_path, "model.safetensors")), + strict=False, + ) + else: + self.accelerator.load_state(str(checkpoint_path)) return str(checkpoint_path) def inference(self): diff --git a/models/tts/vits/vits_dataset.py b/models/tts/vits/vits_dataset.py index cd596894..e3a1444b 100644 --- a/models/tts/vits/vits_dataset.py +++ b/models/tts/vits/vits_dataset.py @@ -27,6 +27,26 @@ def __getitem__(self, index): def __len__(self): return super().__len__() + def get_metadata(self): + metadata_filter = [] + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + for utt_info in metadata: + duration = utt_info["Duration"] + frame_len = ( + duration + * self.cfg.preprocess.sample_rate + // self.cfg.preprocess.hop_size + ) + if ( + frame_len + < self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size + ): + continue + metadata_filter.append(utt_info) + + return metadata_filter + class VITSCollator(TTSCollator): """Zero-pads model inputs and targets based on number of frames per step""" @@ -42,11 +62,8 @@ def __call__(self, batch): class VITSTestDataset(TTSTestDataset): def __init__(self, args, cfg): super().__init__(args, cfg) - + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset) if cfg.preprocess.use_spkid: - processed_data_dir = os.path.join( - cfg.preprocess.processed_dir, args.dataset - ) spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id) with open(spk2id_path, "r") as f: self.spk2id = json.load(f) diff --git a/models/tts/vits/vits_inference.py b/models/tts/vits/vits_inference.py index 6c3d385a..5e28858a 100644 --- a/models/tts/vits/vits_inference.py +++ b/models/tts/vits/vits_inference.py @@ -14,6 +14,7 @@ from models.tts.vits.vits import SynthesizerTrn from processors.phone_extractor import phoneExtractor from text.text_token_collation import phoneIDCollation +from utils.data_utils import * class VitsInference(TTSInference): @@ -120,6 +121,9 @@ def inference_for_single_utterance( ) phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) + if self.cfg.preprocess.add_blank: + phone_id_seq = intersperse(phone_id_seq, 0) + # convert phone sequence to phone id sequence phone_id_seq = np.array(phone_id_seq) phone_id_seq = torch.from_numpy(phone_id_seq) @@ -130,8 +134,15 @@ def inference_for_single_utterance( spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) with open(spk2id_file, "r") as f: spk2id = json.load(f) - speaker_id = spk2id[self.args.speaker_name] - speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32)) + speaker_name = self.args.speaker_name + assert ( + speaker_name in spk2id + ), f"Speaker {speaker_name} not found in the spk2id keys. \ + Please make sure you've specified the correct speaker name in infer_speaker_name." + speaker_id = spk2id[speaker_name] + speaker_id = torch.from_numpy( + np.array([speaker_id], dtype=np.int32) + ).unsqueeze(0) with torch.no_grad(): x_tst = phone_id_seq.to(self.device).unsqueeze(0) diff --git a/preprocessors/hifitts.py b/preprocessors/hifitts.py index bb5b2317..78b069ca 100644 --- a/preprocessors/hifitts.py +++ b/preprocessors/hifitts.py @@ -60,7 +60,7 @@ def main(output_path, dataset_path): entry = json.loads(line) utt_path = entry.get("audio_filepath") chosen_book = utt_path.split("/")[-2] - chosen_uid = utt_path.split("/")[-1] + chosen_uid = utt_path.split("/")[-1].split(".")[0] duration = entry.get("duration") text = entry.get("text_normalized") path = os.path.join(hifitts_path, utt_path) diff --git a/preprocessors/processor.py b/preprocessors/processor.py index 1a1d0362..037ac6c5 100644 --- a/preprocessors/processor.py +++ b/preprocessors/processor.py @@ -29,6 +29,7 @@ vocalist, ljspeech_vocoder, librilight, + hifitts, ) @@ -93,6 +94,8 @@ def preprocess_dataset( vocalist.main(output_path, dataset_path) if dataset == "librilight": librilight.main(output_path, dataset_path, cfg) + if dataset == "hifitts": + hifitts.main(output_path, dataset_path) def prepare_align(dataset, dataset_path, cfg, output_path): diff --git a/processors/phone_extractor.py b/processors/phone_extractor.py index 676e919c..f5508192 100644 --- a/processors/phone_extractor.py +++ b/processors/phone_extractor.py @@ -47,7 +47,7 @@ def __init__(self, cfg, dataset_name=None, phone_symbol_file=None): assert cfg.preprocess.lexicon_path != "" self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path) else: - print("No suppert to", cfg.preprocess.phone_extractor) + print("No support to", cfg.preprocess.phone_extractor) raise def extract_phone(self, text): @@ -95,16 +95,17 @@ def save_dataset_phone_symbols_to_table(self): phone_symbol_dict.to_file(self.phone_symbols_file) -def extract_utt_phone_sequence(cfg, metadata): +def extract_utt_phone_sequence(dataset, cfg, metadata): """ Extract phone sequence from text Args: + dataset (str): name of dataset, e.g. opencpop cfg: config metadata: list of dict, each dict contains "Uid", "Text" """ - dataset_name = cfg.dataset[0] + dataset_name = dataset # output path out_path = os.path.join( diff --git a/utils/data_utils.py b/utils/data_utils.py index 7976d050..8c0bc2ff 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -12,6 +12,19 @@ from sklearn.preprocessing import StandardScaler +def intersperse(lst, item): + """ + Insert an item in between any two consecutive elements of the given list, including beginning and end of list + + Example: + >>> intersperse(0, [1, 74, 5, 31]) + [0, 1, 0, 74, 0, 5, 0, 31, 0] + """ + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + def load_content_feature_path(meta_data, processed_dir, feat_dir): utt2feat_path = {} for utt_info in meta_data: