-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #173 from X-LANCE/ygr_pr2
for ctc-assisted llm-basd CASR codes pr
- Loading branch information
Showing
15 changed files
with
1,385 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# CTC-Assisted LLM-Based Contextual ASR | ||
|
||
## Guides | ||
|
||
[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords. | ||
|
||
## Model Architecture | ||
|
||
We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details. | ||
|
||
![](docs/model.png) | ||
|
||
## Checkpoints | ||
We only train the linear projector in this recipe. | ||
Encoder | Projector | LLM | ||
|---|---|---| | ||
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B) | ||
|
||
## Performance | ||
![](docs/performance.png) | ||
|
||
|
||
## Data preparation | ||
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias). | ||
They categorize the 5,000 most frequent words in the Librispeech training corpus as common | ||
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists. | ||
|
||
|
||
|
||
## Decoding with checkpoints | ||
LLM-based ASR Inference script. | ||
``` | ||
bash decode_wavlm_libri960_ft_char.sh | ||
``` | ||
LLM-based Contextual ASR Inference script, with different biaisng list sizes. | ||
``` | ||
bash decode_wavlm_libri960_ft_char_hotwords.sh | ||
``` | ||
|
||
|
||
## Training the model | ||
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt. | ||
``` | ||
bash finetune_wavlm_libri960_ft_char.sh | ||
``` | ||
LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt. | ||
``` | ||
bash finetune_wavlm_libri960_ft_char_hotwords.sh | ||
``` | ||
|
||
|
||
## Citation | ||
You can refer to the paper for more results. | ||
``` | ||
@article{yang2024ctc, | ||
title={CTC-Assisted LLM-Based Contextual ASR}, | ||
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie}, | ||
journal={Proc. SLT}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " | ||
prompt: "Transcribe speech to text. " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
@dataclass | ||
class ModelConfig: | ||
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory" | ||
llm_name: str = "vicuna-13b-v1.5" | ||
llm_path: str = "PATH/to/LLAMA/7B" | ||
llm_type: str = "decoder_only" | ||
llm_dim: int = 4096 | ||
encoder_name: Optional[str] = None | ||
encoder_ds_rate: int = 2 | ||
encoder_path: Optional[str] = None | ||
encoder_dim: int = 1280 | ||
encoder_projector: str = "linear" | ||
encoder_projector_ds_rate: int = 5 | ||
modal: str = "audio" | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
encoder_type: str = field(default="finetune", metadata={ | ||
"help": "whether model is only pretrained or finetuned, used for models such as hubert" | ||
}) | ||
|
||
@dataclass | ||
class PeftConfig: | ||
peft_method: str = "lora" # None , llama_adapter, prefix | ||
r: int = 8 | ||
lora_alpha: int = 32 | ||
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) | ||
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ]) | ||
bias: str = "none" | ||
task_type: str = "CAUSAL_LM" | ||
lora_dropout: float = 0.05 | ||
inference_mode: bool = False | ||
|
||
@dataclass | ||
class TrainConfig: | ||
model_name:str = "PATH/to/LLAMA/7B" | ||
enable_ddp:bool = False | ||
enable_deepspeed:bool = False | ||
enable_fsdp:bool = False | ||
low_cpu_fsdp:bool = False | ||
run_validation:bool = True | ||
batch_size_training:int = 4 | ||
batching_strategy:str = field(default="packing", metadata={ | ||
"help":"alternative: padding" | ||
}) | ||
context_length:int = 4096 | ||
gradient_accumulation_steps:int = 1 | ||
num_epochs:int = 3 | ||
num_workers_dataloader:int = 1 | ||
warmup_steps:int = 1000 | ||
total_steps:int = 100000 | ||
validation_interval:int = 1000 | ||
lr:float = 1e-4 | ||
weight_decay:float = 0.0 | ||
gamma:float = 0.85 | ||
seed:int = 42 | ||
use_fp16:bool = False | ||
mixed_precision:bool = True | ||
val_batch_size:int = 1 | ||
use_peft:bool = False | ||
peft_config:PeftConfig = field(default_factory=PeftConfig) | ||
output_dir:str = "PATH/to/save/PEFT/model" | ||
freeze_layers:bool = False | ||
num_freeze_layers:int = 1 | ||
quantization:bool = False | ||
one_gpu:bool = False | ||
save_model:bool = True | ||
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP | ||
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP | ||
save_optimizer:bool = False # will be used if using FSDP | ||
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
run_test_during_validation:bool = False | ||
run_test_during_validation_file:str = "test.wav" | ||
run_test_during_validation_prompt:str = "<|ASR|>" | ||
freeze_llm:bool = field(default=False, metadata={ | ||
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | ||
}) | ||
freeze_encoder:bool = False | ||
|
||
@dataclass | ||
class DataConfig: | ||
dataset: str = "speech_dataset" | ||
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" | ||
train_data_path: Optional[str] = None | ||
val_data_path: Optional[str] = None | ||
train_split: str = "train" | ||
test_split:str = "validation" | ||
prompt: Optional[str] = None | ||
data_path: Optional[str] = None | ||
max_words: Optional[int] = None | ||
max_mel: Optional[float] = None | ||
fix_length_audio: int = -1 | ||
inference_mode:bool = False | ||
input_type: str = field(default="raw", metadata={ | ||
"help":"Use raw when input is wav, mel when for whisper" | ||
}) | ||
mel_size: int = field(default=80, metadata={ | ||
"help": "80 for whisper large v1 and v2, 128 for v3" | ||
}) | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
infer_type: str = "bias" | ||
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv" | ||
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt" | ||
filter_type: str = "char" | ||
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json" | ||
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt" | ||
probability_threshold: float = 0.9 | ||
word_num: int = 15 | ||
filter_infer_sentence: bool = False | ||
filter_infer_sentence_few: bool = False | ||
first: int = 1 | ||
|
||
@dataclass | ||
class FSDPConfig: | ||
mixed_precision: bool = True | ||
use_fp16: bool = False | ||
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | ||
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | ||
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | ||
fsdp_activation_checkpointing: bool = True | ||
fsdp_cpu_offload: bool = False | ||
pure_bf16: bool = False | ||
optimizer: str = "AdamW" | ||
|
||
@dataclass | ||
class LogConfig: | ||
use_wandb: bool = False | ||
wandb_dir: str = "/root/test_wandb" | ||
wandb_entity_name: str = "project_name" | ||
wandb_project_name: str = "project_name" | ||
wandb_exp_name: str = "exp_name" | ||
log_file: str = "/root/test.log" | ||
log_interval: int = 5 |
Oops, something went wrong.