-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #60 from ddlBoJack/main
sync
- Loading branch information
Showing
24 changed files
with
3,876 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# An unofficial reproduction of VALL-E-X | ||
We refer to the repository of https://github.com/Plachtaa/VALL-E-X to open an unofficial reproduction of VALLEX. | ||
|
||
## Checkpoints | ||
Pretrained model can be found at [Google driven](https://drive.google.com/drive/folders/1wCTffPnSsiHpthaX-yzUne1dTBAMI_JV?usp=drive_link). | ||
|
||
|
||
## Decode with checkpoints | ||
```shell | ||
# first modify the model_home in following scrip to the location of downloaded/pretrained models. | ||
# second diy the prompt_txt, prompt_audio, target_txt with a corresponding language id | ||
bash examples\\vallex\\scripts\\inference.sh | ||
``` | ||
|
||
## Data preparation | ||
Vallex is trained on the dataset containing discrete speech tokens and text tokens. | ||
|
||
|
||
* Prepare a "info.tsv" file as following (file \t duration), containing speech path and duration of each speech. | ||
```txt | ||
SPEECH_PATH1 DURATION1 | ||
SPEECH_PATH2 DURATION2 | ||
SPEECH_PATH3 DURATION3 | ||
...... | ||
``` | ||
* Extract Codec according to the "info.tsv" | ||
```shell | ||
bash examples/vallex/data_pretreatment/extract_codec.sh | ||
``` | ||
We can obtain 8 "codec[i].tsv" files, 0~7 (i)-th layer of codecs are separately saved into "codec[i].tsv" | ||
```txt | ||
304 123 453 255 256 345 124 666 543 ... | ||
654 662 543 463 674 537 273 473 973 ... | ||
355 345 766 255 234 768 275 785 102 ... | ||
...... | ||
``` | ||
* Prepare the text ("trans.tsv") file with each line corresponding to the speech | ||
```txt | ||
Text for SPEECH1 | ||
Text for SPEECH2 | ||
Text for SPEECH3 | ||
...... | ||
``` | ||
Next, we need convert the text into tokens via tools like BPE/G2P/..., and it's saved as "st.tsv" | ||
```txt | ||
1521 467 885 2367 242 ... | ||
2362 3261 356 167 1246 2364 ... | ||
1246 123 432 134 53 13 ... | ||
...... | ||
``` | ||
* Convert data (codec[i].tsv and st.tsv) into binary file for fast reading | ||
```shell | ||
# We use the fairseq tool to achieve this convertion process | ||
python /home/wangtianrui/codes/fairseq/fairseq_cli/preprocess.py \ | ||
--only-source \ | ||
--trainpref /home/wangtianrui/develop_dataset/st.tsv \ | ||
--destdir /home/wangtianrui/develop_dataset/data_bin \ | ||
--thresholdsrc 0 \ | ||
--srcdict /home/wangtianrui/develop_dataset/dict.st.txt \ | ||
--workers `cat /proc/cpuinfo| grep "processor"| wc -l` | ||
for ((i=0;i<=7;i++)) | ||
do | ||
echo $i | ||
outname=train.at${i}.zh | ||
python /home/wangtianrui/codes/fairseq/fairseq_cli/preprocess.py \ | ||
--only-source \ | ||
--trainpref codec${i}.tsv \ | ||
--destdir $outdir \ | ||
--thresholdsrc 0 \ | ||
--srcdict /home/wangtianrui/develop_dataset/dict.at.txt \ | ||
--workers `cat /proc/cpuinfo| grep "processor"| wc -l` | ||
done | ||
``` | ||
where dict.at.txt and dict.st.txt are simple idx-to-idx rows of speech discrete tokens and text tokens, as shown in examples/vallex/data_pretreatment | ||
In this way, we can train the vallex with the dataset_config.train_data_path set as the home_path of binary files. We also release a tiny dataset for reference at [Google driven](https://drive.google.com/drive/folders/1wCTffPnSsiHpthaX-yzUne1dTBAMI_JV?usp=drive_link). | ||
## Train a new AR model | ||
After pretreated dataset, modify the "train_data_path" in following script, you can start for your training or finetuning. | ||
```shell | ||
bash examples\\vallex\\scripts\\vallex.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
model_config: | ||
llm_name: "" | ||
ar_at_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.at.txt" | ||
ar_st_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.st.txt" | ||
nar_at_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.at.txt" | ||
nar_st_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.st.txt" | ||
only_ar: true | ||
only_nar: false | ||
|
||
train_config: | ||
model_name: "" | ||
enable_ddp: false | ||
enable_fsdp: false | ||
low_cpu_fsdp: false | ||
run_validation: true | ||
batch_size_training: 4 | ||
batching_strategy: "packing" #alternative: padding | ||
context_length: 4096 | ||
gradient_accumulation_steps: 1 | ||
num_epochs: 100 | ||
num_workers_dataloader: 1 | ||
warmup_steps: 1000 | ||
total_steps: 100000 | ||
validation_interval: 1000 | ||
lr: 1e-4 | ||
weight_decay: 0.0 | ||
gamma: 0.85 | ||
seed: 42 | ||
use_fp16: false | ||
mixed_precision: true | ||
val_batch_size: 1 | ||
|
||
output_dir: "PATH/to/save/PEFT/model" | ||
freeze_layers: false | ||
num_freeze_layers: 1 | ||
quantization: false | ||
one_gpu: false | ||
save_model: true | ||
dist_checkpoint_root_folder: "PATH/to/save/FSDP/model" # will be used if using FSDP | ||
dist_checkpoint_folder: "fine-tuned" # will be used if using FSDP | ||
save_optimizer: false # will be used if using FSDP | ||
|
||
dataset_config: | ||
dataset: "speech_dataset" | ||
file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" | ||
train_data_path: null | ||
val_data_path: null | ||
|
||
fsdp_config: | ||
mixed_precision: true | ||
use_fp16: false | ||
# sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | ||
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | ||
checkpoint_type: "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | ||
fsdp_activation_checkpointing: true | ||
fsdp_cpu_offload: false | ||
pure_bf16: false | ||
optimizer: "AdamW" | ||
|
||
log_config: | ||
use_wandb: false | ||
wandb_dir: "/Work20/2023/wangtianrui/codes/test/debug/test_wandb" | ||
wandb_entity_name : "project_name" | ||
wandb_project_name : "project_name" | ||
wandb_exp_name : "exp_name" | ||
log_file: "/Work20/2023/wangtianrui/codes/test/debug/test.log" | ||
log_interval: 5 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
from tqdm import tqdm | ||
from encodec import EncodecModel | ||
import math | ||
import torch | ||
import torchaudio | ||
import sys | ||
from fairseq import search, utils | ||
import pandas as pd | ||
import librosa as lib | ||
from multiprocessing import Process | ||
|
||
def get_codec(model, audio_path, device, resampleers): | ||
with torch.no_grad(): | ||
audio, sr = torchaudio.load(audio_path) | ||
if audio.size(1) < 16000: | ||
return None | ||
# audio, sr = lib.load(audio_path, sr=16000) | ||
# audio = torch.tensor([audio]) | ||
# duration = len(audio[0]) / sr | ||
en_audio = audio.unsqueeze(0).to(device) | ||
en_audio = convert_audio(en_audio, sr, model.channels, resampleers) | ||
encoded_frames = model.encode(en_audio) | ||
# dim, nframe | ||
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze(0).detach().cpu().numpy() | ||
return codes # dim, nframe | ||
|
||
def convert_audio(wav: torch.Tensor, sr: int, target_channels: int, resampleers): | ||
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." | ||
if target_channels == 1: | ||
wav = wav.mean(0, keepdim=True) | ||
elif target_channels == 2: | ||
*shape, _, length = wav.shape | ||
wav = wav.expand(*shape, target_channels, length) | ||
elif wav.shape[0] == 1: | ||
wav = wav.expand(target_channels, -1) | ||
# wav = torchaudio.transforms.Resample(sr, target_sr)(wav) | ||
wav = resampleers[sr](wav) | ||
return wav | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--save-home', type=str, help='out home') | ||
parser.add_argument('--tsv', type=str, help='out home') | ||
parser.add_argument('--pro-idx', type=int, help='out home') | ||
parser.add_argument('--pro-total', type=int, help='out home') | ||
args = parser.parse_args() | ||
|
||
model = EncodecModel.encodec_model_24khz().cuda().eval() | ||
model.set_target_bandwidth(6.0) # 1.5, 3, 6, 12, 24 | ||
device = next(model.parameters()).device | ||
|
||
tsv = os.path.join(args.tsv) | ||
infos = utils.read_file(tsv) # path, dur | ||
|
||
resampleers = { | ||
16000: torchaudio.transforms.Resample(16000, model.sample_rate).to(device).eval(), | ||
44100: torchaudio.transforms.Resample(44100, model.sample_rate).to(device).eval(), | ||
48000: torchaudio.transforms.Resample(48000, model.sample_rate).to(device).eval(), | ||
24000: torchaudio.transforms.Resample(24000, model.sample_rate).to(device).eval(), | ||
8000: torchaudio.transforms.Resample(8000, model.sample_rate).to(device).eval(), | ||
22050: torchaudio.transforms.Resample(22050, model.sample_rate).to(device).eval(), | ||
} | ||
|
||
slice_len = len(infos) // args.pro_total | ||
start = args.pro_idx * slice_len | ||
if args.pro_idx == args.pro_total - 1: | ||
infos = tqdm(infos[start:]) | ||
else: | ||
end = (args.pro_idx + 1) * slice_len | ||
infos = tqdm(infos[start:end]) | ||
print("start:%d, len:%d"%(start, len(infos))) | ||
|
||
wfs = [open(os.path.join(args.save_home, "%d_codec%d.tsv"%(args.pro_idx, i)), "w") for i in range(8)] | ||
|
||
for step, row in enumerate(infos): | ||
row = row.strip() | ||
temp_path, dur = row.split("\t") | ||
codec = get_codec(model, temp_path, device, resampleers) | ||
if codec is None: | ||
continue | ||
for i in range(8): | ||
codecli = codec[i, :].flatten().astype(str) | ||
codecli = " ".join(codecli) | ||
print("%s\t%s"%(temp_path, codecli), file=wfs[i]) | ||
|
||
for i in wfs: | ||
i.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
GPU_NUM=8 | ||
NUM_P=16 | ||
save_home="/home/develop_dataset/codec" | ||
tsv="/home/develop_dataset/info_bilibili.tsv" | ||
GPU_IDX=0 | ||
mkdir -p ${save_home} | ||
for ((i=0; i<NUM_P; i++)); | ||
do | ||
temp_file=$ori_home/${i}_codec0.tsv | ||
echo run $i $GPU_IDX | ||
DEVICEIDX=$[${GPU_IDX}%(${GPU_NUM})] | ||
CUDA_VISIBLE_DEVICES=$DEVICEIDX \ | ||
python -u extract_codec.py \ | ||
--tsv ${tsv} \ | ||
--save-home ${save_home} \ | ||
--pro-idx ${i} \ | ||
--pro-total ${NUM_P} & | ||
GPU_IDX=$[$GPU_IDX+1] | ||
done | ||
wait | ||
|
||
|
||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from vallex_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Oops, something went wrong.