Skip to content

Commit

Permalink
Merge pull request espnet#5849 from Tsukasane/tts2_aishell3
Browse files Browse the repository at this point in the history
New Recipe of tts2+aishell3
  • Loading branch information
ftshijt authored Aug 22, 2024
2 parents 5df1271 + b8d56cf commit 38cc9e8
Show file tree
Hide file tree
Showing 36 changed files with 860 additions and 11 deletions.
8 changes: 8 additions & 0 deletions egs2/TEMPLATE/asr1/pyscripts/feats/dump_km_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def get_parser():
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--audio_sample_rate",
type=int,
default=16000,
help="input audio sampling rate (could be different from fs used in SSL)",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
Expand Down Expand Up @@ -116,6 +122,7 @@ def __call__(self, x):
def dump_label(
rspecifier,
in_filetype,
audio_sample_rate,
wspecifier,
out_filetype,
km_path,
Expand Down Expand Up @@ -152,6 +159,7 @@ def dump_label(
)
if reader_conf.get("layer", None):
reader_conf["layer"] = int(reader_conf["layer"])
reader_conf["audio_sample_rate"] = audio_sample_rate

reader = reader_class(**reader_conf)
iterator = build_data_iterator(
Expand Down
7 changes: 7 additions & 0 deletions egs2/TEMPLATE/asr1/pyscripts/feats/dump_ssl_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def get_parser():
default=None,
help="Specify the utt2num_samples file.",
)
parser.add_argument(
"--audio_sample_rate",
type=int,
default=16000,
help="input audio sampling rate (could be different from fs used in SSL)",
)
parser.add_argument(
"--write_num_frames", type=str, help="Specify wspecifer for utt2num_frames"
)
Expand Down Expand Up @@ -83,6 +89,7 @@ def main(args):
reader_conf["multilayer_feature"] = str2bool(reader_conf["multilayer_feature"])
if reader_conf.get("layer", None):
reader_conf["layer"] = int(reader_conf["layer"])
reader_conf["audio_sample_rate"] = args.audio_sample_rate
reader = reader_class(use_gpu=args.use_gpu, **reader_conf)

dump_feature(
Expand Down
65 changes: 62 additions & 3 deletions egs2/TEMPLATE/asr1/pyscripts/feats/ssl_feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from typing import List, Optional, Tuple, Union

import librosa
import numpy as np
import soundfile as sf
import torch
Expand Down Expand Up @@ -97,7 +98,15 @@ def __init__(self):

def load_audio(self, path: str, ref_len: Optional[int] = None):
wav, sr = sf.read(path)
assert sr == self.sample_rate, sr
# assert sr == self.sample_rate, sr
if sr != self.sample_rate:
logging.warning(
"sampling rate mismatch between "
"the requirements of feature extractor {} "
"and source wav {},"
"conduct resampling".format(self.sample_rate, sr)
)
wav = librosa.resample(wav, sr, self.sample_rate, scale=True)
if wav.ndim == 2:
wav = wav.mean(-1)
if ref_len is not None and abs(ref_len - len(wav)) > 160:
Expand Down Expand Up @@ -134,9 +143,18 @@ class MfccFeatureReader(BaseFeatureReader):
def __init__(
self,
sample_rate: int = 16000,
audio_sample_rate: int = 16000,
**kwargs, # placeholder for unused arguments
):
self.sample_rate = sample_rate
self.audio_sample_rate = audio_sample_rate
if self.sample_rate != self.audio_sample_rate:
logging.warning("The audio sample rate is different from feat extractor")
self.resample = torchaudio.transforms.Resample(
orig_freq=audio_sample_rate, new_freq=sample_rate
)
else:
self.resample = None
self.frame_length = 25 * sample_rate / 1000
self.frame_shift = 10 * sample_rate / 1000

Expand All @@ -149,6 +167,9 @@ def get_feats(
feats, feats_lens = [], []
with torch.no_grad():
x, x_lens = self.preprocess_data(data, data_lens)
if self.resample is not None:
x = self.resample(x)
x_lens = x_lens * self.sample_rate // self.audio_sample_rate
batch_size = x.shape[0]
for i in range(batch_size):
mfcc = torchaudio.compliance.kaldi.mfcc(
Expand Down Expand Up @@ -177,10 +198,19 @@ def __init__(
hubert_dir_path,
layer,
sample_rate=16000,
audio_sample_rate=16000,
max_chunk=1600000,
use_gpu=True,
):
self.sample_rate = sample_rate
self.sample_rate = int(sample_rate)
self.audio_sample_rate = audio_sample_rate
if self.sample_rate != self.audio_sample_rate:
logging.warning("The audio sample rate is different from feat extractor")
self.resample = torchaudio.transforms.Resample(
orig_freq=audio_sample_rate, new_freq=self.sample_rate
)
else:
self.resample = None

self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
from espnet2.asr.encoder.hubert_encoder import FairseqHubertEncoder
Expand All @@ -200,6 +230,9 @@ def get_feats(
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
x, x_lens = self.preprocess_data(data, data_lens)
if self.resample is not None:
x = self.resample(x)
x_lens = x_lens * self.sample_rate // self.audio_sample_rate
x = x.to(self.device)
mask = x.zeros_like(x, dtype=torch.long)
for i in range(x.shape[0]):
Expand Down Expand Up @@ -229,10 +262,19 @@ def __init__(
hubert_model_path,
layer,
sample_rate=16000,
audio_sample_rate=16000,
max_chunk=1600000,
use_gpu=True,
):
self.sample_rate = sample_rate
self.sample_rate = int(sample_rate) # str->int
self.audio_sample_rate = audio_sample_rate
if self.sample_rate != self.audio_sample_rate:
logging.warning("The audio sample rate is different from feat extractor")
self.resample = torchaudio.transforms.Resample(
orig_freq=audio_sample_rate, new_freq=self.sample_rate
)
else:
self.resample = None

self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
from espnet2.tasks.hubert import HubertTask
Expand All @@ -256,6 +298,9 @@ def get_feats(
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.inference_mode():
x, x_lens = self.preprocess_data(data, data_lens)
if self.resample is not None:
x = self.resample(x)
x_lens = x_lens * self.sample_rate // self.audio_sample_rate
x = x.to(self.device)
x_lens = x_lens.to(self.device)

Expand All @@ -272,6 +317,7 @@ class S3PRLFeatureReader(BaseFeatureReader):
def __init__(
self,
fs: Union[int, str] = 16000,
audio_sample_rate: int = 16000,
s3prl_conf: Optional[dict] = None,
download_dir: str = None,
multilayer_feature: bool = False,
Expand All @@ -285,6 +331,16 @@ def __init__(
multilayer_feature=multilayer_feature,
layer=layer,
)
self.sample_rate = fs
self.audio_sample_rate = audio_sample_rate
if self.sample_rate != self.audio_sample_rate:
logging.warning("The audio sample rate is different from feat extractor")
self.resample = torchaudio.transforms.Resample(
orig_freq=audio_sample_rate, new_freq=fs
)
else:
self.resample = None

self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)

Expand All @@ -296,6 +352,9 @@ def get_feats(
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
x, x_lens = self.preprocess_data(data, data_lens)
if self.resample is not None:
x = self.resample(x)
x_lens = x_lens * self.sample_rate // self.audio_sample_rate
x = x.to(self.device)

feats, feats_lens = self.model(x, x_lens)
Expand Down
3 changes: 3 additions & 0 deletions egs2/TEMPLATE/asr1/scripts/feats/perform_kmeans.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ upsample= # Upsampling rate of pseudo-labels to measure the pseudo-lab
use_gpu=false # Whether to use gpu in feature extraction
suffix= # A suffix to distinguish the feature dump directory. Empty in usual cases.
audio_format="wav" # The audio format of the source speech (flac, wav, *_ark, etc)
audio_sample_rate=16000 # the sample rate of input audio

skip_train_kmeans=false # Whether to skip the kmeans model training
nclusters=100 # Number of clusters of kmeans model
Expand Down Expand Up @@ -152,6 +153,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && ! [[ " ${skip_stages} " =~ [
${_cmd} JOB=1:${_nj} ${_logdir}/dump_features.JOB.log \
${python} pyscripts/feats/dump_ssl_feature.py \
--feature_conf "'${feature_conf}'" \
--audio_sample_rate "${audio_sample_rate}" \
--use_gpu ${use_gpu} \
--in_filetype "${_in_filetype}" \
--out_filetype "mat" \
Expand Down Expand Up @@ -267,6 +269,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ] && ! [[ " ${skip_stages} " =~ [
${_cmd} JOB=1:${_nj} "${_dump_dir}"/logdir/inference_pseudo_labels_km${nclusters}.JOB.log \
${python} pyscripts/feats/dump_km_label.py \
${_opts} \
--audio_sample_rate "${audio_sample_rate}" \
--km_path "${km_dir}/km_${nclusters}.mdl" \
--out_filetype "mat" \
--use_gpu ${use_gpu} \
Expand Down
8 changes: 6 additions & 2 deletions egs2/TEMPLATE/tts2/tts2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,14 @@ if ! "${skip_data_prep}"; then

if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
log "Stage 6: Discrete TTS discrete unit extraction"

# (en hubert), the original arguments
s3prl_conf="{upstream=${s3prl_upstream_name}}"
kmeans_feature_type=s3prl
kmeans_feature_conf="{type=${kmeans_feature_type},conf={s3prl_conf=${s3prl_conf},download_dir=ckpt,multilayer_feature=False,layer=${feature_layer}}}"

# (zh hubert), the arguments we used on aishell3
# s3prl_conf="{upstream=${s3prl_upstream_name},path_or_url=TencentGameMate/chinese-hubert-large}"
# kmeans_feature_type=s3prl
# kmeans_feature_conf={type=${kmeans_feature_type},conf={s3prl_conf=${s3prl_conf},download_dir=ckpt,multilayer_feature=False,layer=${feature_layer}}}
scripts/feats/perform_kmeans.sh \
--stage ${discrete_stage} \
--stop_stage ${discrete_stop_stage} \
Expand All @@ -606,6 +609,7 @@ if ! "${skip_data_prep}"; then
--datadir "${dumpdir}/raw" \
--featdir "${feature_dir}" \
--audio_format "${audio_format}" \
--audio_sample_rate "${fs}" \
--feature_type ${kmeans_feature_type} \
--layer "${feature_layer}" \
--feature_conf "${kmeans_feature_conf}" \
Expand Down
9 changes: 5 additions & 4 deletions egs2/aishell3/tts1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fi
db_root=${AISHELL3}

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
mkdir -p ${db_root}
log "stage -1: download data from openslr"
local/download_and_untar.sh "${db_root}" "https://www.openslr.org/resources/93/data_aishell3.tgz" data_aishell3.tgz
fi
Expand Down Expand Up @@ -78,19 +79,19 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
utils/fix_data_dir.sh data/${x}
done
fi

# use {dset},_phn here, to be consistent with mfa
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: split for development set"
utils/subset_data_dir.sh data/train 250 data/dev
utils/subset_data_dir.sh data/train_phn 250 data/dev_phn
utils/copy_data_dir.sh data/train data/train_no_dev
utils/copy_data_dir.sh data/train_phn data/train_phn_no_dev
utils/copy_data_dir.sh data/train_phn data/train_no_dev_phn
utils/filter_scp.pl --exclude data/dev/wav.scp \
data/train/wav.scp > data/train_no_dev/wav.scp
utils/filter_scp.pl --exclude data/dev_phn/wav.scp \
data/train_phn/wav.scp > data/train_phn_no_dev/wav.scp
data/train_phn/wav.scp > data/train_no_dev_phn/wav.scp
utils/fix_data_dir.sh data/train_no_dev
utils/fix_data_dir.sh data/train_phn_no_dev
utils/fix_data_dir.sh data/train_no_dev_phn
fi

log "Successfully finished. [elapsed=${SECONDS}s]"
Loading

0 comments on commit 38cc9e8

Please sign in to comment.