diff --git a/.gitignore b/.gitignore index 0c04a67d..c72a47be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .DS_Store __pycache__ .ipynb_checkpoints +.idea/* +transformers diff --git a/scripts/finetune.sh b/scripts/finetune.sh deleted file mode 100644 index 5c2fa513..00000000 --- a/scripts/finetune.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 -export CUDA_LAUNCH_BLOCKING=1 - -cd /root/SLAM-LLM - -audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth -speech_encoder_path=/nfs/zhifu.gzf/init_model/whisper/large-v2.pt -llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune - -# -m debugpy --listen 5678 --wait-for-client -python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ ---model_name echat \ ---quantization \ ---llm_name llama-2-7b-hf \ ---llm_path $llm_path \ ---encoder_name whisper \ ---encoder_path $speech_encoder_path \ ---encoder_projector linear \ ---dataset custom_dataset \ ---custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \ ---batching_strategy padding \ ---max_words 2596 \ ---num_epochs 1 \ ---batch_size_training 2 \ ---output_dir $output_dir \ No newline at end of file diff --git a/scripts/finetune_echat.sh b/scripts/finetune_echat.sh new file mode 100644 index 00000000..866df6b2 --- /dev/null +++ b/scripts/finetune_echat.sh @@ -0,0 +1,102 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_LAUNCH_BLOCKING=1 +# export OMP_NUM_THREADS=1 +# export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-echat-ds5-proj2048-debug + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name echat \ +--freeze_encoder \ +--freeze_llm \ +--use_fp16 \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \ +--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ +--batching_strategy custom \ +--custom_dataset.max_words 1024 \ +--num_epochs 100 \ +--batch_size_training 2 \ +--val_batch_size 2 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --use_peft --peft_method lora \ + +# train +# {"trans": "Well, do you have your passport?\n", +# "emotion": "xxx", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_F009.wav"} +# {"trans": "No, I don't have a passport.\n", +# "emotion": "neu", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_M010.wav"} + +# val +# {"trans": "Yeah, well thanks for your help.\n", +# "emotion": "ang", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_M040.wav"} +# {"trans": "I'm sorry. Good luck, man.\n", +# "emotion": "xxx", +# "wav": "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F038.wav"} + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name echat \ +--freeze_encoder \ +--use_fp16 \ +--use_peft --peft_method lora \ +--enable_fsdp \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/echat_dataset.py:get_audio_dataset \ +--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 8 \ +--val_batch_size 8 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file /nfs/zhifu.gzf/data/IEMOCAP_full_release/Session1/sentences/wav/Ses01M_impro01/Ses01M_impro01_F009.wav \ +--run_test_during_validation_prompt """ +Please provide an emotional response based on the emotional speech you hear. +Remember to format your answer as follows: <|EMOTION|><|REPLY|>. +<|EMOTION|> is a standalone adjective. +<|REPLY|> is a reply based on a the speech. +""" \ +--metric acc \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --freeze_llm \ +fi diff --git a/scripts/finetune_speech_pretraining.sh b/scripts/finetune_speech_pretraining.sh new file mode 100644 index 00000000..471cc8a9 --- /dev/null +++ b/scripts/finetune_speech_pretraining.sh @@ -0,0 +1,87 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048 + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \ +--custom_dataset.train_data_path /nfs/beinian.lzr/workspace/datasets/speech_llm/train_dataset/data_wav_json/asr/librispeech_train_960h_wav_speech_llm_train_data.json \ +--custom_dataset.val_data_path /nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/dev_other/librispeech_dev_other.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 4 \ +--val_batch_size 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav" \ +--run_test_during_validation_prompt "<|ASR|>" \ +--metric acc \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +# --use_peft --peft_method lora \ + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--use_fp16 \ +--enable_fsdp \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \ +--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h_wav_speech_llm_train_data.json \ +--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 8 \ +--val_batch_size 8 \ +--num_workers_dataloader 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \ +--run_test_during_validation_prompt "<|ASR|>" \ +--metric acc \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +# --use_peft --peft_method lora \ +fi + +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file diff --git a/scripts/inference_asr.sh b/scripts/inference_asr.sh new file mode 100644 index 00000000..fc53658b --- /dev/null +++ b/scripts/inference_asr.sh @@ -0,0 +1,30 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048 + +# -m debugpy --listen 5678 --wait-for-client +python src/llama_recipes/pipeline/inference.py \ +--model_name asr \ +--freeze_llm \ +--freeze_encoder \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--output_dir $output_dir \ +--ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048/asr/13/model.pt" \ +--wav_path "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \ +--prompt "<|ASR|>" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" \ +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh new file mode 100644 index 00000000..beaad4ce --- /dev/null +++ b/scripts/inference_asr_batch.sh @@ -0,0 +1,38 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048 +ckpt_path=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048/asr/10/model.pt +decode_log=/root/decode_log + +# -m debugpy --listen 5678 --wait-for-client +python src/llama_recipes/pipeline/inference_batch.py \ +--model_name asr \ +--freeze_llm \ +--freeze_encoder \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_dataset_inference.py:get_audio_dataset \ +--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_test_other.jsonl \ +--batching_strategy custom \ +--num_epochs 1 \ +--val_batch_size 8 \ +--num_workers_dataloader 4 \ +--output_dir $output_dir \ +--ckpt_path $ckpt_path \ +--decode_log $decode_log \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" \ +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/scripts/inference_echat.sh b/scripts/inference_echat.sh new file mode 100644 index 00000000..f34726d2 --- /dev/null +++ b/scripts/inference_echat.sh @@ -0,0 +1,42 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/base.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune + +# -m debugpy --listen 5678 --wait-for-client +#python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +python src/llama_recipes/pipeline/inference.py \ +--model_name echat \ +--freeze_llm \ +--use_fp16 \ +--quantization \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \ +--custom_dataset.data_path /nfs/zhifu.gzf/data/IEMOCAP_full_release/datalist.jsonl \ +--batching_strategy custom \ +--custom_dataset.max_words 1024 \ +--num_epochs 1 \ +--batch_size_training 2 \ +--output_dir $output_dir \ +--ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1/model.pt" \ +--wav_path "/nfs/zhifu.gzf/data/IEMOCAP_full_release/Session5/sentences/wav/Ses05M_impro04/Ses05M_impro04_F035.wav" \ +--prompt """ + Please provide an emotional response based on the emotional speech you hear. + Remember to format your answer as follows: <|EMOTION|><|REPLY|>. + <|EMOTION|> is a standalone adjective. + <|REPLY|> is a reply based on a the speech. + """ \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/1" +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py index 0c41d0a4..7232f159 100644 --- a/src/llama_recipes/configs/datasets.py +++ b/src/llama_recipes/configs/datasets.py @@ -31,4 +31,9 @@ class custom_dataset: dataset: str = "custom_dataset" file: str = "examples/custom_dataset.py" train_split: str = "train" - test_split: str = "validation" \ No newline at end of file + test_split: str = "validation" + data_path: str = NotImplemented + train_data_path: str = NotImplemented + val_data_path: str = NotImplemented + max_words: int = NotImplemented + max_mel: int = NotImplemented \ No newline at end of file diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py index c89aff1f..7cf54289 100644 --- a/src/llama_recipes/configs/fsdp.py +++ b/src/llama_recipes/configs/fsdp.py @@ -10,7 +10,8 @@ class fsdp_config: mixed_precision: bool=True use_fp16: bool=False - sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + # sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: ShardingStrategy = ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool=True fsdp_cpu_offload: bool=False diff --git a/src/llama_recipes/configs/model.py b/src/llama_recipes/configs/model.py index 8b869b79..46c8668c 100644 --- a/src/llama_recipes/configs/model.py +++ b/src/llama_recipes/configs/model.py @@ -6,5 +6,7 @@ class model_config: llm_name: str = "llama-2-7b-hf" llm_path: str = "PATH/to/LLAMA/7B" encoder_name: str = None + encoder_ds_rate: int = 2 encoder_path: str = None - encoder_projector: str = "linear" \ No newline at end of file + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 \ No newline at end of file diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 354c534e..1239c7f5 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -36,3 +36,8 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation: bool = False + run_test_during_validation_file: str = "test.wav" + run_test_during_validation_prompt: str = "<|ASR|>" + freeze_llm: bool = False + freeze_encoder: bool = False diff --git a/src/llama_recipes/datasets/echat_dataset.py b/src/llama_recipes/datasets/echat_dataset.py new file mode 100644 index 00000000..65a67519 --- /dev/null +++ b/src/llama_recipes/datasets/echat_dataset.py @@ -0,0 +1,186 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal +import soundfile as sf + +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from llama_recipes.utils.compute_utils import calculate_output_length_1d + + +class EChatDataset(Dataset): + def __init__( + self, + dataset_config, + tokenizer=None, + split='train' + ): + super().__init__() + + self.dataset_config = dataset_config + self.tokenizer = tokenizer + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "<|{}|><|{}|>" + + with open(dataset_config.data_path, 'r') as file: + data = file.readlines() + + sentence_list = [] + + for item in data: + dialog_name, dialog = item.split('\t', 1) + dialog_list = eval(dialog) + for sentence_id in range(len(dialog_list)-2): + if 'emotion' in dialog_list[sentence_id].keys() and 'emotion' in dialog_list[sentence_id+1].keys(): + if dialog_list[sentence_id+1]['emotion'] != 'xxx': + sentence_dict = {} + sentence_dict['pre_wav'] = dialog_list[sentence_id]['wav'] + sentence_dict['post_emotion'] = dialog_list[sentence_id+1]['emotion'] + sentence_dict['post_trans'] = dialog_list[sentence_id+1]['trans'] + sentence_list.append(sentence_dict) + + total_sentence = len(sentence_list) + print(f"Using {total_sentence} sentence totally.") + if split == "train": + self.data = sentence_list[:int(total_sentence * 0.9)] + else: + self.data = sentence_list[int(total_sentence * 0.9):] + + # # debug + # if split == "train": + # self.data = sentence_list[:8] + # else: + # self.data = sentence_list[8:16] + + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + item = self.data[index] + + speech_raw = whisper.load_audio(item['pre_wav']) + # speech_raw = whisper.pad_or_trim(speech_raw) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0) + + prompt=""" + Please provide an emotional response based on the emotional speech you hear. + Remember to format your answer as follows: <|EMOTION|><|REPLY|>. + <|EMOTION|> is a standalone adjective. + <|REPLY|> is a reply based on a the speech. + """ + answer=""" + <|happy|><|The moon looks so beautiful tonight.|> + """ + + prompt = self.prompt_template.format(prompt) + answer = self.answer_template.format(item['post_emotion'], item['post_trans']) + + prompt_ids = self.tokenizer.encode(prompt) + + prompt_length = len(prompt_ids) + speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + speech_length = speech_length // 5 # ad-hoc for 5x cov1d downsample + speech_pseudo = torch.full((speech_length,),-1) + + example = prompt + answer #FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((speech_pseudo, example_ids)) # [speech,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [speech,prompt,answer,eos] + labels_ids[:speech_length + prompt_length] = -1 #[-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) #FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) #[False,False,True,True] + example_ids[~example_mask] = 0 #[speech,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX #[-100,answer,eos,-100] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + 'speech_mel': speech_mel, + 'speech_length': speech_length, + + } + + + def _wav2feat(self, data): + wav = data.reshape(1, -1) + + feats = torchaudio.compliance.kaldi.fbank( # 25ms and 10ms + wav, htk_compat=True, sample_frequency=16000, use_energy=False, + window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10 + ) + n_frames = feats.shape[0] + + p = self.target_length - n_frames + + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + feats = m(feats) + elif p < 0: + feats = feats[0:self.target_length, :] + + return feats.unsqueeze(0) # channels, frames, dim + + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat((sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + + speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples]) + speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0) + for s in samples]) + + speech_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + speech_mask[line, :sample['speech_length']] = 1 + + return { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + 'speech_mel': speech_mel, + 'speech_mask': speech_mask + } + + +def get_audio_dataset(dataset_config, tokenizer, split): + dataset = EChatDataset(dataset_config, tokenizer, split) + + return dataset diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py new file mode 100644 index 00000000..fd6be80b --- /dev/null +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -0,0 +1,161 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal +import soundfile as sf + +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from llama_recipes.utils.compute_utils import calculate_output_length_1d + + +class SpeechDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + # data_parallel_size = dist.get_world_size() + data_parallel_size = 1 + + # self.data_list = contents + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "<|{}|>" + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + # # debug + # if split == "train": + # self.data_list = contents[:80] + # else: + # self.data_list = contents[80:100] + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + speech_path = data_dict.get("source") + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") + + speech_raw = whisper.load_audio(speech_path) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1, 0) + + prompt = """ + <|ASR|> + """ + answer = """ + <|The moon looks so beautiful tonight.|> + """ + + prompt = self.prompt_template.format(prompt) + answer = self.answer_template.format(target) + + prompt_ids = self.tokenizer.encode(prompt) + + prompt_length = len(prompt_ids) + speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + speech_length = speech_length // 5 # ad-hoc for 5x cov1d downsample + speech_pseudo = torch.full((speech_length,), -1) + + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((speech_pseudo, example_ids)) # [speech,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [speech,prompt,answer,eos] + labels_ids[:speech_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [speech,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,answer,eos,-100] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + 'speech_mel': speech_mel, + 'speech_length': speech_length, + + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + + speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples]) + speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0) + for s in samples]) + + speech_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + speech_mask[line, :sample['speech_length']] = 1 + + return { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + 'speech_mel': speech_mel, + 'speech_mask': speech_mask + } + + + +def get_audio_dataset(dataset_config, tokenizer, split): + dataset = SpeechDatasetJsonl(dataset_config, tokenizer, split) + + return dataset diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py new file mode 100644 index 00000000..62f4c382 --- /dev/null +++ b/src/llama_recipes/datasets/speech_dataset_inference.py @@ -0,0 +1,144 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal +import soundfile as sf + +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from llama_recipes.utils.compute_utils import calculate_output_length_1d + + +class SpeechDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + # data_parallel_size = dist.get_world_size() + data_parallel_size = 1 + + # self.data_list = contents + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt_template = "USER: {}\n ASSISTANT:" + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + # # debug + # if split == "train": + # self.data_list = contents[:80] + # else: + # self.data_list = contents[80:100] + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + speech_path = data_dict.get("source") + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") + key = data_dict.get("key", None) + + speech_raw = whisper.load_audio(speech_path) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1, 0) + + speech_length = (speech_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + speech_length = speech_length // 5 # ad-hoc for 5x cov1d downsample + speech_pseudo = torch.full((speech_length,), -1) + + prompt = """ + <|ASR|> + """ + prompt = self.prompt_template.format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + + example_ids = torch.cat((speech_pseudo, prompt_ids)) # [speech,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + 'speech_mel': speech_mel, + 'speech_length': speech_length, + 'key': key, + 'target':target + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + + speech_mel_max_length = max([s['speech_mel'].shape[0] for s in samples]) + speech_mel = torch.stack([self.pad(s['speech_mel'], speech_mel_max_length, 0) + for s in samples]) + + speech_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + speech_mask[line, :sample['speech_length']] = 1 + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'speech_mel': speech_mel, + 'speech_mask': speech_mask, + 'keys': keys, + 'targets': targets + } + + + +def get_audio_dataset(dataset_config, tokenizer, split): + dataset = SpeechDatasetJsonl(dataset_config, tokenizer, split) + + return dataset diff --git a/src/llama_recipes/datasets/speech_text_dataset.py b/src/llama_recipes/datasets/speech_text_dataset.py deleted file mode 100644 index c8ed6375..00000000 --- a/src/llama_recipes/datasets/speech_text_dataset.py +++ /dev/null @@ -1,135 +0,0 @@ -import os.path as osp -import random -import json, yaml -import copy - -import numpy as np -from scipy import signal -import soundfile as sf - -import torch -import torchaudio -from torch.utils.data import Dataset -import whisper - - -prompt = ( - f"USER: {prompt}\n ASSISTANT:" - ) - -def apply_prompt_template(prompt, answer): - return prompt.format(prompt=prompt) - -class AudioDataset(Dataset): - def __init__( - self, - dataset_config, - tokenizer=None, - split='train' - ): - super().__init__() - self.data = torch.randn(100, 160000) - - self.dataset_config = dataset_config - self.max_words = dataset_config.max_words - self.tokenizer = tokenizer - self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - - def __len__(self) -> int: - return len(self.data) - - def __getitem__(self, index): - item = self.data[index] - - # load audio and pad/trim it to fit 30 seconds - audio_raw = whisper.load_audio(item) - audio_raw = whisper.pad_or_trim(audio_raw) - # make log-Mel spectrogram - audio_feats = whisper.log_mel_spectrogram(audio_raw) - - prompt=""" - Please provide an emotional response based on the emotional speech you hear. - Remember to format your answer as follows: <|EMOTION|><|DEGREE|><|REPLY|>. - <|EMOTION|> is a standalone adjective. - <|DEGREE|> is an number ranging from 0 to 2. - <|REPLY|> is a reply based on a the speech. - """ - answer=""" - <|happy|><2><|The moon looks so beautiful tonight.|> - """ - - prompt = apply_prompt_template(prompt=prompt) - example = prompt + answer - prompt_ids = torch.tensor( - self.tokenizer.encode(prompt), dtype=torch.int64 - ) - - example = self.tokenizer.encode(example) - example.append(self.tokenizer.eos_token_id) - padding = self.max_words - example.shape[0] - if padding > 0: - example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) - elif padding < 0: - example = example[: self.max_words] - labels = copy.deepcopy(example) - labels[: len(prompt)] = -1 - example_mask = example.ge(0) - label_mask = labels.ge(0) - example[~example_mask] = 0 - labels[~label_mask] = self.IGNORE_INDEX - - return { - "input_ids": example.tolist(), - "labels": labels.tolist(), - "attention_mask":example_mask.tolist(), - 'audio_mel': audio_mel - } - - - def _wav2feat(self, data): - wav = data.reshape(1, -1) - - feats = torchaudio.compliance.kaldi.fbank( # 25ms and 10ms - wav, htk_compat=True, sample_frequency=16000, use_energy=False, - window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10 - ) - n_frames = feats.shape[0] - - p = self.target_length - n_frames - - # cut and pad - if p > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, p)) - feats = m(feats) - elif p < 0: - feats = feats[0:self.target_length, :] - - return feats.unsqueeze(0) # channels, frames, dim - - - def pad(self, sequence, max_length, padding_idx=0): - if len(sequence) < max_length: - sequence = sequence + [padding_idx] * (max_length - len(sequence)) - else: - sequence = sequence[:max_length] - return sequence - - def collator(self, samples): - assert samples is not None - input_ids = torch.stack([s['input_ids'] for s in samples]) - labels = torch.stack([s['labels'] for s in samples]) - attention_mask = torch.stack([s['attention_mask'] for s in samples]) - - audio_feats = torch.stack([s['audio_feats'] for s in samples]) - return { - 'input_ids': input_ids, - 'labels': labels, - 'attention_mask': attention_mask, - 'audio_feats': audio_feats, - } - - -def get_audio_dataset(dataset_config, tokenizer, split): - dataset = AudioDataset(dataset_config, tokenizer, split) - - return dataset diff --git a/src/llama_recipes/model_checkpointing/__init__.py b/src/llama_recipes/model_checkpointing/__init__.py index 9474f78c..76781b02 100644 --- a/src/llama_recipes/model_checkpointing/__init__.py +++ b/src/llama_recipes/model_checkpointing/__init__.py @@ -8,5 +8,6 @@ save_optimizer_checkpoint, save_model_and_optimizer_sharded, load_model_sharded, - load_sharded_model_single_gpu + load_sharded_model_single_gpu, + save_model_checkpoint_peft, ) diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index b097df97..d1d6eed6 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - +import os from pathlib import Path from datetime import datetime import torch @@ -160,7 +160,28 @@ def save_model_checkpoint( print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - +def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): + print(f"--> saving model ...") + save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch)) + os.makedirs(save_dir, exist_ok=True) + if not cfg.freeze_llm: + model.llm.save_pretrained(save_dir) + + save_full_path = os.path.join(save_dir, "model.pt") + cpu_state = model.state_dict() + project_dict = {} + if not cfg.freeze_encoder: + for key in cpu_state.keys(): + if key.startswith("encoder."): + project_dict[key] = cpu_state[key] + for key in cpu_state.keys(): + if key.startswith("encoder_projector."): + project_dict[key] = cpu_state[key] + torch.save(project_dict, save_full_path) + + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + + def load_model_checkpoint(model, rank, cfg): """load local checkpoint to rank0 cpu diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index 1c12235f..0c362786 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -1,8 +1,11 @@ +import os import types import torch import soundfile as sf import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist +from typing import List, Optional, Tuple, Union from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training from transformers import ( LlamaForCausalLM, @@ -10,12 +13,15 @@ LlamaConfig, ) import whisper -import librosa from llama_recipes.utils.config_utils import generate_peft_config +from llama_recipes.utils.train_utils import print_module_size +from peft import PeftModel, PeftConfig +from torch.nn import CrossEntropyLoss +from llama_recipes.utils.metric import compute_accuracy -def setup_model(train_config, model_config, **kwargs): +def setup_model(tokenizer, train_config, model_config, **kwargs): return slam_model(tokenizer, train_config, model_config, **kwargs) @@ -46,6 +52,25 @@ def extract_variable_length_features(self, x: torch.Tensor): x = self.ln_post(x) return x +def setup_encoder(train_config, model_config, **kwargs): + encoder_list = model_config.encoder_name.split(",") + if len(encoder_list) == 1: + encoder_name = encoder_list[0] + if encoder_name == "whisper" or "qwen-audio": + encoder = whisper.load_model(model_config.encoder_path).encoder + encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) + if encoder_name == "audio-mae": #TODO + pass + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + + if train_config.freeze_encoder: + for name, param in encoder.named_parameters(): + param.requires_grad = False + encoder.eval() + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + + return encoder + def setup_llm(train_config, model_config, **kwargs): from pkg_resources import packaging use_cache = False if train_config.enable_fsdp else None @@ -56,11 +81,12 @@ def setup_llm(train_config, model_config, **kwargs): model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") + # v = packaging.version.parse(torch.__version__) + # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + # if not verify_latest_nightly: + # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + # "please install latest nightly.") + rank = int(os.environ["RANK"]) if rank == 0: model = LlamaForCausalLM.from_pretrained( model_config.llm_path, @@ -71,8 +97,8 @@ def setup_llm(train_config, model_config, **kwargs): else: llama_config = LlamaConfig.from_pretrained(model_config.llm_path) llama_config.use_cache = use_cache - with torch.device("meta"): - model = LlamaForCausalLM(llama_config) + # with torch.device("meta"): + model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` else: model = LlamaForCausalLM.from_pretrained( @@ -93,19 +119,75 @@ def setup_llm(train_config, model_config, **kwargs): except ImportError: print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled if train_config.quantization: model = prepare_model_for_kbit_training(model) + if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + if train_config.use_peft: peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() + + if kwargs.get("peft_ckpt", None): + print("loading peft_ckpt from: ", kwargs.get("peft_ckpt")) + model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt")) + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) return model +def setup_encoder_projector(train_config, model_config, **kwargs): + if model_config.encoder_projector == "linear": + encoder_projector = EncoderProjectorConcat(model_config) + print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + return encoder_projector + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.encoder_projector_ds_rate + self.linear1 = nn.Linear(1280 * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, 4096) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + +class EncoderProjectorCov1d(nn.Module): + def __init__(self, config): + super(self).__init__() + self.conv1d = nn.Conv1d(in_channels=1280, out_channels=1280, kernel_size=config.encoder_projector_ds_rate, stride=config.encoder_projector_ds_rate, padding=0) + self.linear1 = nn.Linear(1280, 2048) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(2048, 4096) + self.relu2 = nn.ReLU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv1d(x) + x = x.transpose(1, 2) + x = self.relu1(x) + x = self.linear1(x) + x = self.relu2(x) + x = self.linear2(x) + return x + class slam_model(nn.Module): def __init__( @@ -116,19 +198,172 @@ def __init__( **kwargs ): super().__init__() - # whisper - self.speech_encoder = whisper.load_model(model_config.encoder_path).encoder - self.speech_encoder.extract_features = types.MethodType(extract_variable_length_features, self.speech_encoder) - for name, param in self.speech_encoder.named_parameters(): - param.requires_grad = False - self.speech_encoder.eval() - self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) - - # llama - llm = setup_llm(train_config, model_config, **kwargs) - - # Projector - self.speech_encoder_projector = nn.Linear(self.speech_encoder.config.d_model ,self.llm.config.hidden_size) - - def forward(self): - pass #TODO \ No newline at end of file + # modality encoder + self.encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + self.llm = setup_llm(train_config, model_config, **kwargs) + + # projector + self.encoder_projector = setup_encoder_projector(train_config, model_config, **kwargs) + + # tokenizer + self.tokenizer = tokenizer + self.metric = kwargs.get("metric", "acc") + + self.train_config = train_config + self.model_config = model_config + + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + speech_mel = kwargs.get("speech_mel", None) + speech_mask = kwargs.get("speech_mask", None) + + encoder_outs = None + if speech_mel is not None: + encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) # bs*seq*dim + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if speech_mask is not None: + batch_size, token_num, dims = inputs_embeds.shape + _, l, _ = encoder_outs.shape + encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) + inputs_embeds = encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) + + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + + acc = -1 + if self.metric: + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) + + return model_outputs, acc + + @torch.no_grad() + def generate(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + speech_mel = kwargs.get("speech_mel", None) + speech_mask = kwargs.get("speech_mask", None) + + encoder_outs = None + if speech_mel is not None: + encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) # bs*seq*dim + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if speech_mask is not None: + batch_size, token_num, dims = inputs_embeds.shape + _, l, _ = encoder_outs.shape + encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) + inputs_embeds = encoder_outs_pad * speech_mask[:, :, None] + inputs_embeds * (~speech_mask[:, :, None]) + + model_outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + max_length=kwargs.get("max_length", 200), + num_beams=kwargs.get("num_beams", 4), + do_sample=kwargs.get("do_sample", False), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 0.9), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + attention_mask=attention_mask, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + ) + + return model_outputs + + @torch.no_grad() + def inference( + self, + wav_path = None, + prompt = None, + generation_config = None, + logits_processor = None, + stopping_criteria = None, + prefix_allowed_tokens_fn = None, + synced_gpus = None, + assistant_model = None, + streamer = None, + negative_prompt_ids = None, + negative_prompt_attention_mask = None, + **kwargs, + ): + + device = kwargs.get("device", "cuda") + assert os.path.exists(wav_path) + speech_raw = whisper.load_audio(wav_path) + # speech_raw = whisper.pad_or_trim(speech_raw) + speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1,0)[None, :, :].to(device) + + encoder_outs = self.encoder.extract_variable_length_features(speech_mel.permute(0, 2, 1)) + encoder_outs = self.encoder_projector(encoder_outs) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat((encoder_outs, inputs_embeds[None, :, :]), dim=1) # [speech,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs + ) + + output_text = self.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) + + return output_text \ No newline at end of file diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index 3b5fbf94..32347daf 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -6,6 +6,7 @@ # nn import torch +from transformers.models.llama.modeling_llama import LlamaDecoderLayer # opt import torch.optim as optim @@ -37,7 +38,6 @@ setup, setup_environ_flags, clear_gpu_cache, - print_model_size, get_policies ) @@ -60,6 +60,7 @@ def main(**kwargs): local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) + print(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}") if torch.distributed.is_initialized(): torch.cuda.set_device(local_rank) @@ -67,6 +68,8 @@ def main(**kwargs): setup_environ_flags(rank) model, tokenizer = model_factory(train_config, model_config, **kwargs) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. + model.to(device) # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled @@ -84,7 +87,7 @@ def main(**kwargs): model = FSDP( model, - auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, + auto_wrap_policy= my_auto_wrapping_policy, #(FIX:MZY): Using my_auto_wrapping_policy whether peft or not. This will avoid model shard type check error of requires_grad mismatching. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, @@ -112,7 +115,7 @@ def main(**kwargs): dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, - split="test", + split="val", ) if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") diff --git a/src/llama_recipes/pipeline/inference.py b/src/llama_recipes/pipeline/inference.py new file mode 100644 index 00000000..6d3ca6da --- /dev/null +++ b/src/llama_recipes/pipeline/inference.py @@ -0,0 +1,48 @@ +import fire +import random +import torch +# import argparse +from llama_recipes.models.slam_model import slam_model +# config +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.configs import train_config as TRAIN_CONFIG +from llama_recipes.configs import model_config as MODEL_CONFIG +from llama_recipes.utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, + get_dataloader_kwargs, +) +from llama_recipes.pipeline.model_factory import model_factory + +def main(**kwargs): + + # Update the configuration for the training and sharding process + train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() + update_config((train_config, fsdp_config, model_config), **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + + model, tokenizer = model_factory(train_config, model_config, **kwargs) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. + model.to(device) + model.eval() + + while True: + print("=====================================") + wav_path = input("Your Wav Path:\n") + # prompt = input("Your Prompt:\n") + # wav_path = kwargs.get('wav_path') + prompt = kwargs.get('prompt') + try: + print(model.inference(wav_path, prompt)) + except: + continue + + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/src/llama_recipes/pipeline/inference_batch.py b/src/llama_recipes/pipeline/inference_batch.py new file mode 100644 index 00000000..841c5d18 --- /dev/null +++ b/src/llama_recipes/pipeline/inference_batch.py @@ -0,0 +1,68 @@ +import fire +import random +import torch +# import argparse +from llama_recipes.models.slam_model import slam_model +# config +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.configs import train_config as TRAIN_CONFIG +from llama_recipes.configs import model_config as MODEL_CONFIG +from llama_recipes.utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, + get_dataloader_kwargs, +) +from llama_recipes.pipeline.model_factory import model_factory +from llama_recipes.utils.dataset_utils import get_preprocessed_dataset + +def main(**kwargs): + + # Update the configuration for the training and sharding process + train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() + update_config((train_config, fsdp_config, model_config), **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + + model, tokenizer = model_factory(train_config, model_config, **kwargs) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. + model.to(device) + model.eval() + + dataset_config = generate_dataset_config(train_config, kwargs) + dataset_test = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="test", + ) + if not train_config.enable_fsdp or rank == 0: + print(f"--> Test Set Length = {len(dataset_test)}") + + test_dataloader = torch.utils.data.DataLoader( + dataset_test, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + shuffle=False, + batch_size=train_config.val_batch_size, + drop_last=False, + collate_fn=dataset_test.collator + ) + + + print("=====================================") + with open(kwargs.get('decode_log'), "w") as decode_log: + for step, batch in enumerate(test_dataloader): + for key in batch.keys(): + batch[key] = batch[key].to(device) if key not in ["keys", "targets"] else batch[key] + model_outputs = model.generate(**batch) + output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) + for key, text, target in zip(batch["keys"], output_text, batch["targets"]): + decode_log.write(key + "\t" + text + "\n") + decode_log.write(key + "\t" + target + "\n") + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/src/llama_recipes/pipeline/model_factory.py b/src/llama_recipes/pipeline/model_factory.py index 2b9419e0..6d7a41c7 100644 --- a/src/llama_recipes/pipeline/model_factory.py +++ b/src/llama_recipes/pipeline/model_factory.py @@ -1,8 +1,17 @@ +import torch from llama_recipes.models.slam_model import setup_model, setup_tokenizer +from llama_recipes.utils.train_utils import print_model_size +import os def model_factory(train_config, model_config, **kwargs): tokenizer = setup_tokenizer(train_config, model_config, **kwargs) model = setup_model(tokenizer, train_config, model_config, **kwargs) + ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + print("loading other parts from: ", ckpt_path) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) return model, tokenizer diff --git a/src/llama_recipes/utils/compute_utils.py b/src/llama_recipes/utils/compute_utils.py new file mode 100644 index 00000000..14328b29 --- /dev/null +++ b/src/llama_recipes/utils/compute_utils.py @@ -0,0 +1,3 @@ + +def calculate_output_length_1d(L_in, kernel_size, stride, padding=0): + return (L_in + 2 * padding - kernel_size) // stride + 1 \ No newline at end of file diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 3f8c9428..81ab9680 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -95,6 +95,17 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["drop_last"] = True kwargs["collate_fn"] = default_data_collator else: - raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") + # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") + if train_config.enable_fsdp: + kwargs["sampler"] = DistributedSampler( + dataset, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode=="train", + ) + kwargs["batch_size"] = batch_size + kwargs["drop_last"] = True + kwargs["collate_fn"] = dataset.collator + print(f"Using batching strategy: {train_config.batching_strategy}") return kwargs diff --git a/src/llama_recipes/utils/metric.py b/src/llama_recipes/utils/metric.py new file mode 100644 index 00000000..2de2129a --- /dev/null +++ b/src/llama_recipes/utils/metric.py @@ -0,0 +1,20 @@ +import torch + +def compute_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f154f965..d3851387 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -18,9 +18,15 @@ from transformers import LlamaTokenizer -from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint +from llama_recipes.model_checkpointing import( + save_model_checkpoint, + save_model_and_optimizer_sharded, + save_optimizer_checkpoint, + save_model_checkpoint_peft +) from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace +from llama_recipes.utils.metric import compute_accuracy def set_tokenizer_params(tokenizer: LlamaTokenizer): @@ -60,8 +66,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_prep = [] train_loss = [] + train_acc = [] val_prep = [] val_loss =[] + val_acc = [] epoch_times = [] checkpoint_times = [] results = {} @@ -71,6 +79,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 + total_acc = 0.0 total_length = len(train_dataloader)//gradient_accumulation_steps pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) for step, batch in enumerate(train_dataloader): @@ -80,9 +89,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: batch[key] = batch[key].to('cuda:0') with autocast(): - loss = model(**batch).loss + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 + loss = outputs.loss + loss = loss / gradient_accumulation_steps + acc = acc / gradient_accumulation_steps total_loss += loss.detach().float() + total_acc += acc if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward() @@ -99,7 +113,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche optimizer.zero_grad() pbar.update(1) - pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})") pbar.close() epoch_end_time = time.perf_counter()-epoch_start_time @@ -107,13 +121,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche # Reducing total_loss across all devices if there's more than one CUDA device if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) + train_epoch_acc = total_acc / len(train_dataloader) if train_config.enable_fsdp: train_epoch_loss = train_epoch_loss/world_size + train_epoch_acc = train_epoch_acc/world_size train_perplexity = torch.exp(train_epoch_loss) train_prep.append(train_perplexity) train_loss.append(train_epoch_loss) + train_acc.append(train_epoch_acc) if train_config.enable_fsdp: if rank==0: @@ -133,7 +151,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche lr_scheduler.step() if train_config.run_validation: - eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) checkpoint_start_time = time.perf_counter() if train_config.save_model and eval_epoch_loss < best_val_loss: if train_config.enable_fsdp: @@ -144,12 +162,35 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"we are about to save the PEFT modules") else: print(f"we are about to save the PEFT modules") - model.save_pretrained(train_config.output_dir) + if train_config.enable_fsdp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + dist.barrier() + else: + # model.save_pretrained(train_config.output_dir) + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) if train_config.enable_fsdp: if rank==0: print(f"PEFT modules are saved in {train_config.output_dir} directory") else: print(f"PEFT modules are saved in {train_config.output_dir} directory") + + elif not train_config.use_peft and train_config.freeze_llm: + print(f"llm is frozen, we are about to save other parts.") + if train_config.enable_fsdp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + dist.barrier() + else: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) else: if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: @@ -184,8 +225,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") else: print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - val_loss.append(best_val_loss) + val_loss.append(eval_epoch_loss) val_prep.append(eval_ppl) + if rest: + val_acc.append(rest[0]) + else: + val_acc.append(-1) + if train_config.run_test_during_validation: + if train_config.enable_fsdp: + if rank==0: + print("=====================================") + print(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + print(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + print("=====================================") + dist.barrier() + else: + print("=====================================") + print(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + print(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + print("=====================================") if train_config.enable_fsdp: if rank==0: print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") @@ -195,15 +255,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep)/len(train_prep) avg_train_loss = sum(train_loss)/len(train_loss) + avg_train_acc = sum(train_acc)/len(train_acc) if train_config.run_validation: avg_eval_prep = sum(val_prep)/len(val_prep) avg_eval_loss = sum(val_loss)/len(val_loss) + avg_eval_acc = sum(val_acc)/len(val_acc) results['avg_train_prep'] = avg_train_prep results['avg_train_loss'] = avg_train_loss + results['avg_train_acc'] = avg_train_acc if train_config.run_validation: results['avg_eval_prep'] = avg_eval_prep results['avg_eval_loss'] = avg_eval_loss + results['avg_eval_acc'] = avg_eval_acc results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time @@ -230,6 +294,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): model.eval() eval_preds = [] eval_loss = 0.0 # Initialize evaluation loss + eval_acc = 0.0 + autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm + with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): for key in batch.keys(): @@ -240,9 +307,13 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss - outputs = model(**batch) + with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 loss = outputs.loss + eval_loss += loss.detach().float() + eval_acc += acc # Decode predictions and add to evaluation predictions list preds = torch.argmax(outputs.logits, -1) eval_preds.extend( @@ -252,21 +323,24 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): # If there's more than one CUDA device, reduce evaluation loss across all devices if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) # Compute average loss and perplexity eval_epoch_loss = eval_loss / len(eval_dataloader) + eval_epoch_acc = eval_acc / len(eval_dataloader) if train_config.enable_fsdp: eval_epoch_loss = eval_epoch_loss/world_size + eval_epoch_acc = eval_epoch_acc/world_size eval_ppl = torch.exp(eval_epoch_loss) # Print evaluation metrics if train_config.enable_fsdp: if local_rank==0: - print(f" {eval_ppl=} {eval_epoch_loss=}") + print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") else: - print(f" {eval_ppl=} {eval_epoch_loss=}") + print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") - return eval_ppl, eval_epoch_loss + return eval_ppl, eval_epoch_loss, eval_epoch_acc def freeze_transformer_layers(model, num_layer): for i, layer in enumerate(model.model.layers): @@ -333,7 +407,19 @@ def print_model_size(model, config, rank: int = 0) -> None: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") +def print_module_size(module, module_name, rank: int = 0) -> None: + """ + Print module name, the number of trainable parameters and initialization time. + Args: + module: The PyTorch module. + module_name (str): Name of the model. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + print(f"--> Module {module_name}") + total_params = sum(p.numel() for p in module.parameters() if p.requires_grad) + print(f"\n--> {module_name} has {total_params / 1e6} Million params\n") def get_policies(cfg, rank): diff --git a/tests/test_whisper.py b/tests/test_whisper.py new file mode 100644 index 00000000..6c1d45af --- /dev/null +++ b/tests/test_whisper.py @@ -0,0 +1,52 @@ +import whisper +import torch +import types +import torch.nn.functional as F + +# model = whisper.load_model("/home/oss/maziyang.mzy/models/Whisper/base.pt") +encoder = whisper.load_model("/home/oss/maziyang.mzy/models/Whisper/base.pt").encoder + +# load audio and pad/trim it to fit 30 seconds +audio = whisper.load_audio("/root/whisper/tests/jfk.flac") +# audio = whisper.pad_or_trim(audio) + +# make log-Mel spectrogram and move to the same device as the model +mel = whisper.log_mel_spectrogram(audio).to("cuda") +print(mel.shape) + +def extract_features(self, x: torch.Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + # x = (x + self.positional_embedding).to(x.dtype) + x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + +encoder.extract_features = types.MethodType(extract_features, encoder) + +# # detect the spoken language +# _, probs = model.detect_language(mel) +# print(f"Detected language: {max(probs, key=probs.get)}") + +# get encoder output +mel = mel.unsqueeze(0) +encoder_out = encoder.extract_features(mel) +print(encoder_out.shape) + +# # decode the audio +# options = whisper.DecodingOptions() +# result = whisper.decode(model, mel, options) + +# # print the recognized text +# print(result.text) \ No newline at end of file