From e57a2b129ce6a9eb287bfb842e58d0117b97fef7 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Sat, 27 Jan 2024 00:41:46 +0800 Subject: [PATCH] fix bugs in inference and saving ckpt using ddp --- scripts/compute_wer.sh | 4 +- scripts/finetune_asr_vicuna.sh | 39 +++++++-------- scripts/inference_asr.sh | 48 +++++++++++-------- scripts/inference_asr_batch.sh | 2 +- .../model_checkpointing/checkpoint_handler.py | 5 +- src/llama_recipes/models/slam_model.py | 7 ++- src/llama_recipes/pipeline/finetune.py | 2 +- src/llama_recipes/pipeline/inference.py | 6 +++ src/llama_recipes/utils/train_utils.py | 3 +- 9 files changed, 70 insertions(+), 46 deletions(-) diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index 849ef210..f88ffc22 100644 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -1,7 +1,7 @@ #cd /root/SLAM-LLM -trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt" -preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred" +trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt" +preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred" # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer diff --git a/scripts/finetune_asr_vicuna.sh b/scripts/finetune_asr_vicuna.sh index f5056d9a..4d18cf58 100644 --- a/scripts/finetune_asr_vicuna.sh +++ b/scripts/finetune_asr_vicuna.sh @@ -21,7 +21,7 @@ speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt # speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt # llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T -# lm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf # llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 @@ -31,7 +31,7 @@ output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwa # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then -python src/llama_recipes/pipeline/finetune.py \ +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ --config-path "/root/SLAM-LLM/scripts/conf" \ --config-name "asr_vicuna_lora.yaml" \ hydra.run.dir=$output_dir \ @@ -42,8 +42,8 @@ hydra.run.dir=$output_dir \ ++model_config.encoder_ds_rate=2 \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=1280 \ -++model_config.encoder_projector=linear \ -++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_projector=q-former \ +++dataset_config.fix_length_audio=64 \ ++dataset_config.dataset=speech_dataset \ ++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ ++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ @@ -58,23 +58,23 @@ hydra.run.dir=$output_dir \ ++train_config.batch_size_training=4 \ ++train_config.val_batch_size=4 \ ++train_config.num_workers_dataloader=4 \ -++train_config.lr=1e-4 \ ++train_config.output_dir=$output_dir \ -++train_config.use_peft=true \ -++train_config.peft_config.peft_method=lora \ ++metric=acc \ +# ++model_config.encoder_projector=linear \ +# ++model_config.encoder_projector_ds_rate=5 \ +# ++train_config.use_peft=true \ +# ++train_config.peft_config.peft_method=lora \ #++log_config.log_file=/$output_dir/train.log \ #++log_config.use_wandb=true \ #++log_config.wandb_dir=$output_dir \ #++log_config.wandb_entity_name=zym22 \ #++log_config.wandb_project_name=slam-llm \ -#++log_config.wandb_exp_name=test \ +#++log_config.wandb_exp_name=${0##*/%.*} \ #++log_config.log_interval 5 \ # --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ -# --use_peft --peft_method lora \ -##vicuna-7b-v1.5 + else torchrun \ --nnodes 1 \ @@ -91,11 +91,10 @@ hydra.run.dir=$output_dir \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=1280 \ ++model_config.encoder_projector=q-former \ -++model_config.encoder_projector_ds_rate=5 \ +++dataset_config.fix_length_audio=64 \ ++dataset_config.dataset=speech_dataset \ ++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ ++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ -++dataset_config.fix_length_audio=64 \ ++train_config.model_name=asr \ ++train_config.freeze_encoder=true \ ++train_config.freeze_llm=true \ @@ -112,13 +111,15 @@ hydra.run.dir=$output_dir \ ++train_config.enable_ddp=true \ ++train_config.use_fp16=true \ ++metric=acc \ -++log_config.log_file=/$output_dir/train.log \ -++log_config.use_wandb=true \ -++log_config.wandb_dir=$output_dir \ -++log_config.wandb_entity_name=zym22 \ -++log_config.wandb_project_name=slam-llm \ -++log_config.wandb_exp_name=${0##*/%.*} \ -++log_config.log_interval=5 \ +# ++log_config.log_file=/$output_dir/train.log \ +# ++log_config.use_wandb=true \ +# ++log_config.wandb_dir=$output_dir \ +# ++log_config.wandb_entity_name=zym22 \ +# ++log_config.wandb_project_name=slam-llm \ +# ++log_config.wandb_exp_name=${0##*/%.*} \ +# ++log_config.log_interval=5 \ +# ++model_config.encoder_projector=linear \ +# ++model_config.encoder_projector_ds_rate=5 \ # ++train_config.use_peft=true \ # ++train_config.peft_config.peft_method=lora \ # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ diff --git a/scripts/inference_asr.sh b/scripts/inference_asr.sh index 65677ff9..35ae0d43 100644 --- a/scripts/inference_asr.sh +++ b/scripts/inference_asr.sh @@ -6,33 +6,43 @@ export TOKENIZERS_PARALLELISM=false cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/small.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt # speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 -output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106 -ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106/asr/2/model.pt +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126 +ckpt_path=$output_dir/asr/2 # peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1 # -m debugpy --listen 5678 --wait-for-client python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/inference.py \ ---model_name asr \ ---freeze_encoder \ ---freeze_llm \ ---llm_name vicuna-7b-v1.5 \ ---llm_path $llm_path \ ---llm_dim 4096 \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_dim 1280 \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---output_dir $output_dir \ ---ckpt_path $ckpt_path \ ---wav_path "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \ ---prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=whisper \ +++model_config.encoder_ds_rate=2 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1280 \ +++model_config.encoder_projector=q-former \ +++dataset_config.fix_length_audio=64 \ +++ckpt_path=$ckpt_path/model.pt \ +++wav_path="/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \ +++prompt="Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \ +++train_config.model_name=asr \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +# ++model_config.encoder_projector=linear \ +# ++model_config.encoder_projector_ds_rate=5 \ # --peft_ckpt $peft_ckpt \ # --use_peft --peft_method lora \ \ No newline at end of file diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh index 2d6d752e..23fac068 100644 --- a/scripts/inference_asr_batch.sh +++ b/scripts/inference_asr_batch.sh @@ -1,6 +1,6 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=3 export TOKENIZERS_PARALLELISM=false # export CUDA_LAUNCH_BLOCKING=1 diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index df8f3da6..4970bd68 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -176,7 +176,10 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): logger.info(f"llm saved at {save_dir}") save_full_path = os.path.join(save_dir, "model.pt") - cpu_state = model.state_dict() + if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP + cpu_state = model.module.state_dict() + else: + cpu_state = model.state_dict() encoder_dict = {} if not cfg.freeze_encoder: for key in cpu_state.keys(): diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index 602760a5..f76bd541 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -322,7 +322,12 @@ def inference( audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device) encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) - encoder_outs = self.encoder_projector(encoder_outs) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones(encoder_outs.size()[:-1], dtype=torch.long).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) else: # Text QA encoder_outs = torch.empty(1, 0, self.llm.model.embed_tokens.embedding_dim).to(device) diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index a49b9a83..205b891f 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -131,7 +131,7 @@ def main(kwargs: DictConfig): clear_gpu_cache(local_rank) setup_environ_flags(rank) - if (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: logger.info("train_config: {}".format(train_config)) logger.info("fsdp_config: {}".format(fsdp_config)) logger.info("model_config: {}".format(model_config)) diff --git a/src/llama_recipes/pipeline/inference.py b/src/llama_recipes/pipeline/inference.py index cc7bbd16..44bd4edb 100644 --- a/src/llama_recipes/pipeline/inference.py +++ b/src/llama_recipes/pipeline/inference.py @@ -48,6 +48,12 @@ def main(kwargs: DictConfig): kwargs.log_config, \ kwargs.dataset_config + del kwargs.train_config + del kwargs.fsdp_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config + # Set the seeds for reproducibility torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index b3ec8390..fd397e00 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -162,8 +162,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche pbar.update(1) pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})") - - dist.barrier() + if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation: eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) eval_epoch_acc = rest[0] if rest else -1