Skip to content

Commit

Permalink
Merge pull request #37 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
fix bugs in inference and saving ckpt using ddp
  • Loading branch information
ddlBoJack authored Jan 26, 2024
2 parents 17746df + e57a2b1 commit 25f5eea
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 46 deletions.
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
39 changes: 20 additions & 19 deletions scripts/finetune_asr_vicuna.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# lm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
Expand All @@ -31,7 +31,7 @@ output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwa

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python src/llama_recipes/pipeline/finetune.py \
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$output_dir \
Expand All @@ -42,8 +42,8 @@ hydra.run.dir=$output_dir \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=5 \
++model_config.encoder_projector=q-former \
++dataset_config.fix_length_audio=64 \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
Expand All @@ -58,23 +58,23 @@ hydra.run.dir=$output_dir \
++train_config.batch_size_training=4 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
++train_config.peft_config.peft_method=lora \
++metric=acc \
# ++model_config.encoder_projector=linear \
# ++model_config.encoder_projector_ds_rate=5 \
# ++train_config.use_peft=true \
# ++train_config.peft_config.peft_method=lora \
#++log_config.log_file=/$output_dir/train.log \
#++log_config.use_wandb=true \
#++log_config.wandb_dir=$output_dir \
#++log_config.wandb_entity_name=zym22 \
#++log_config.wandb_project_name=slam-llm \
#++log_config.wandb_exp_name=test \
#++log_config.wandb_exp_name=${0##*/%.*} \
#++log_config.log_interval 5 \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \
# --use_peft --peft_method lora \

##vicuna-7b-v1.5

else
torchrun \
--nnodes 1 \
Expand All @@ -91,11 +91,10 @@ hydra.run.dir=$output_dir \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=q-former \
++model_config.encoder_projector_ds_rate=5 \
++dataset_config.fix_length_audio=64 \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++dataset_config.fix_length_audio=64 \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
Expand All @@ -112,13 +111,15 @@ hydra.run.dir=$output_dir \
++train_config.enable_ddp=true \
++train_config.use_fp16=true \
++metric=acc \
++log_config.log_file=/$output_dir/train.log \
++log_config.use_wandb=true \
++log_config.wandb_dir=$output_dir \
++log_config.wandb_entity_name=zym22 \
++log_config.wandb_project_name=slam-llm \
++log_config.wandb_exp_name=${0##*/%.*} \
++log_config.log_interval=5 \
# ++log_config.log_file=/$output_dir/train.log \
# ++log_config.use_wandb=true \
# ++log_config.wandb_dir=$output_dir \
# ++log_config.wandb_entity_name=zym22 \
# ++log_config.wandb_project_name=slam-llm \
# ++log_config.wandb_exp_name=${0##*/%.*} \
# ++log_config.log_interval=5 \
# ++model_config.encoder_projector=linear \
# ++model_config.encoder_projector_ds_rate=5 \
# ++train_config.use_peft=true \
# ++train_config.peft_config.peft_method=lora \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
Expand Down
48 changes: 29 additions & 19 deletions scripts/inference_asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,43 @@ export TOKENIZERS_PARALLELISM=false

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/small.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106
ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106/asr/2/model.pt
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1

# -m debugpy --listen 5678 --wait-for-client
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/inference.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--llm_name vicuna-7b-v1.5 \
--llm_path $llm_path \
--llm_dim 4096 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--output_dir $output_dir \
--ckpt_path $ckpt_path \
--wav_path "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \
--prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=q-former \
++dataset_config.fix_length_audio=64 \
++ckpt_path=$ckpt_path/model.pt \
++wav_path="/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \
++prompt="Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
# ++model_config.encoder_projector=linear \
# ++model_config.encoder_projector_ds_rate=5 \
# --peft_ckpt $peft_ckpt \
# --use_peft --peft_method lora \
2 changes: 1 addition & 1 deletion scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=3
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

Expand Down
5 changes: 4 additions & 1 deletion src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
logger.info(f"llm saved at {save_dir}")

save_full_path = os.path.join(save_dir, "model.pt")
cpu_state = model.state_dict()
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
cpu_state = model.module.state_dict()
else:
cpu_state = model.state_dict()
encoder_dict = {}
if not cfg.freeze_encoder:
for key in cpu_state.keys():
Expand Down
7 changes: 6 additions & 1 deletion src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ def inference(
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device)

encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1))
encoder_outs = self.encoder_projector(encoder_outs)

if self.model_config.encoder_projector == "q-former":
audio_mel_post_mask = torch.ones(encoder_outs.size()[:-1], dtype=torch.long).to(encoder_outs.device)
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
if self.model_config.encoder_projector == "linear":
encoder_outs = self.encoder_projector(encoder_outs)
else: # Text QA
encoder_outs = torch.empty(1, 0, self.llm.model.embed_tokens.embedding_dim).to(device)

Expand Down
2 changes: 1 addition & 1 deletion src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main(kwargs: DictConfig):
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

if (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0:
if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0:
logger.info("train_config: {}".format(train_config))
logger.info("fsdp_config: {}".format(fsdp_config))
logger.info("model_config: {}".format(model_config))
Expand Down
6 changes: 6 additions & 0 deletions src/llama_recipes/pipeline/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def main(kwargs: DictConfig):
kwargs.log_config, \
kwargs.dataset_config

del kwargs.train_config
del kwargs.fsdp_config
del kwargs.model_config
del kwargs.log_config
del kwargs.dataset_config

# Set the seeds for reproducibility
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
Expand Down
3 changes: 1 addition & 2 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
pbar.update(1)

pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})")

dist.barrier()

if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation:
eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
eval_epoch_acc = rest[0] if rest else -1
Expand Down

0 comments on commit 25f5eea

Please sign in to comment.