Skip to content

Commit

Permalink
Merge pull request #60 from ddlBoJack/main
Browse files Browse the repository at this point in the history
sync
  • Loading branch information
ddlBoJack authored May 3, 2024
2 parents 796fac9 + 78fa7f4 commit aa7ea00
Show file tree
Hide file tree
Showing 24 changed files with 3,876 additions and 14 deletions.
89 changes: 89 additions & 0 deletions examples/vallex/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# An unofficial reproduction of VALL-E-X
We refer to the repository of https://github.com/Plachtaa/VALL-E-X to open an unofficial reproduction of VALLEX.

## Checkpoints
Pretrained model can be found at [Google driven](https://drive.google.com/drive/folders/1wCTffPnSsiHpthaX-yzUne1dTBAMI_JV?usp=drive_link).


## Decode with checkpoints
```shell
# first modify the model_home in following scrip to the location of downloaded/pretrained models.
# second diy the prompt_txt, prompt_audio, target_txt with a corresponding language id
bash examples\\vallex\\scripts\\inference.sh
```

## Data preparation
Vallex is trained on the dataset containing discrete speech tokens and text tokens.


* Prepare a "info.tsv" file as following (file \t duration), containing speech path and duration of each speech.
```txt
SPEECH_PATH1 DURATION1
SPEECH_PATH2 DURATION2
SPEECH_PATH3 DURATION3
......
```
* Extract Codec according to the "info.tsv"
```shell
bash examples/vallex/data_pretreatment/extract_codec.sh
```
We can obtain 8 "codec[i].tsv" files, 0~7 (i)-th layer of codecs are separately saved into "codec[i].tsv"
```txt
304 123 453 255 256 345 124 666 543 ...
654 662 543 463 674 537 273 473 973 ...
355 345 766 255 234 768 275 785 102 ...
......
```
* Prepare the text ("trans.tsv") file with each line corresponding to the speech
```txt
Text for SPEECH1
Text for SPEECH2
Text for SPEECH3
......
```
Next, we need convert the text into tokens via tools like BPE/G2P/..., and it's saved as "st.tsv"
```txt
1521 467 885 2367 242 ...
2362 3261 356 167 1246 2364 ...
1246 123 432 134 53 13 ...
......
```
* Convert data (codec[i].tsv and st.tsv) into binary file for fast reading
```shell
# We use the fairseq tool to achieve this convertion process
python /home/wangtianrui/codes/fairseq/fairseq_cli/preprocess.py \
--only-source \
--trainpref /home/wangtianrui/develop_dataset/st.tsv \
--destdir /home/wangtianrui/develop_dataset/data_bin \
--thresholdsrc 0 \
--srcdict /home/wangtianrui/develop_dataset/dict.st.txt \
--workers `cat /proc/cpuinfo| grep "processor"| wc -l`
for ((i=0;i<=7;i++))
do
echo $i
outname=train.at${i}.zh
python /home/wangtianrui/codes/fairseq/fairseq_cli/preprocess.py \
--only-source \
--trainpref codec${i}.tsv \
--destdir $outdir \
--thresholdsrc 0 \
--srcdict /home/wangtianrui/develop_dataset/dict.at.txt \
--workers `cat /proc/cpuinfo| grep "processor"| wc -l`
done
```
where dict.at.txt and dict.st.txt are simple idx-to-idx rows of speech discrete tokens and text tokens, as shown in examples/vallex/data_pretreatment
In this way, we can train the vallex with the dataset_config.train_data_path set as the home_path of binary files. We also release a tiny dataset for reference at [Google driven](https://drive.google.com/drive/folders/1wCTffPnSsiHpthaX-yzUne1dTBAMI_JV?usp=drive_link).
## Train a new AR model
After pretreated dataset, modify the "train_data_path" in following script, you can start for your training or finetuning.
```shell
bash examples\\vallex\\scripts\\vallex.sh
```
69 changes: 69 additions & 0 deletions examples/vallex/conf/vallex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

model_config:
llm_name: ""
ar_at_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.at.txt"
ar_st_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.st.txt"
nar_at_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.at.txt"
nar_st_dict: "/Work20/2023/wangtianrui/datas/bilibli_7min_woman/data_bin/dict.st.txt"
only_ar: true
only_nar: false

train_config:
model_name: ""
enable_ddp: false
enable_fsdp: false
low_cpu_fsdp: false
run_validation: true
batch_size_training: 4
batching_strategy: "packing" #alternative: padding
context_length: 4096
gradient_accumulation_steps: 1
num_epochs: 100
num_workers_dataloader: 1
warmup_steps: 1000
total_steps: 100000
validation_interval: 1000
lr: 1e-4
weight_decay: 0.0
gamma: 0.85
seed: 42
use_fp16: false
mixed_precision: true
val_batch_size: 1

output_dir: "PATH/to/save/PEFT/model"
freeze_layers: false
num_freeze_layers: 1
quantization: false
one_gpu: false
save_model: true
dist_checkpoint_root_folder: "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder: "fine-tuned" # will be used if using FSDP
save_optimizer: false # will be used if using FSDP

dataset_config:
dataset: "speech_dataset"
file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: null
val_data_path: null

fsdp_config:
mixed_precision: true
use_fp16: false
# sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: true
fsdp_cpu_offload: false
pure_bf16: false
optimizer: "AdamW"

log_config:
use_wandb: false
wandb_dir: "/Work20/2023/wangtianrui/codes/test/debug/test_wandb"
wandb_entity_name : "project_name"
wandb_project_name : "project_name"
wandb_exp_name : "exp_name"
log_file: "/Work20/2023/wangtianrui/codes/test/debug/test.log"
log_interval: 5

90 changes: 90 additions & 0 deletions examples/vallex/data_pretreatment/extract_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import argparse
import numpy as np
from tqdm import tqdm
from encodec import EncodecModel
import math
import torch
import torchaudio
import sys
from fairseq import search, utils
import pandas as pd
import librosa as lib
from multiprocessing import Process

def get_codec(model, audio_path, device, resampleers):
with torch.no_grad():
audio, sr = torchaudio.load(audio_path)
if audio.size(1) < 16000:
return None
# audio, sr = lib.load(audio_path, sr=16000)
# audio = torch.tensor([audio])
# duration = len(audio[0]) / sr
en_audio = audio.unsqueeze(0).to(device)
en_audio = convert_audio(en_audio, sr, model.channels, resampleers)
encoded_frames = model.encode(en_audio)
# dim, nframe
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze(0).detach().cpu().numpy()
return codes # dim, nframe

def convert_audio(wav: torch.Tensor, sr: int, target_channels: int, resampleers):
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
if target_channels == 1:
wav = wav.mean(0, keepdim=True)
elif target_channels == 2:
*shape, _, length = wav.shape
wav = wav.expand(*shape, target_channels, length)
elif wav.shape[0] == 1:
wav = wav.expand(target_channels, -1)
# wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
wav = resampleers[sr](wav)
return wav

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--save-home', type=str, help='out home')
parser.add_argument('--tsv', type=str, help='out home')
parser.add_argument('--pro-idx', type=int, help='out home')
parser.add_argument('--pro-total', type=int, help='out home')
args = parser.parse_args()

model = EncodecModel.encodec_model_24khz().cuda().eval()
model.set_target_bandwidth(6.0) # 1.5, 3, 6, 12, 24
device = next(model.parameters()).device

tsv = os.path.join(args.tsv)
infos = utils.read_file(tsv) # path, dur

resampleers = {
16000: torchaudio.transforms.Resample(16000, model.sample_rate).to(device).eval(),
44100: torchaudio.transforms.Resample(44100, model.sample_rate).to(device).eval(),
48000: torchaudio.transforms.Resample(48000, model.sample_rate).to(device).eval(),
24000: torchaudio.transforms.Resample(24000, model.sample_rate).to(device).eval(),
8000: torchaudio.transforms.Resample(8000, model.sample_rate).to(device).eval(),
22050: torchaudio.transforms.Resample(22050, model.sample_rate).to(device).eval(),
}

slice_len = len(infos) // args.pro_total
start = args.pro_idx * slice_len
if args.pro_idx == args.pro_total - 1:
infos = tqdm(infos[start:])
else:
end = (args.pro_idx + 1) * slice_len
infos = tqdm(infos[start:end])
print("start:%d, len:%d"%(start, len(infos)))

wfs = [open(os.path.join(args.save_home, "%d_codec%d.tsv"%(args.pro_idx, i)), "w") for i in range(8)]

for step, row in enumerate(infos):
row = row.strip()
temp_path, dur = row.split("\t")
codec = get_codec(model, temp_path, device, resampleers)
if codec is None:
continue
for i in range(8):
codecli = codec[i, :].flatten().astype(str)
codecli = " ".join(codecli)
print("%s\t%s"%(temp_path, codecli), file=wfs[i])

for i in wfs:
i.close()
23 changes: 23 additions & 0 deletions examples/vallex/data_pretreatment/extract_codec.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
GPU_NUM=8
NUM_P=16
save_home="/home/develop_dataset/codec"
tsv="/home/develop_dataset/info_bilibili.tsv"
GPU_IDX=0
mkdir -p ${save_home}
for ((i=0; i<NUM_P; i++));
do
temp_file=$ori_home/${i}_codec0.tsv
echo run $i $GPU_IDX
DEVICEIDX=$[${GPU_IDX}%(${GPU_NUM})]
CUDA_VISIBLE_DEVICES=$DEVICEIDX \
python -u extract_codec.py \
--tsv ${tsv} \
--save-home ${save_home} \
--pro-idx ${i} \
--pro-total ${NUM_P} &
GPU_IDX=$[$GPU_IDX+1]
done
wait



Binary file added examples/vallex/demo/en2en_test_out.wav
Binary file not shown.
Binary file added examples/vallex/demo/en2zh_test_out.wav
Binary file not shown.
Binary file added examples/vallex/demo/en_prompt.flac
Binary file not shown.
Binary file added examples/vallex/demo/zh2en_test_out.wav
Binary file not shown.
Binary file added examples/vallex/demo/zh2zh_test_out.wav
Binary file not shown.
Binary file added examples/vallex/demo/zh_prompt.wav
Binary file not shown.
45 changes: 45 additions & 0 deletions examples/vallex/finetune_vallex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from slam_llm.pipeline.finetune import main as train

import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from vallex_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
Loading

0 comments on commit aa7ea00

Please sign in to comment.