Skip to content

Commit

Permalink
Merge pull request #173 from X-LANCE/ygr_pr2
Browse files Browse the repository at this point in the history
for ctc-assisted llm-basd CASR codes pr
  • Loading branch information
ddlBoJack authored Nov 17, 2024
2 parents f32b8a2 + 52fab27 commit 80cc33f
Show file tree
Hide file tree
Showing 15 changed files with 1,385 additions and 0 deletions.
61 changes: 61 additions & 0 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# CTC-Assisted LLM-Based Contextual ASR

## Guides

[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords.

## Model Architecture

We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details.

![](docs/model.png)

## Checkpoints
We only train the linear projector in this recipe.
Encoder | Projector | LLM
|---|---|---|
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B)

## Performance
![](docs/performance.png)


## Data preparation
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
They categorize the 5,000 most frequent words in the Librispeech training corpus as common
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.



## Decoding with checkpoints
LLM-based ASR Inference script.
```
bash decode_wavlm_libri960_ft_char.sh
```
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
```
bash decode_wavlm_libri960_ft_char_hotwords.sh
```


## Training the model
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
```
bash finetune_wavlm_libri960_ft_char.sh
```
LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt.
```
bash finetune_wavlm_libri960_ft_char_hotwords.sh
```


## Citation
You can refer to the paper for more results.
```
@article{yang2024ctc,
title={CTC-Assisted LLM-Based Contextual ASR},
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
journal={Proc. SLT},
year={2024}
}
```
19 changes: 19 additions & 0 deletions examples/contextual_asr/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
4 changes: 4 additions & 0 deletions examples/contextual_asr/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt: "Transcribe speech to text. "
137 changes: 137 additions & 0 deletions examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})

@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
})
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1
use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False

@dataclass
class DataConfig:
dataset: str = "speech_dataset"
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
infer_type: str = "bias"
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
filter_type: str = "char"
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
probability_threshold: float = 0.9
word_num: int = 15
filter_infer_sentence: bool = False
filter_infer_sentence_few: bool = False
first: int = 1

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/root/test.log"
log_interval: int = 5
Loading

0 comments on commit 80cc33f

Please sign in to comment.