Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
蒄骰 committed Nov 17, 2024
1 parent 86dc310 commit 52fab27
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 7 deletions.
11 changes: 7 additions & 4 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ words, with the remainder classified as rare words. The biasing list generated f


## Decoding with checkpoints
LLM-based Contextual ASR Inference script, with different biaisng sizes and test sets.
LLM-based ASR Inference script.
```
bash decode_wavlm_libri960_ft_char.sh
```
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
```
bash decode_wavlm_libri960_ft_char_hotwords.sh
```


## Training the model
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
```
Expand All @@ -53,6 +58,4 @@ You can refer to the paper for more results.
journal={Proc. SLT},
year={2024}
}
```


```
2 changes: 1 addition & 1 deletion examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TrainConfig:
@dataclass
class DataConfig:
dataset: str = "speech_dataset"
file: str = "examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset"
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ hydra.run.dir=$output_dir \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=raw \
++dataset_config.dataset=hotwords_dataset \
++dataset_config.file=src/slam_llm/datasets/hotwords_dataset.py:get_speech_dataset \
++dataset_config.file=examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset \
++train_config.model_name=asr \
++train_config.num_epochs=5 \
++train_config.freeze_encoder=true \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=2
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM
cd $run_dir
code_dir=examples/contextual_asr

speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-20240521
ckpt_path=$output_dir/asr_epoch_3_step_9780
N=100
for ref_split in test_clean test_other; do
split=librispeech_${ref_split}
val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl
decode_log=$ckpt_path/decode_${split}_beam4_debug
python $code_dir/inference_contextual_asr_batch.py \
--config-path "conf" \
--config-name "prompt.yaml" \
hydra.run.dir=$ckpt_path \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=wavlm \
++model_config.normalize=true \
++dataset_config.normalize=true \
++model_config.encoder_projector_ds_rate=5 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1024 \
++model_config.encoder_projector=cov1d-linear \
++dataset_config.dataset=speech_dataset \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=raw \
++dataset_config.inference_mode=true \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=1 \
++train_config.val_batch_size=1 \
++train_config.num_workers_dataloader=0 \
++train_config.output_dir=$output_dir \
++decode_log=$decode_log \
++ckpt_path=$ckpt_path/model.pt && \
python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc && \
python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc && \
python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer && \
python /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_score.py \
--refs /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/ref_score/${ref_split}.biasing_${N}.tsv \
--hyps ${decode_log}_pred.proc \
--output_file ${decode_log}.proc.wer
done
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export TOKENIZERS_PARALLELISM=false
export CUDA_LAUNCH_BLOCKING=1
export HYDRA_FULL_ERROR=1

run_dir=/root/SLAM-LLM
run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM
cd $run_dir
code_dir=examples/contextual_asr

Expand Down

0 comments on commit 52fab27

Please sign in to comment.