diff --git a/.gitignore b/.gitignore index c72a47be..0e0e2506 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,10 @@ .DS_Store __pycache__ .ipynb_checkpoints +.vscode +debug.py .idea/* transformers +wandb/ +*.log +log \ No newline at end of file diff --git a/examples/vllm/inference.py b/examples/vllm/inference.py index e587bc03..8b93d56e 100644 --- a/examples/vllm/inference.py +++ b/examples/vllm/inference.py @@ -7,6 +7,9 @@ from vllm import LLM from vllm import LLM, SamplingParams +import logging +logger = logging.getLogger(__name__) + torch.cuda.manual_seed(42) torch.manual_seed(42) @@ -27,15 +30,15 @@ def main( if user_prompt is None: user_prompt = input("Enter your prompt: ") - print(f"User prompt:\n{user_prompt}") + logger.info(f"User prompt:\n{user_prompt}") - print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request") + logger.info(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request") sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens) outputs = model.generate(user_prompt, sampling_params=sampling_param) - print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}") + logger.info(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}") user_prompt = input("Enter next prompt (press Enter to exit): ") if not user_prompt: break diff --git a/scripts/finetune_avsr.sh b/scripts/finetune_avsr.sh new file mode 100644 index 00000000..171e3c4d --- /dev/null +++ b/scripts/finetune_avsr.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth +speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt +llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune + +# -m debugpy --listen 5678 --wait-for-client +python src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--use_peft --peft_method lora \ +--quantization \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 1 \ +--batch_size_training 2 \ +--output_dir $output_dir \ No newline at end of file diff --git a/scripts/finetune_avsr_debug.sh b/scripts/finetune_avsr_debug.sh new file mode 100644 index 00000000..f6728478 --- /dev/null +++ b/scripts/finetune_avsr_debug.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 + +cd /root/SLAM-LLM + +audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth +speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt + +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf #/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/yangguanrou.ygr/ckpts/llama-2-hf-finetune #/home/oss/yangguanrou.ygr/ckpts/llama-2-hf-finetune + +# -m debugpy --listen 5680 --wait-for-client +python -m debugpy --listen 5680 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--use_peft --peft_method lora \ +--quantization \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 1 \ +--batch_size_training 2 \ +--output_dir $output_dir \ +--stepSize 10 \ +--log_file "/root/SLAM-LLM/log/test.log" \ +--valid_subset "LRS3/val_debug.txt" \ \ No newline at end of file diff --git a/scripts/finetune_avsr_debug_1214.sh b/scripts/finetune_avsr_debug_1214.sh new file mode 100644 index 00000000..d44bef08 --- /dev/null +++ b/scripts/finetune_avsr_debug_1214.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +cd /root/SLAM-LLM + + +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune + +# -m debugpy --listen 5678 --wait-for-client +python -m debugpy --listen 5679 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--freeze_llm \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 1 \ +--batch_size_training 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/test.log" \ + + +# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ \ No newline at end of file diff --git a/scripts/finetune_avsr_debug_1218.sh b/scripts/finetune_avsr_debug_1218.sh new file mode 100644 index 00000000..01ac0f9a --- /dev/null +++ b/scripts/finetune_avsr_debug_1218.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +cd /root/SLAM-LLM + + +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune + +# -m debugpy --listen 5678 --wait-for-client +python src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--freeze_llm \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 1 \ +--batch_size_training 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/test.log" \ + + +# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ \ No newline at end of file diff --git a/scripts/finetune_speech_pretraining_my.sh b/scripts/finetune_speech_pretraining_my.sh new file mode 100644 index 00000000..27e4501b --- /dev/null +++ b/scripts/finetune_speech_pretraining_my.sh @@ -0,0 +1,89 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1 +export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048 + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \ +--custom_dataset.train_data_path /nfs/beinian.lzr/workspace/datasets/speech_llm/train_dataset/data_wav_json/asr/librispeech_train_960h_wav_speech_llm_train_data.json \ +--custom_dataset.val_data_path /nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/dev_other/librispeech_dev_other.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 4 \ +--val_batch_size 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav" \ +--run_test_during_validation_prompt "<|ASR|>" \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/test.log" \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +# --use_peft --peft_method lora \ + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--use_fp16 \ +--enable_fsdp \ +--llm_name llama-2-7b-hf \ +--llm_path $llm_path \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset custom_dataset \ +--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \ +--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h_wav_speech_llm_train_data.json \ +--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 8 \ +--val_batch_size 8 \ +--num_workers_dataloader 4 \ +--lr 1e-5 \ +--output_dir $output_dir \ +--run_test_during_validation \ +--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \ +--run_test_during_validation_prompt "<|ASR|>" \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/test.log" \ +# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +# --use_peft --peft_method lora \ +fi + +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py index bc754832..12da6da4 100644 --- a/src/llama_recipes/configs/datasets.py +++ b/src/llama_recipes/configs/datasets.py @@ -32,9 +32,34 @@ class custom_dataset: file: str = "examples/custom_dataset.py" train_split: str = "train" test_split: str = "validation" - data_path: str = NotImplemented - train_data_path: str = NotImplemented - val_data_path: str = NotImplemented - max_words: int = NotImplemented - max_mel: int = NotImplemented - fix_length_audio: int = -1 \ No newline at end of file + data_path: str = None + max_words: int = None + train_data_path: str = None + val_data_path: str = None + max_words: int = None + max_mel: int = None + fix_length_audio: int = -1 + + + +@dataclass +class avsr_dataset: + dataset: str = "avsr_dataset" + file: str = "examples/avsr_dataset.py" + train_split: str = "train" + test_split: str = "val" + data_path: str = "/nfs/yangguanrou.ygr/" #"/home/oss/yangguanrou.ygr/" + h5file: str = "/nfs/yangguanrou.ygr/LRS3/LRS3.h5" # "/home/oss/yangguanrou.ygr/LRS3/LRS3.h5" + noiseFile : str = "/nfs/yangguanrou.ygr/AVSR/LRS3/Noise.h5" #"/home/oss/yangguanrou.ygr/AVSR/LRS3/Noise.h5" + noiseProb: float = 0. + noiseSNR: float = 5 + stepSize: int = 16384 + # charToIx={" ": 1, "'": 22, "1": 30, "0": 29, "3": 37, "2": 32, "5": 34, "4": 38, "7": 36, "6": 35, "9": 31, "8": 33, "A": 5, "C": 17, + # "B": 20, "E": 2, "D": 12, "G": 16, "F": 19, "I": 6, "H": 9, "K": 24, "J": 25, "M": 18, "L": 11, "O": 4, "N": 7, "Q": 27, + # "P": 21, "S": 8, "R": 10, "U": 13, "T": 3, "W": 15, "V": 23, "Y": 14, "X": 26, "Z": 28, "": 39} + charToIx : str = "x" #应该没用了 TypeError: Object of type NotImplementedType is not JSON serializable 但这个是上面的问题 + modal: str = "AV" + pretrain_subset: str = "LRS3/pretrain.txt" + train_subset: str = "LRS3/train.txt" + valid_subset: str = "LRS3/val.txt" + test_subset: str = "LRS3/test.txt" diff --git a/src/llama_recipes/configs/model.py b/src/llama_recipes/configs/model.py index 46c8668c..0efef860 100644 --- a/src/llama_recipes/configs/model.py +++ b/src/llama_recipes/configs/model.py @@ -9,4 +9,33 @@ class model_config: encoder_ds_rate: int = 2 encoder_path: str = None encoder_projector: str = "linear" - encoder_projector_ds_rate: int = 5 \ No newline at end of file + encoder_projector_ds_rate: int = 5 + + DMODEL: int = 512 + FRONTEND_DMODEL: int = 1024 #这个是专门指moco的 + TX_ATTENTION_HEADS: int = 8 + TX_NUM_LAYERS: int = 6 + PE_MAX_LENGTH: int = 500 + AUDIO_FEATURE_SIZE: int = 1024 + VIDEO_FEATURE_SIZE: int = 2048 + TX_FEEDFORWARD_DIM: int= 2048 + TX_DROPOUT: int = 0.1 + CHAR_NUM_CLASSES: int = 40 + + WORD_NUM_CLASSES: int = 500 + FRAME_LENGTH: int = 29 + MOCO_FRONTEND_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" + WAV2VEC_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" + MAIN_REQ_INPUT_LENGTH: int = 80 + modal: str = "AV" + TRAIN_LRS3_MODEL_FILE: str = "/nfs/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" # "/home/oss/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" #单一模态是这个 + TRAINED_AO_FILE : str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" + TRAINED_VO_FILE: str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" + + + + + + + + diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 1239c7f5..65c13d8c 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -36,8 +36,10 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + log_file: str="PATH/to/Log_File" run_test_during_validation: bool = False run_test_during_validation_file: str = "test.wav" run_test_during_validation_prompt: str = "<|ASR|>" freeze_llm: bool = False freeze_encoder: bool = False + log_interval: int = 5 diff --git a/src/llama_recipes/datasets/avsr_dataset.py b/src/llama_recipes/datasets/avsr_dataset.py new file mode 100644 index 00000000..a40284ec --- /dev/null +++ b/src/llama_recipes/datasets/avsr_dataset.py @@ -0,0 +1,527 @@ +import h5py +import numpy as np +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import torch + +import cv2 as cv +from torch.nn.utils.rnn import pad_sequence + +import logging +logger = logging.getLogger(__name__) + +class AVSRDataset(Dataset): + def __init__(self, dataset_config, tokenizer=None, split='train'): + super().__init__() + + self.tokenizer = tokenizer + self.modal = dataset_config.modal + self.dataset = split #train|val|test + self.data_path = dataset_config.data_path + self.h5file = dataset_config.h5file + self.noiseFile = dataset_config.noiseFile + self.noiseSNR = dataset_config.noiseSNR + self.noiseProb = dataset_config.noiseProb + self.stepSize = dataset_config.stepSize #16384 + self.charToIx = dataset_config.charToIx + self.pretrain_subset = dataset_config.pretrain_subset + self.train_subset = dataset_config.train_subset + self.valid_subset = dataset_config.valid_subset + self.test_subset= dataset_config.test_subset + + if self.dataset == "train": + pretrain_dir = self.data_path + self.pretrain_subset # "LRS3/pretrain.txt" + train_dir = self.data_path + self.train_subset # "LRS3/train.txt" + + with open(pretrain_dir, "r") as f: + lines = f.readlines() + pretrain_datalist = [self.data_path + line.strip()[3:] for line in lines] #长度:118516 + + with open(train_dir, "r") as f: + lines = f.readlines() + train_datalist = [self.data_path + line.strip()[3:] for line in lines] #长度:31662 + + self.datalist = pretrain_datalist+ train_datalist + lrs3Aug=True + + elif self.dataset == "val": + val_dir = self.data_path + self.valid_subset # "LRS3/val.txt" + with open(val_dir, "r") as f: + lines = f.readlines() + val_datalist = [self.data_path + line.strip()[3:] for line in lines] + self.datalist = val_datalist + lrs3Aug=False + + else: + test_dir = self.data_path + self.test_subset # "LRS3/test.txt" + with open(test_dir, "r") as f: + lines = f.readlines() + test_datalist = [self.data_path + line.strip()[3:] for line in lines] + self.datalist = test_datalist + lrs3Aug=False + + with h5py.File(self.noiseFile, "r") as f: #{'noiseFile': '/home/xcpan/LRS2/mvlrs_v1/Noise.h5', 'noiseProb': 0.25, 'noiseSNR': 5} + self.noise = f["noise"][0] #ndarray:57600000 + + if lrs3Aug: + self.transform = transforms.Compose([ + ToTensor(), + RandomCrop(112), + RandomHorizontalFlip(0.5), + Normalize(mean=[0.4161], std=[0.1688]) + ]) + else: + self.transform = transforms.Compose([ + ToTensor(), + CenterCrop(112), + Normalize(mean=[0.4161], std=[0.1688]) + ]) + + def open_h5(self): + self.h5 = h5py.File(self.h5file, "r") + + def __getitem__(self, index): #avsr 是shuffle的dataloader echat好像默认false 没shu index从0开始 + """ + LRS3 : pretrain 118516 train 31662 val 320 test 1321 + LRS2 : pretrain 96318 train 45839 val 1082 test 1243 142157 = 96318 + 45839 = pretrain + train 143239 = 96318+45839+1082=pretrain+train+val + + index goes from 0 to stepSize-1 + dividing the dataset into partitions of size equal to stepSize and selecting a random partition + fetch the sample at position 'index' in this randomly selected partition + """ + + if not hasattr(self, 'h5'): + self.open_h5() + + if self.dataset == "train": #index=610 + base = self.stepSize * np.arange(int(len(self.datalist) / self.stepSize) + 1) # datalist, 118516 应该全是pretrain的 从pretrain.txt 搞出来的 # stepsize 16384 + ixs = base + index # [ 0 16384 32768 49152 65536 81920 98304 114688 131072 147456] + ixs = ixs[ixs < len(self.datalist)] # [ 610 16994 33378 49762 66146 82530 98914 115298] + index = ixs[0] if len(ixs) == 1 else np.random.choice(ixs) #以某种方式随机采样 #33378 + + if index==99639 or index== 71740 or index==19753 or index==14116 or index==49729 or index==26726: #dirty data + index+=1 + + # passing the sample files and the target file paths to the prepare function to obtain the input tensors + targetFile = self.datalist[index] + ".txt" + if self.dataset == "val": + index += 150178 # 原本 142157 + elif self.dataset == "test": + index += 150498 # 原本 143239 + + if np.random.choice([True, False], p=[self.noiseProb, 1 - self.noiseProb]): + noise = self.noise + else: + noise = None + + if index < 118516: #原本是96318 查过了 这个数确实是lrs2的那个行数 也就是文件数 原本应该是pretrain处理的 有一部分搞到main处理了 所以没有crop 导致超过500 + #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) + inp, trgtin, trgtout, trgtLen = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) + if inp==0 and trgtin ==0 and trgtout ==0 and trgtLen==0: + index+=1 + targetFile = self.datalist[index] + ".txt" + #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile,self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) #就只是往后挪了一格 很弱 + inp, trgtin, trgtout, trgtLen = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile,self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) #就只是往后挪了一格 很弱 + else: + #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_main_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR) + inp, trgtin, trgtout, trgtLen = self.prepare_main_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR) + + return inp, trgtin, trgtout, trgtLen #, trgttext #VO (none,(72,1,112,112) ) + + def __len__(self): + """ + each iteration covers only a random subset of all the training samples whose size is given by the step size step size的作用在这里 感觉也没什么大用 + this is done only for the pretrain set, while the whole val/test set is considered + """ + + if self.dataset == "train": + return self.stepSize + else: + return len(self.datalist) + + def collator(self, dataBatch): + # audio & mask + if not self.modal == "VO": + aud_seq_list = [data[0][0] for data in dataBatch] + aud_padding_mask = torch.zeros((len(aud_seq_list), len(max(aud_seq_list, key=len))), dtype=torch.bool) + for i, seq in enumerate(aud_seq_list): + aud_padding_mask[i, len(seq):] = True + aud_seq_list = pad_sequence(aud_seq_list, batch_first=True) #可以通过设置 batch_first=True 参数来指定输出的tensor中是否将batch维度放在第一维度 + else: + aud_seq_list = None + aud_padding_mask = None + # visual & len + if not self.modal == "AO": + vis_seq_list = pad_sequence([data[0][1] for data in dataBatch], batch_first=True) #(4,147,1,112,112) #pad_sequence((none,62,1,112,112)) + vis_len = torch.tensor([len(data[0][1]) for data in dataBatch]) #就是这四个句子每一个的长度 tensor([ 62, 62, 97, 147]) #时间帧上pad + else: + vis_seq_list = None + vis_len = None + + inputBatch = (aud_seq_list, aud_padding_mask, vis_seq_list, vis_len) #!!! + + targetinBatch = pad_sequence([data[1] for data in dataBatch], batch_first=True) + targetoutBatch = pad_sequence([data[2] for data in dataBatch], batch_first=True) + targetLenBatch = torch.stack([data[3] for data in dataBatch]) + + if self.modal == "AO": + inputBatch = (inputBatch[0].float(), inputBatch[1], None, None) + elif self.modal == "VO": + inputBatch = (None, None, inputBatch[2].float(), inputBatch[3].int()) + else: + inputBatch = (inputBatch[0].float(), inputBatch[1], inputBatch[2].float(), inputBatch[3].int()) + + targetinBatch = targetinBatch.int() + targetoutBatch = targetoutBatch.int() + targetLenBatch = targetLenBatch.int() + targetMask = torch.zeros_like(targetoutBatch, device=targetoutBatch.device) + targetMask[(torch.arange(targetMask.shape[0]), targetLenBatch.long() - 1)] = 1 + targetMask = (1 - targetMask.flip([-1]).cumsum(-1).flip([-1])).bool() + + return { + "inputBatch0": inputBatch[0], + "inputBatch1": inputBatch[1], + "inputBatch2": inputBatch[2], + "inputBatch3": inputBatch[3], + + "targetoutBatch": targetoutBatch, + "targetLenBatch": targetLenBatch.long(), + 'maskw2v': True, + } + + def prepare_pretrain_input(self,index, modal, h5, targetFile, charToIx, transform, noise, noiseSNR, numWordsRange, maxLength): #(3,21) 160 + """ + Function to convert the data sample in the pretrain dataset into appropriate tensors. + """ + + try: + with open(targetFile, "r") as f: + lines = f.readlines() + except: + logger.info("error") + logger.info(targetFile) + logger.info(index) + return 0, 0, 0, 0 + + lines = [line.strip() for line in lines] + + trgt = lines[0][7:] + + coun = trgt.count("{") + for i in range(coun): + left = trgt.find("{") + if left != -1: + right = trgt.find("}") + trgt = trgt.replace(trgt[left:right + 2], "") + + trgt=trgt.strip() + words = trgt.split(" ") + + numWords = len(words) // 3 + if numWords < numWordsRange[0]: #3 #(numwordsRange 是个tuple(3,21) + numWords = numWordsRange[0] + elif numWords > numWordsRange[1]: #21 + numWords = numWordsRange[1] + + while True: + # if number of words in target is less than the required number of words, consider the whole target + if len(words) <= numWords: + trgtNWord = trgt + + # audio file + if not modal == "VO": + audInp = np.array(h5["flac"][index]) + audInp = (audInp - audInp.mean()) / audInp.std() + if noise is not None: + pos = np.random.randint(0, len(noise) - len(audInp) + 1) + noise = noise[pos:pos + len(audInp)] + noise = noise / np.max(np.abs(noise)) + gain = 10 ** (noiseSNR / 10) + noise = noise * np.sqrt(np.sum(audInp ** 2) / (gain * np.sum(noise ** 2))) + audInp = audInp + noise + audInp = torch.from_numpy(audInp) + else: + audInp = None + + # visual file + if not modal == "AO": + try: + vidInp = cv.imdecode(h5["png"][index], cv.IMREAD_COLOR) + vidInp = np.array(np.split(vidInp, range(120, len(vidInp[0]), 120), axis=1))[:, :, :, 0] + vidInp = torch.tensor(vidInp).unsqueeze(1) + vidInp = transform(vidInp) + except: + logger.info("error") + logger.info(targetFile) + logger.info(index) + return 0,0,0,0 + else: + vidInp = None + else: + # make a list of all possible sub-sequences with required number of words in the target + nWords = [" ".join(words[i:i + numWords]) + for i in range(len(words) - numWords + 1)] + nWordLens = np.array( + [len(nWord) + 1 for nWord in nWords]).astype(float) + + # choose the sub-sequence for target according to a softmax distribution of the lengths + # this way longer sub-sequences (which are more diverse) are selected more often while + # the shorter sub-sequences (which appear more frequently) are not entirely missed out + ix = np.random.choice(np.arange(len(nWordLens)), p=nWordLens / nWordLens.sum()) + trgtNWord = nWords[ix] + + # reading the start and end times in the video corresponding to the selected sub-sequence + startTime = float(lines[4 + ix].split(" ")[1]) + endTime = float(lines[4 + ix + numWords - 1].split(" ")[2]) + + # audio file + if not modal == "VO": + samplerate = 16000 + audInp = np.array(h5["flac"][index]) #(81920,) + audInp = (audInp - audInp.mean()) / audInp.std() + if noise is not None: + pos = np.random.randint(0, len(noise) - len(audInp) + 1) + noise = noise[pos:pos + len(audInp)] + noise = noise / np.max(np.abs(noise)) + gain = 10 ** (noiseSNR / 10) + noise = noise * np.sqrt(np.sum(audInp ** 2) / (gain * np.sum(noise ** 2))) + audInp = audInp + noise + audInp = torch.from_numpy(audInp) + audInp = audInp[int(samplerate * startTime):int(samplerate * endTime)] #!!!!!!! + else: + audInp = None + + # visual file + if not modal == "AO": + videoFPS = 25 + try: + vidInp = cv.imdecode(h5["png"][index], cv.IMREAD_COLOR) + vidInp = np.array(np.split(vidInp, range(120, len(vidInp[0]), 120), axis=1))[:, :, :, 0] ##这一句报错x + vidInp = torch.tensor(vidInp).unsqueeze(1) + vidInp = transform(vidInp) + vidInp = vidInp[int(np.floor(videoFPS * startTime)): int(np.ceil(videoFPS * endTime))] + except: + logger.info("error") + logger.info(targetFile) + logger.info(index) + return 0, 0, 0, 0 + + else: + vidInp = None + + """ + trgtin = [charToIx[item] for item in trgtNWord] #trgtNWord: 'POPULATION BY PROVIDING THEM A SAFE SPACE WHERE THESE GIRLS COULD COME AND MEET OTHER GIRLS READ SOME BOOKS PLAY SOME' + trgtout = [charToIx[item] for item in trgtNWord] + trgtin.insert(0, charToIx[""]) + trgtout.append(charToIx[""]) + """ + + # 替换成 + trgtin = self.tokenizer.encode(trgtNWord) #[1, 349, 4590, 13309, 8098, 6770, 13756, 13044, 4214, 6093, 29924, 319, 317, 5098, ...] + trgtout = self.tokenizer.encode(trgtNWord) + trgtin.insert(0, self.tokenizer.eos_token_id ) #[2,xxx] + trgtout.append(self.tokenizer.eos_token_id ) #[] + + trgtin = np.array(trgtin) + trgtout = np.array(trgtout) + trgtLen = len(trgtout) + + inp = (audInp, vidInp) + trgtin = torch.from_numpy(trgtin) + trgtout = torch.from_numpy(trgtout) + trgtLen = torch.tensor(trgtLen) + inpLen = len(vidInp) if not self.modal == "AO" else len(audInp) / 640 + if inpLen <= maxLength: #maxlength:160 + break + elif inpLen > maxLength + 80: + numWords -= 2 + else: + numWords -= 1 + + return inp, trgtin, trgtout, trgtLen #, trgtNWord + + def prepare_main_input(self, index, modal, h5, targetFile, charToIx, transform, noise, noiseSNR): + """ + Function to convert the data sample in the main dataset into appropriate tensors. + """ + with open(targetFile, "r") as f: + trgt = f.readline().strip()[7:] #'SO WE NEED YOU TO HELP US IN OUR REVIVAL CAMPAIGN' 'YOU ARE A HEALER IN A STONE AGE VILLAGE' + + coun = trgt.count("{") + for i in range(coun): + left = trgt.find("{") + if left != -1: + right = trgt.find("}") + trgt = trgt .replace(trgt [left:right + 2], "") + + """ + trgtin = [charToIx[item] for item in trgt] #[8, 4, 1, 15, 2, 1, 7, 2, 2, 12, 1, 14, 4, 13, 1, 3, 4, 1, 9, 2, 11, + trgtin.insert(0, charToIx[""]) #[39,8,4,...] + trgtout = [charToIx[item] for item in trgt] + trgtout.append(charToIx[""]) #[..,39] 在最后面加39 + """ + + trgtin = self.tokenizer.encode(trgt) + trgtout = self.tokenizer.encode(trgt) + trgtin.insert(0, self.tokenizer.eos_token_id ) + trgtout.append(self.tokenizer.eos_token_id ) + + trgtin = np.array(trgtin) + trgtout = np.array(trgtout) + trgtLen = len(trgtout) #50 + + # audio file + if not modal == "VO": + audInp = np.array(h5["flac"][index]) # ndarray(22528,) + audInp = (audInp - audInp.mean()) / audInp.std() + if noise is not None: + pos = np.random.randint(0, len(noise) - len(audInp) + 1) + noise = noise[pos:pos + len(audInp)] + noise = noise / np.max(np.abs(noise)) + gain = 10 ** (noiseSNR / 10) + noise = noise * np.sqrt(np.sum(audInp ** 2) / (gain * np.sum(noise ** 2))) + audInp = audInp + noise + audInp = torch.from_numpy(audInp) + else: + audInp = None + + # visual file + if not modal == "AO": + vidInp = cv.imdecode(h5["png"][index], cv.IMREAD_COLOR) #(120,2040,3) + vidInp = np.array(np.split(vidInp, range(120, len(vidInp[0]), 120), axis=1))[:, :, :, 0] #(17,120,120) + vidInp = torch.tensor(vidInp).unsqueeze(1) #(17,1,120,120) + vidInp = transform(vidInp) #(17,1,112,112) + else: + vidInp = None + + inp = (audInp, vidInp) + trgtin = torch.from_numpy(trgtin) + trgtout = torch.from_numpy(trgtout) + trgtLen = torch.tensor(trgtLen) + + return inp, trgtin, trgtout, trgtLen#,trgt #'THE FIRST TIME WHEN IT TOOK ME FIVE MONTHS FROM THE DECISION OF' + + +def get_audio_dataset(dataset_config, tokenizer, split): + dataset = AVSRDataset(dataset_config, tokenizer, split) + return dataset + +class ToTensor: + """Applies the :class:`~torchvision.transforms.ToTensor` transform to a batch of images. + """ + + def __init__(self): + self.max = 255 + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be tensorized. + Returns: + Tensor: Tensorized Tensor. + """ + return tensor.float().div_(self.max) + + +class Normalize: + """Applies the :class:`~torchvision.transforms.Normalize` transform to a batch of images. + .. note:: + This transform acts out of place by default, i.e., it does not mutate the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + dtype (torch.dtype,optional): The data type of tensors to which the transform will be applied. + device (torch.device,optional): The device of tensors to which the transform will be applied. + """ + + def __init__(self, mean, std, inplace=False, dtype=torch.float, device='cpu'): + self.mean = torch.as_tensor(mean, dtype=dtype, device=device)[None, :, None, None] + self.std = torch.as_tensor(std, dtype=dtype, device=device)[None, :, None, None] + self.inplace = inplace + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor. + """ + if not self.inplace: + tensor = tensor.clone() + + tensor.sub_(self.mean).div_(self.std) + return tensor + + +class RandomCrop: + """Applies the :class:`~torchvision.transforms.RandomCrop` transform to a batch of images. + Args: + size (int): Desired output size of the crop. + device (torch.device,optional): The device of tensors to which the transform will be applied. + """ + + def __init__(self, size, device='cpu'): + self.size = size + self.device = device + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be cropped. + Returns: + Tensor: Randomly cropped Tensor. + """ + margin = tensor.shape[-1] - self.size + hcrop = random.randint(0, margin - 1) + wcrop = random.randint(0, margin - 1) + tensor = tensor[:, :, hcrop:-(margin - hcrop), wcrop:-(margin - wcrop)] + return tensor + + +class CenterCrop: + + def __init__(self, size, device='cpu'): + self.size = size + self.device = device + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be cropped. + Returns: + Tensor: Randomly cropped Tensor. + """ + crop = (tensor.shape[-1] - self.size) // 2 + tensor = tensor[:, :, crop:-crop, crop:-crop] + return tensor + + +class RandomHorizontalFlip: + """Applies the :class:`~torchvision.transforms.RandomHorizontalFlip` transform to a batch of images. + .. note:: + This transform acts out of place by default, i.e., it does not mutate the input tensor. + Args: + p (float): probability of an image being flipped. + inplace(bool,optional): Bool to make this operation in-place. + """ + + def __init__(self, p=0.5, inplace=False): + self.p = p + self.inplace = inplace + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be flipped. + Returns: + Tensor: Randomly flipped Tensor. + """ + if not self.inplace: + tensor = tensor.clone() + + if random.random() < self.p: + tensor = torch.flip(tensor, dims=(3,)) + return tensor \ No newline at end of file diff --git a/src/llama_recipes/datasets/echat_dataset.py b/src/llama_recipes/datasets/echat_dataset.py index 65a67519..20f9c766 100644 --- a/src/llama_recipes/datasets/echat_dataset.py +++ b/src/llama_recipes/datasets/echat_dataset.py @@ -13,6 +13,8 @@ import whisper from llama_recipes.utils.compute_utils import calculate_output_length_1d +import logging +logger = logging.getLogger(__name__) class EChatDataset(Dataset): def __init__( @@ -47,7 +49,7 @@ def __init__( sentence_list.append(sentence_dict) total_sentence = len(sentence_list) - print(f"Using {total_sentence} sentence totally.") + logger.info(f"Using {total_sentence} sentence totally.") if split == "train": self.data = sentence_list[:int(total_sentence * 0.9)] else: @@ -105,7 +107,7 @@ def __getitem__(self, index): label_mask = labels_ids.ge(0) #[False,False,True,True] example_ids[~example_mask] = 0 #[speech,prompt,answer,eos] labels_ids[~label_mask] = self.IGNORE_INDEX #[-100,answer,eos,-100] - + return { "input_ids": example_ids, "labels": labels_ids, diff --git a/src/llama_recipes/datasets/vision_transform.py b/src/llama_recipes/datasets/vision_transform.py new file mode 100644 index 00000000..38004fbf --- /dev/null +++ b/src/llama_recipes/datasets/vision_transform.py @@ -0,0 +1,122 @@ +import random + +import torch + + +class ToTensor: + """Applies the :class:`~torchvision.transforms.ToTensor` transform to a batch of images. + """ + + def __init__(self): + self.max = 255 + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be tensorized. + Returns: + Tensor: Tensorized Tensor. + """ + return tensor.float().div_(self.max) + + +class Normalize: + """Applies the :class:`~torchvision.transforms.Normalize` transform to a batch of images. + .. note:: + This transform acts out of place by default, i.e., it does not mutate the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + dtype (torch.dtype,optional): The data type of tensors to which the transform will be applied. + device (torch.device,optional): The device of tensors to which the transform will be applied. + """ + + def __init__(self, mean, std, inplace=False, dtype=torch.float, device='cpu'): + self.mean = torch.as_tensor(mean, dtype=dtype, device=device)[None, :, None, None] + self.std = torch.as_tensor(std, dtype=dtype, device=device)[None, :, None, None] + self.inplace = inplace + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor. + """ + if not self.inplace: + tensor = tensor.clone() + + tensor.sub_(self.mean).div_(self.std) + return tensor + + +class RandomCrop: + """Applies the :class:`~torchvision.transforms.RandomCrop` transform to a batch of images. + Args: + size (int): Desired output size of the crop. + device (torch.device,optional): The device of tensors to which the transform will be applied. + """ + + def __init__(self, size, device='cpu'): + self.size = size + self.device = device + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be cropped. + Returns: + Tensor: Randomly cropped Tensor. + """ + margin = tensor.shape[-1] - self.size + hcrop = random.randint(0, margin - 1) + wcrop = random.randint(0, margin - 1) + tensor = tensor[:, :, hcrop:-(margin - hcrop), wcrop:-(margin - wcrop)] + return tensor + + +class CenterCrop: + + def __init__(self, size, device='cpu'): + self.size = size + self.device = device + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be cropped. + Returns: + Tensor: Randomly cropped Tensor. + """ + crop = (tensor.shape[-1] - self.size) // 2 + tensor = tensor[:, :, crop:-crop, crop:-crop] + return tensor + + +class RandomHorizontalFlip: + """Applies the :class:`~torchvision.transforms.RandomHorizontalFlip` transform to a batch of images. + .. note:: + This transform acts out of place by default, i.e., it does not mutate the input tensor. + Args: + p (float): probability of an image being flipped. + inplace(bool,optional): Bool to make this operation in-place. + """ + + def __init__(self, p=0.5, inplace=False): + self.p = p + self.inplace = inplace + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor of size (N, C, H, W) to be flipped. + Returns: + Tensor: Randomly flipped Tensor. + """ + if not self.inplace: + tensor = tensor.clone() + + if random.random() < self.p: + tensor = torch.flip(tensor, dims=(3,)) + return tensor diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index d1d6eed6..020b4d00 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -31,12 +31,16 @@ import torch.distributed as dist +import logging +logger = logging.getLogger(__name__) + + def get_date_of_run(): """create date and time for file save uniqueness example: 2022-05-07-08:31:12_PM' """ date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") - print(f"--> current date and time of run = {date_of_run}") + logger.info(f"--> current date and time of run = {date_of_run}") return date_of_run @@ -58,29 +62,29 @@ def load_model_sharded(model, rank, cfg): if not load_dir.exists(): if rank == 0: - print(f"No sharded_state_dict checkpoint directory found...skipping") + logger.info(f"No sharded_state_dict checkpoint directory found...skipping") return if rank == 0: - print(f"loading model from model path: {load_dir} ") + logger.info(f"loading model from model path: {load_dir} ") reader = FileSystemReader(load_dir) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): checkpoint = {"model": model.state_dict()} if rank == 0: ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") dist_cp.load_state_dict( state_dict=checkpoint, storage_reader=reader, ) if rank == 0: - print(f"checkpoint after load_state_dict()") + logger.info(f"checkpoint after load_state_dict()") ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") model.load_state_dict(checkpoint["model"]) if rank == 0: - print(f"Sharded state checkpoint loaded from {load_dir}") + logger.info(f"Sharded state checkpoint loaded from {load_dir}") def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): @@ -96,7 +100,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): save_dir = Path.cwd() / folder_name if rank == 0: - print(f"Saving model to {save_dir}") + logger.info(f"Saving model to {save_dir}") distributed_writer = dist_cp.FileSystemWriter( save_dir, @@ -118,8 +122,8 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): dist.barrier() t1 = time.perf_counter() if rank == 0: - print(f"Sharded state checkpoint saved to {save_dir}") - print( + logger.info(f"Sharded state checkpoint saved to {save_dir}") + logger.info( f"Checkpoint Time = {t1-t0:.4f}\n" ) def save_model_checkpoint( @@ -136,11 +140,11 @@ def save_model_checkpoint( ): cpu_state = model.state_dict() - print(f"saving process: rank {rank} done w model state_dict\n") + logger.info(f"saving process: rank {rank} done w model state_dict\n") if rank == 0: - print(f"--> saving model ...") + logger.info(f"--> saving model ...") # create save path folder_name = ( cfg.dist_checkpoint_root_folder @@ -158,10 +162,10 @@ def save_model_checkpoint( torch.save(cpu_state, save_full_path) - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): - print(f"--> saving model ...") + logger.info(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch)) os.makedirs(save_dir, exist_ok=True) if not cfg.freeze_llm: @@ -179,7 +183,7 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): project_dict[key] = cpu_state[key] torch.save(project_dict, save_full_path) - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") @@ -196,7 +200,7 @@ def load_model_checkpoint(model, rank, cfg): ) # is it present... if not full_state_dict_model_path.is_file(): - print( + logger.info( f"model checkpoint {full_state_dict_model_path} not present. Returning..." ) return @@ -207,21 +211,21 @@ def load_model_checkpoint(model, rank, cfg): model.load_state_dict(model_checkpoint) - print(f"model checkpoint loaded to rank0 cpu") + logger.info(f"model checkpoint loaded to rank0 cpu") def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): """save optimizer state via full state dict""" - print(f"--> optim state call on rank {rank}\n") + logger.info(f"--> optim state call on rank {rank}\n") # pull all sharded optimizer states to rank0 cpu... optim_state = FSDP.full_optim_state_dict(model, optimizer) - print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") if rank == 0: folder_name = ( @@ -239,11 +243,11 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): ) opt_save_full_path = save_dir / opt_save_name - print(f"--> saving optimizer state...") + logger.info(f"--> saving optimizer state...") torch.save(optim_state, opt_save_full_path) - print(f"--> saved {opt_save_full_path} to disk") + logger.info(f"--> saved {opt_save_full_path} to disk") def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): @@ -253,7 +257,7 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): if not optimizer_checkpoint_path.is_file(): - print( + logger.info( f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " ) return @@ -266,7 +270,7 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): # called from all ranks, though only rank0 has a valid param for full_osd sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) - print(f"optimizer shard loaded on rank {rank}") + logger.info(f"optimizer shard loaded on rank {rank}") def load_sharded_model_single_gpu(model,model_path): @@ -284,5 +288,5 @@ def load_sharded_model_single_gpu(model,model_path): model.load_state_dict(state_dict["model"]) - print(f"Sharded state checkpoint loaded from {model_path}") + logger.info(f"Sharded state checkpoint loaded from {model_path}") return model \ No newline at end of file diff --git a/src/llama_recipes/models/av_net.py b/src/llama_recipes/models/av_net.py new file mode 100644 index 00000000..f8742ccc --- /dev/null +++ b/src/llama_recipes/models/av_net.py @@ -0,0 +1,217 @@ +from fairseq.checkpoint_utils import load_model_ensemble_and_task +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm + +from .moco_visual_frontend import MoCoVisualFrontend +from .utils import PositionalEncoding, conv1dLayers, outputConv, MaskedLayerNorm, generate_square_subsequent_mask + +from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel + + +class AVNet(nn.Module): + + def __init__(self, model_config): + super(AVNet, self).__init__() + + self.modal = model_config.modal + self.numClasses = model_config.CHAR_NUM_CLASSES + self.reqInpLen = model_config.MAIN_REQ_INPUT_LENGTH + self.dModel= model_config.DMODEL #!!! + self.nHeads = model_config.TX_ATTENTION_HEADS + self.numLayers = model_config.TX_NUM_LAYERS + self.peMaxLen= model_config.PE_MAX_LENGTH + self.audinSize = model_config.AUDIO_FEATURE_SIZE + self.vidinSize = model_config.VIDEO_FEATURE_SIZE + self.fcHiddenSize = model_config.TX_FEEDFORWARD_DIM + self.dropout = model_config.TX_DROPOUT + self.MoCofile = model_config.MOCO_FRONTEND_FILE + self.W2Vfile = model_config.WAV2VEC_FILE + + # A & V Modal + tx_norm = nn.LayerNorm(self.dModel) + self.maskedLayerNorm = MaskedLayerNorm() + if self.modal == "AV": + self.ModalityNormalization = nn.LayerNorm(self.dModel) + self.EncoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen=self.peMaxLen) #512,500 + + # audio + if not self.modal == "VO": + # front-end + wav2vecModel, cfg, task = load_model_ensemble_and_task([self.W2Vfile], arg_overrides={ + "apply_mask": True, + "mask_prob": 0.5, + "mask_channel_prob": 0.25, + "mask_channel_length": 64, + "layerdrop": 0.1, + "activation_dropout": 0.1, + "feature_grad_mult": 0.0, + }) + wav2vecModel = wav2vecModel[0] + wav2vecModel.remove_pretraining_modules() + self.wav2vecModel = wav2vecModel + # back-end + self.audioConv = conv1dLayers(self.maskedLayerNorm, self.audinSize, self.dModel, self.dModel, downsample=True) + audioEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) + self.audioEncoder = nn.TransformerEncoder(audioEncoderLayer, num_layers=self.numLayers, norm=tx_norm) + else: + self.wav2vecModel = None #主要是这三个 + self.audioConv = None + self.audioEncoder = None + + # visual + if not self.modal == "AO": + # front-end + visualModel = MoCoVisualFrontend(model_config) + if self.MoCofile is not None: + visualModel.load_state_dict(torch.load(self.MoCofile, map_location="cpu"), strict=False) + self.visualModel = visualModel + # back-end + self.videoConv = conv1dLayers(self.maskedLayerNorm, self.vidinSize, self.dModel, self.dModel) + videoEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) + self.videoEncoder = nn.TransformerEncoder(videoEncoderLayer, num_layers=self.numLayers, norm=tx_norm) + else: + self.visualModel = None #主要是这三个 + self.videoConv = None + self.videoEncoder = None + + # JointConv for fusion + if self.modal == "AV": + self.jointConv = conv1dLayers(self.maskedLayerNorm, 2 * self.dModel, self.dModel, self.dModel) + jointEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) + self.jointEncoder = nn.TransformerEncoder(jointEncoderLayer, num_layers=self.numLayers, norm=tx_norm) + + # self.jointOutputConv = outputConv(self.maskedLayerNorm, self.dModel, self.numClasses) + # self.decoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen=self.peMaxLen) + # self.embed = torch.nn.Sequential( + # nn.Embedding(self.numClasses, self.dModel), + # self.decoderPositionalEncoding + # ) + # jointDecoderLayer = nn.TransformerDecoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) + # self.jointAttentionDecoder = nn.TransformerDecoder(jointDecoderLayer, num_layers=self.numLayers, norm=tx_norm) + # self.jointAttentionOutputConv = outputConv("LN", self.dModel, self.numClasses) + + def forward(self, inputBatch, maskw2v): + audioBatch, audMask, videoBatch, vidLen = inputBatch #torch.Size([2, 32480]),torch.Size([2, 32480]),torch.Size([2, 52, 1, 112, 112]),[52,47] # audMask尾部有一堆true表示mask,其余都是false + if not self.modal == "VO": + result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version + audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 + if audMask==None: + audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO + + audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') + else: + audLen = None + + if not self.modal == "AO": + videoBatch = videoBatch.transpose(1, 2) + videoBatch = self.visualModel(videoBatch, vidLen.long()) #torch.Size([99, 2048]) + videoBatch = list(torch.split(videoBatch, vidLen.tolist(), dim=0)) #拆成一个list [(52,2048), (47, 2048)] + + audioBatch, videoBatch, inputLenBatch, mask = self.makePadding(audioBatch, audLen, videoBatch, vidLen) #[2, 160, 1024], torch.Size([2, 80, 2048]), tensor([80, 80], (2,80) #这一步比较关键 + + if isinstance(self.maskedLayerNorm, MaskedLayerNorm): + self.maskedLayerNorm.SetMaskandLength(mask, inputLenBatch) + + if not self.modal == "VO": + audioBatch = audioBatch.transpose(1, 2) #? + audioBatch = self.audioConv(audioBatch) #[2, 1024, 80] + audioBatch = audioBatch.transpose(1, 2).transpose(0, 1) + audioBatch = self.EncoderPositionalEncoding(audioBatch) + audioBatch = self.audioEncoder(audioBatch, src_key_padding_mask=mask) #[80,2,1024] + + if not self.modal == "AO": + videoBatch = videoBatch.transpose(1, 2) + videoBatch = self.videoConv(videoBatch) #[2, 1024, 80] + videoBatch = videoBatch.transpose(1, 2).transpose(0, 1) + videoBatch = self.EncoderPositionalEncoding(videoBatch) + videoBatch = self.videoEncoder(videoBatch, src_key_padding_mask=mask) #[80, 2, 1024] + + if self.modal == "AO": + jointBatch = audioBatch + elif self.modal == "VO": + jointBatch = videoBatch + else: + jointBatch = torch.cat([self.ModalityNormalization(audioBatch), self.ModalityNormalization(videoBatch)], dim=2) #torch.Size([80, 2, 2048]) + jointBatch = jointBatch.transpose(0, 1).transpose(1, 2) #(2,2048,80) + jointBatch = self.jointConv(jointBatch) #(2,1024,80) + jointBatch = jointBatch.transpose(1, 2).transpose(0, 1) + jointBatch = self.EncoderPositionalEncoding(jointBatch) + jointBatch = self.jointEncoder(jointBatch, src_key_padding_mask=mask) #[80, 2, 1024] + + return jointBatch, inputLenBatch, mask #[80, 2, 1024], [80,80], [2,80] mask全是false + + def makeMaskfromLength(self, maskShape, maskLength, maskDevice): + mask = torch.zeros(maskShape, device=maskDevice) + mask[(torch.arange(mask.shape[0]), maskLength - 1)] = 1 + mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() + return mask + + def makePadding(self, audioBatch, audLen, videoBatch, vidLen): + if self.modal == "AO": + audPadding = audLen % 2 + mask = (audPadding + audLen) > 2 * self.reqInpLen + audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) + audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() + audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() + + audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) + audioBatch = list(audioBatch) + for i, _ in enumerate(audioBatch): + pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding[i], audRightPadding[i])) + audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) + + audioBatch = pad_sequence(audioBatch, batch_first=True) + inputLenBatch = ((audLen + audPadding) // 2).long() + mask = self.makeMaskfromLength([audioBatch.shape[0]] + [audioBatch.shape[1] // 2], inputLenBatch, audioBatch.device) + + elif self.modal == "VO": + vidPadding = torch.zeros(len(videoBatch)).long().to(vidLen.device) + + mask = (vidPadding + vidLen) > self.reqInpLen + vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) + + vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() + vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() + + for i, _ in enumerate(videoBatch): + pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding[i], vidRightPadding[i])) + videoBatch[i] = pad(videoBatch[i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + + videoBatch = pad_sequence(videoBatch, batch_first=True) + inputLenBatch = (vidLen + vidPadding).long() + mask = self.makeMaskfromLength(videoBatch.shape[:-1], inputLenBatch, videoBatch.device) + + else: + dismatch = audLen - 2 * vidLen + vidPadding = torch.ceil(torch.div(dismatch, 2)).int() + vidPadding = vidPadding * (vidPadding > 0) + audPadding = 2 * vidPadding - dismatch + + mask = (vidPadding + vidLen) > self.reqInpLen + vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) + mask = (audPadding + audLen) > 2 * self.reqInpLen + audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) + + vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() + vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() + audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() + audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() + + audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) + audioBatch = list(audioBatch) + for i, _ in enumerate(audioBatch): + pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding[i], audRightPadding[i])) + audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) + pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding[i], vidRightPadding[i])) + videoBatch[i] = pad(videoBatch[i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + + audioBatch = pad_sequence(audioBatch, batch_first=True) + videoBatch = pad_sequence(videoBatch, batch_first=True) + inputLenBatch = (vidLen + vidPadding).long() + mask = self.makeMaskfromLength(videoBatch.shape[:-1], inputLenBatch, videoBatch.device) + + return audioBatch, videoBatch, inputLenBatch, mask diff --git a/src/llama_recipes/models/avsr_model.py b/src/llama_recipes/models/avsr_model.py new file mode 100644 index 00000000..a8b9af99 --- /dev/null +++ b/src/llama_recipes/models/avsr_model.py @@ -0,0 +1,122 @@ +import types +import torch +import soundfile as sf +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Optional, Tuple, Union +from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + LlamaConfig, +) +import whisper + +from llama_recipes.utils.config_utils import generate_peft_config +from llama_recipes.utils.train_utils import print_model_size + +from .av_net import AVNet +from .slam_model import setup_llm +from torch.nn.utils.rnn import pad_sequence +import copy +from llama_recipes.utils.metric import compute_accuracy + +def setupavsr_model(tokenizer, train_config, model_config, **kwargs): + return avsrllm_model(tokenizer, train_config, model_config, **kwargs) + +class avsrllm_model(nn.Module): + def __init__( + self, + tokenizer, + train_config, + model_config, + **kwargs + ): + super().__init__() + + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + + # audio-visual + self.avnet=AVNet(model_config) + + # load_ckpt + checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) + self.avnet.load_state_dict(checkpoint['state_dict'],strict=False) # 最终输出ctc/attention的模块没有用到 + + # freeze + for name, param in self.avnet.named_parameters(): + param.requires_grad = False + self.avnet.eval() + + # llama + self.llm = setup_llm(train_config, model_config, **kwargs) + + # projector + self.feature_projector = nn.Linear(model_config.DMODEL, self.llm.config.hidden_size) #(512,4096) + + # tokenizer + self.tokenizer = tokenizer #tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path) 不需要保存 + self.metric = kwargs.get("metric", "acc") + + def forward(self, inputBatch0,inputBatch1,inputBatch2,inputBatch3, targetoutBatch, targetLenBatch, maskw2v, **kwargs): + inputBatch=(inputBatch0, inputBatch1,inputBatch2,inputBatch3) # targetinBatch是前面加 + + jointBatch, inputLenBatch, mask = self.avnet(inputBatch, maskw2v) #[129, 2, 1024], [129,125], [2,129] mask false的地方是不mask的,mask的位置是true , 就mask[1]末尾4个true #输出应该是 bs,l,dim + jointBatch = jointBatch.transpose(0, 1) #(2,129,1024) + + # project + feature_tokens = self.feature_projector(jointBatch) #(2,129,4096) + + if hasattr(self.llm.model, "embed_tokens"): + texts_embeds = self.llm.model.embed_tokens(targetoutBatch) + else: # + texts_embeds = self.llm.model.model.embed_tokens(targetoutBatch) #(2,37)-> (2,37,4096) + + #还原原来长度 搞出每个item的特征和文本 拼起来 再padding + + #input_list=[torch.cat( (jointBatch[i, ~mask[i]] , targetoutBatch[i][:targetLenBatch[i]]), dim=1) for i in range(jointBatch.size(0) )] + # for i in range(jointBatch.size(0)): + # a= feature_tokens[i, ~mask[i]] #(129,4096) (125,4096) + # b= texts_embeds[i][:targetLenBatch[i]][:] #(37,4096) (26,4096) + # input= torch.cat( (a,b), dim=0) #(166,4096) (151,4096) + + input_lists=[torch.cat( (feature_tokens[i, ~mask[i]], texts_embeds[i][:targetLenBatch[i]][:] ) , dim=0 ) for i in range(jointBatch.size(0)) ] + inputs_embeds = pad_sequence(input_lists, batch_first=True, padding_value=0) #(2,166,4096) + + lengths=[item.size(0) for item in input_lists] #[166, 151] + max_length=max(lengths) #166 + mask2 = torch.zeros(len(input_lists),max_length,dtype=torch.bool) #(2,166) + for i,length in enumerate(lengths): + mask2[i,:length]=1 #mask的地方是false,其余是true,只有maks2[1]末尾有15个false + mask2=mask2.to("cuda:0") + + + # labels_list=[] + # for i in range(jointBatch.size(0)): + # labels= torch.cat(( torch.full((inputLenBatch[i],),self.IGNORE_INDEX , device=targetoutBatch.device) , targetoutBatch[i][:targetLenBatch[i]]) ,dim=0) + # labels_list.append((labels)) + labels_list= [ torch.cat(( torch.full((inputLenBatch[i],),self.IGNORE_INDEX , device=targetoutBatch.device) , targetoutBatch[i][:targetLenBatch[i]]) ,dim=0) for i in range(jointBatch.size(0)) ] #[166,151] + labels = pad_sequence(labels_list, batch_first=True, padding_value=self.IGNORE_INDEX) #(2,166) + + + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask = mask2, labels=labels) #self PeftModelForCausalLM 里面实现了错位 + + acc = -1 + if self.metric: + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) + + return model_outputs, acc #logits:[2,292,32000] #loss:6.9475 + + def save_pretrained(self, output_dir): + save_dir= output_dir+'/avsrmodel.pt' + self.llm.save_pretrained(output_dir) + modules_to_save={ + 'avnet': self.avnet.state_dict(), + 'feature_projector':self.feature_projector.state_dict(), + } + + torch.save(modules_to_save,save_dir) + + diff --git a/src/llama_recipes/models/moco_visual_frontend.py b/src/llama_recipes/models/moco_visual_frontend.py new file mode 100644 index 00000000..d53e786a --- /dev/null +++ b/src/llama_recipes/models/moco_visual_frontend.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torchvision.models as models + + +class MoCoVisualFrontend(nn.Module): + # def __init__(self, dModel=args["FRONTEND_DMODEL"], nClasses=args["WORD_NUM_CLASSES"], frameLen=args["FRAME_LENGTH"], + # vidfeaturedim=args["VIDEO_FEATURE_SIZE"]): + def __init__(self, model_config): + + super(MoCoVisualFrontend, self).__init__() + self.dModel = model_config.FRONTEND_DMODEL + self.nClasses = model_config.WORD_NUM_CLASSES + self.frameLen = model_config.FRAME_LENGTH + self.vidfeaturedim = model_config.VIDEO_FEATURE_SIZE + + + # Conv3D + self.frontend3D = nn.Sequential( + nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(True), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + ) + # moco + MoCoModel = models.__dict__['resnet50']() #就当搞了个ResNet + MoCoModel.fc = nn.Identity() + MoCoModel.conv1 = nn.Identity() + MoCoModel.bn1 = nn.Identity() + MoCoModel.relu = nn.Identity() + MoCoModel.maxpool = nn.Identity() + self.MoCoModel = MoCoModel + + def forward(self, x, x_len): # x: 8,1,149,112,112 + x = self.frontend3D(x) #[2, 64, 52, 28, 28] + x = x.transpose(1, 2) + mask = torch.zeros(x.shape[:2], device=x.device) #(8,149) + mask[(torch.arange(mask.shape[0], device=x.device), x_len - 1)] = 1 + mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() #一堆true false + x = x[~mask] + x = self.MoCoModel(x) # torch.Size([99, 2048]) + return x diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index f9dc1cd3..64b65b86 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -20,6 +20,8 @@ from torch.nn import CrossEntropyLoss from llama_recipes.utils.metric import compute_accuracy +import logging +logger = logging.getLogger(__name__) from llama_recipes.models.projector import EncoderProjectorConcat, EncoderProjectorCov1d, EncoderProjectorQFormer @@ -119,7 +121,7 @@ def setup_llm(train_config, model_config, **kwargs): from optimum.bettertransformer import BetterTransformer model = BetterTransformer.transform(model) except ImportError: - print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.") print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) @@ -138,7 +140,7 @@ def setup_llm(train_config, model_config, **kwargs): model.print_trainable_parameters() if kwargs.get("peft_ckpt", None): - print("loading peft_ckpt from: ", kwargs.get("peft_ckpt")) + logger.info("loading peft_ckpt from: ", kwargs.get("peft_ckpt")) model = PeftModel.from_pretrained(model, kwargs.get("peft_ckpt")) print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) diff --git a/src/llama_recipes/models/utils.py b/src/llama_recipes/models/utils.py new file mode 100644 index 00000000..74969681 --- /dev/null +++ b/src/llama_recipes/models/utils.py @@ -0,0 +1,140 @@ +import math + +import torch +import torch.nn as nn + + +class PositionalEncoding(nn.Module): + """ + A layer to add positional encodings to the inputs of a Transformer model. + Formula: + PE(pos,2i) = sin(pos/10000^(2i/d_model)) + PE(pos,2i+1) = cos(pos/10000^(2i/d_model)) + """ + + def __init__(self, dModel, maxLen): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(maxLen, dModel) #(500,512) + position = torch.arange(0, maxLen, dtype=torch.float).unsqueeze(dim=-1) #(500,1) + denominator = torch.exp(torch.arange(0, dModel, 2).float() * (math.log(10000.0) / dModel)) #(256,) + pe[:, 0::2] = torch.sin(position / denominator) + pe[:, 1::2] = torch.cos(position / denominator) + pe = pe.unsqueeze(dim=0).transpose(0, 1) #(500,1,512) + self.register_buffer("pe", pe) + + def forward(self, inputBatch): #(152,8,512) decoder 输入的时候(92,8,512) + outputBatch = inputBatch + self.pe[:inputBatch.shape[0], :, :] #(152,8,512) # 报错的inputbatch[622,8,512] #self.pe [500,1,512] + return outputBatch + + +class conv1dLayers(nn.Module): + def __init__(self, MaskedNormLayer, inD, dModel, outD, downsample=False): + super(conv1dLayers, self).__init__() #inD=1024, dModel=512, outD=512 + if downsample: + kernel_stride = 2 + else: + kernel_stride = 1 + self.conv = nn.Sequential( + nn.Conv1d(inD, dModel, kernel_size=(kernel_stride,), stride=(kernel_stride,), padding=(0,)), + TransposeLayer(1, 2), + MaskedNormLayer, + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel, outD, kernel_size=(1,), stride=(1,), padding=(0,)) + ) + + def forward(self, inputBatch): + return self.conv(inputBatch) + + +class outputConv(nn.Module): #这个就是decoder了 最后output dim是 numClasses + def __init__(self, MaskedNormLayer, dModel, numClasses): + super(outputConv, self).__init__() + if MaskedNormLayer == "LN": # 区别是normlayer不同 正常的layer normaliztion + self.outputconv = nn.Sequential( + nn.Conv1d(dModel, dModel, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + nn.LayerNorm(dModel), + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + nn.LayerNorm(dModel // 2), + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel // 2, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + nn.LayerNorm(dModel // 2), + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel // 2, numClasses, kernel_size=(1,), stride=(1,), padding=(0,)) + ) + else: + self.outputconv = nn.Sequential( # MaskedNormLayer + nn.Conv1d(dModel, dModel, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + MaskedNormLayer, + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + MaskedNormLayer, + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel // 2, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), + TransposeLayer(1, 2), + MaskedNormLayer, + TransposeLayer(1, 2), + nn.ReLU(True), + nn.Conv1d(dModel // 2, numClasses, kernel_size=(1,), stride=(1,), padding=(0,)) + ) + + def forward(self, inputBatch): + return self.outputconv(inputBatch) + + +class MaskedLayerNorm(nn.Module): + def __init__(self, eps=1e-5): + super(MaskedLayerNorm, self).__init__() + self.register_buffer('mask', None, persistent=False) + self.register_buffer('inputLenBatch', None, persistent=False) + self.eps = eps + + def SetMaskandLength(self, mask, inputLenBatch): + self.mask = mask + self.inputLenBatch = inputLenBatch + + def expand2shape(self, inputBatch, expandedShape): + return inputBatch.unsqueeze(-1).unsqueeze(-1).expand(expandedShape) + + def forward(self, inputBatch): + dModel = inputBatch.shape[-1] + maskBatch = ~self.mask.unsqueeze(-1).expand(inputBatch.shape) + + meanBatch = (inputBatch * maskBatch).sum((1, 2)) / (self.inputLenBatch * dModel) + stdBatch = ((inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) ** 2 * maskBatch).sum((1, 2)) + stdBatch = stdBatch / (self.inputLenBatch * dModel) + + # Norm the input + normed = (inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) / \ + (torch.sqrt(self.expand2shape(stdBatch + self.eps, inputBatch.shape))) + return normed + + +class TransposeLayer(nn.Module): + def __init__(self, dim1, dim2): + super(TransposeLayer, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, inputBatch): + return inputBatch.transpose(self.dim1, self.dim2) + + +def generate_square_subsequent_mask(sz: int, device): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ # 三角矩阵 为了infer的时候的decode + mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask diff --git a/src/llama_recipes/models/visual_encoder.py b/src/llama_recipes/models/visual_encoder.py new file mode 100644 index 00000000..a1d3412d --- /dev/null +++ b/src/llama_recipes/models/visual_encoder.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from config import args + + +class VisualEncoder(nn.Module): + # def __init__(self, dModel=args["FRONTEND_DMODEL"], nClasses=args["WORD_NUM_CLASSES"], frameLen=args["FRAME_LENGTH"], + # vidfeaturedim=args["VIDEO_FEATURE_SIZE"]): + def __init__(self, model_config): + + super(VisualEncoder, self).__init__() + self.dModel = model_config.FRONTEND_DMODEL + self.nClasses = model_config.WORD_NUM_CLASSES + self.frameLen = model_config.FRAME_LENGTH + self.vidfeaturedim = model_config.VIDEO_FEATURE_SIZE + + + # Conv3D + self.frontend3D = nn.Sequential( + nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(True), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + ) + # moco + MoCoModel = models.__dict__['resnet50']() #就当搞了个ResNet + MoCoModel.fc = nn.Identity() + MoCoModel.conv1 = nn.Identity() + MoCoModel.bn1 = nn.Identity() + MoCoModel.relu = nn.Identity() + MoCoModel.maxpool = nn.Identity() #有点意思 + self.MoCoModel = MoCoModel + + self.MoCoModel.load_state_dict(torch.load(MoCofile, map_location="cpu"), strict=False) + + + # AV + self.peMaxLen = model_config.PE_MAX_LENGTH + tx_norm = nn.LayerNorm(dModel) + self.maskedLayerNorm = MaskedLayerNorm() + self.EncoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen= self.peMaxLen) #512,500 + + # visual backend + self.nHeads = model_config.X_ATTENTION_HEADS + self.fcHiddenSize = model_config.TX_FEEDFORWARD_DIM + self.dropout = model_config.TX_DROPOUT + self.num_layers = model_config.TX_NUM_LAYERS + + self.videoConv = conv1dLayers(self.maskedLayerNorm, self.vidfeaturedim, self.dModel, self.dModel) + videoEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) + self.videoEncoder = nn.TransformerEncoder(videoEncoderLayer, num_layers=self.num_layers, norm=tx_norm) + + def forward(self, x, x_len): # x: 8,1,149,112,112 + x = self.frontend3D(x) #(8,64,149,28,28) + x = x.transpose(1, 2) #(8,149,64,28,28) + mask = torch.zeros(x.shape[:2], device=x.device) #(8,149) + mask[(torch.arange(mask.shape[0], device=x.device), x_len - 1)] = 1 + mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() #一堆true false + x = x[~mask] #(739,64,28,28) + x = self.MoCoModel(x) #(739,2048) + return x + + +class MaskedLayerNorm(nn.Module): + def __init__(self, eps=1e-5): + super(MaskedLayerNorm, self).__init__() + self.register_buffer('mask', None, persistent=False) + self.register_buffer('inputLenBatch', None, persistent=False) + self.eps = eps + + def SetMaskandLength(self, mask, inputLenBatch): + self.mask = mask + self.inputLenBatch = inputLenBatch + + def expand2shape(self, inputBatch, expandedShape): + return inputBatch.unsqueeze(-1).unsqueeze(-1).expand(expandedShape) + + def forward(self, inputBatch): + dModel = inputBatch.shape[-1] + maskBatch = ~self.mask.unsqueeze(-1).expand(inputBatch.shape) + + meanBatch = (inputBatch * maskBatch).sum((1, 2)) / (self.inputLenBatch * dModel) + stdBatch = ((inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) ** 2 * maskBatch).sum((1, 2)) + stdBatch = stdBatch / (self.inputLenBatch * dModel) + + # Norm the input + normed = (inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) / \ + (torch.sqrt(self.expand2shape(stdBatch + self.eps, inputBatch.shape))) + return normed diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index 32347daf..d070594a 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -42,13 +42,46 @@ ) from model_factory import model_factory - +import sys +import logging +import wandb def main(**kwargs): # Update the configuration for the training and sharding process train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() update_config((train_config, fsdp_config, model_config), **kwargs) + # Set wandb + wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config)} + wandb.init(project="project_name",name="exp_name",config=wandb_config) #记录参数 + + # Set log + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + filemode='w' + ) + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + file_handler = logging.FileHandler(filename=train_config.log_file, mode='w') + file_handler.setLevel(logging.INFO) + file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + file_handler.setFormatter(file_formatter) + + logger.handlers[0].setLevel(logging.INFO) + console_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + logger.handlers[0].setFormatter(console_formatter) + + logger.addHandler(file_handler) + + logger.info("train_config: {}".format(train_config)) + logger.info("fsdp_config: {}".format(fsdp_config)) + logger.info("model_config: {}".format(model_config)) + + # Set the seeds for reproducibility torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) @@ -60,7 +93,7 @@ def main(**kwargs): local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - print(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}") + logger.info(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}") if torch.distributed.is_initialized(): torch.cuda.set_device(local_rank) @@ -103,6 +136,8 @@ def main(**kwargs): model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) + logger.info("dataset_config: {}".format(dataset_config)) + wandb.config.update( {"dataset_config": vars(dataset_config)} ) # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( @@ -111,14 +146,14 @@ def main(**kwargs): split="train", ) if not train_config.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(dataset_train)}") + logger.info(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="val", ) if not train_config.enable_fsdp or rank == 0: - print(f"--> Validation Set Length = {len(dataset_val)}") + logger.info(f"--> Validation Set Length = {len(dataset_val)}") if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) @@ -179,7 +214,9 @@ def main(**kwargs): rank if train_config.enable_fsdp else None, ) if not train_config.enable_fsdp or rank==0: - [print(f'Key: {k}, Value: {v}') for k, v in results.items()] + [logger.info(f'Key: {k}, Value: {v}') for k, v in results.items()] + + wandb.finish() if __name__ == "__main__": fire.Fire(main) \ No newline at end of file diff --git a/src/llama_recipes/pipeline/model_factory.py b/src/llama_recipes/pipeline/model_factory.py index 6d7a41c7..3de78e53 100644 --- a/src/llama_recipes/pipeline/model_factory.py +++ b/src/llama_recipes/pipeline/model_factory.py @@ -1,15 +1,23 @@ import torch from llama_recipes.models.slam_model import setup_model, setup_tokenizer +from llama_recipes.models.avsr_model import setupavsr_model from llama_recipes.utils.train_utils import print_model_size import os +import logging +logger = logging.getLogger(__name__) + def model_factory(train_config, model_config, **kwargs): tokenizer = setup_tokenizer(train_config, model_config, **kwargs) - model = setup_model(tokenizer, train_config, model_config, **kwargs) + if train_config.model_name=="avsr": + model = setupavsr_model(tokenizer, train_config, model_config, **kwargs) + else: + model = setup_model(tokenizer, train_config, model_config, **kwargs) + ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) if ckpt_path is not None: - print("loading other parts from: ", ckpt_path) + logger.info("loading other parts from: ", ckpt_path) ckpt_dict = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt_dict, strict=False) diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 81ab9680..d8115139 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -18,6 +18,8 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler from llama_recipes.utils.dataset_utils import DATASET_PREPROC +import logging +logger = logging.getLogger(__name__) def update_config(config, **kwargs): if isinstance(config, (tuple, list)): @@ -35,9 +37,9 @@ def update_config(config, **kwargs): setattr(config, param_name, v) else: # In case of specialized config we can warm user - print(f"Warning: {config_name} does not accept parameter: {k}") + logger.warning(f"Warning: {config_name} does not accept parameter: {k}") elif isinstance(config, train_config): - print(f"Warning: unknown parameter {k}") + logger.warning(f"Warning: unknown parameter {k}") def generate_peft_config(train_config, kwargs): @@ -106,6 +108,6 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["batch_size"] = batch_size kwargs["drop_last"] = True kwargs["collate_fn"] = dataset.collator - print(f"Using batching strategy: {train_config.batching_strategy}") + logger.info(f"Using batching strategy: {train_config.batching_strategy}") return kwargs diff --git a/src/llama_recipes/utils/dataset_utils.py b/src/llama_recipes/utils/dataset_utils.py index 47baeefe..af4eca45 100644 --- a/src/llama_recipes/utils/dataset_utils.py +++ b/src/llama_recipes/utils/dataset_utils.py @@ -13,6 +13,9 @@ get_samsum_dataset, ) +import logging +logger = logging.getLogger(__name__) + def load_module_from_py_file(py_file: str) -> object: """ @@ -45,15 +48,15 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): try: return getattr(module, func_name)(dataset_config, tokenizer, split) except AttributeError as e: - print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") + logger.info(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") raise e - DATASET_PREPROC = { "alpaca_dataset": partial(get_alpaca_dataset), "grammar_dataset": get_grammar_dataset, "samsum_dataset": get_samsum_dataset, "custom_dataset": get_custom_dataset, + "avsr_dataset": get_custom_dataset, } diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d3851387..dbc2978b 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -28,6 +28,10 @@ from llama_recipes.utils.memory_utils import MemoryTrace from llama_recipes.utils.metric import compute_accuracy +import wandb +import logging +logger = logging.getLogger(__name__) + def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 @@ -63,7 +67,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: world_size = int(os.environ["WORLD_SIZE"]) autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext - + train_prep = [] train_loss = [] train_acc = [] @@ -84,6 +88,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) for step, batch in enumerate(train_dataloader): for key in batch.keys(): + if type(batch[key])==bool: #train的时候是true infer的时候是false + continue if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: @@ -95,6 +101,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche loss = loss / gradient_accumulation_steps acc = acc / gradient_accumulation_steps + + if step % train_config.log_interval == 0: + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}) + total_loss += loss.detach().float() total_acc += acc if train_config.use_fp16: @@ -133,19 +143,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_loss.append(train_epoch_loss) train_acc.append(train_epoch_acc) + wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + if train_config.enable_fsdp: if rank==0: - print(f"Max CUDA memory allocated was {memtrace.peak} GB") - print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") - print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") - print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") - print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") + logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") else: - print(f"Max CUDA memory allocated was {memtrace.peak} GB") - print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") - print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") - print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") - print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") + logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") # Update the learning rate as needed lr_scheduler.step() @@ -159,9 +171,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.use_peft: if train_config.enable_fsdp: if rank==0: - print(f"we are about to save the PEFT modules") + logger.info(f"we are about to save the PEFT modules") else: - print(f"we are about to save the PEFT modules") + logger.info(f"we are about to save the PEFT modules") if train_config.enable_fsdp: if rank==0: save_model_checkpoint_peft( @@ -175,12 +187,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ) if train_config.enable_fsdp: if rank==0: - print(f"PEFT modules are saved in {train_config.output_dir} directory") + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") else: - print(f"PEFT modules are saved in {train_config.output_dir} directory") + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") elif not train_config.use_peft and train_config.freeze_llm: - print(f"llm is frozen, we are about to save other parts.") + logger.info(f"llm is frozen, we are about to save other parts.") if train_config.enable_fsdp: if rank==0: save_model_checkpoint_peft( @@ -199,21 +211,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche model, optimizer, rank, train_config, epoch=epoch ) elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - print("=====================================================") + logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") + logger.info("=====================================================") - save_model_and_optimizer_sharded(model, rank, train_config) + j(model, rank, train_config) if train_config.save_optimizer: save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - print("=====================================================") + logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") + logger.info("=====================================================") if not train_config.use_peft and train_config.save_optimizer: save_optimizer_checkpoint( model, optimizer, rank, train_config, epoch=epoch ) - print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") - print("=====================================================") + logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") + logger.info("=====================================================") if train_config.enable_fsdp: dist.barrier() checkpoint_end_time = time.perf_counter() - checkpoint_start_time @@ -222,35 +234,38 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche best_val_loss = eval_epoch_loss if train_config.enable_fsdp: if rank==0: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") else: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") val_loss.append(eval_epoch_loss) val_prep.append(eval_ppl) if rest: val_acc.append(rest[0]) else: val_acc.append(-1) + + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) + if train_config.run_test_during_validation: if train_config.enable_fsdp: if rank==0: - print("=====================================") - print(f"Test the file {train_config.run_test_during_validation_file} during validation:") + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") with autocast(): - print(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) - print("=====================================") + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") dist.barrier() else: - print("=====================================") - print(f"Test the file {train_config.run_test_during_validation_file} during validation:") + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") with autocast(): - print(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) - print("=====================================") + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") if train_config.enable_fsdp: if rank==0: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") else: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") avg_epoch_time = sum(epoch_times)/ len(epoch_times) avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep)/len(train_prep) @@ -300,6 +315,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): for key in batch.keys(): + if type(batch[key])==bool: #train的时候是true infer的时候是false + continue if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: @@ -336,9 +353,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): # Print evaluation metrics if train_config.enable_fsdp: if local_rank==0: - print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") + logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") else: - print(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") + logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") return eval_ppl, eval_epoch_loss, eval_epoch_acc @@ -352,7 +369,7 @@ def freeze_transformer_layers(model, num_layer): def check_frozen_layers_peft_model(model): for i, layer in enumerate(model.base_model.model.model.layers): for name, param in layer.named_parameters(): - print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") + logger.info(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") def setup(): @@ -369,7 +386,7 @@ def setup_environ_flags(rank): # Note this is only availble in PyTorch Nighlies (as of July 30 2023) # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' if rank == 0: - print(f"--> Running with torch dist debug set to detail") + logger.info(f"--> Running with torch dist debug set to detail") def cleanup(): @@ -380,7 +397,7 @@ def cleanup(): def clear_gpu_cache(rank=None): """Clear the GPU cache for all ranks""" if rank == 0: - print(f"Clearing GPU cache for all ranks") + logger.info(f"Clearing GPU cache for all ranks") torch.cuda.empty_cache() @@ -393,7 +410,7 @@ def get_parameter_dtypes(model): def print_model_size(model, config, rank: int = 0) -> None: """ - Print model name, the number of trainable parameters and initialization time. + log model name, the number of trainable parameters and initialization time. Args: model: The PyTorch model. @@ -403,9 +420,9 @@ def print_model_size(model, config, rank: int = 0) -> None: rank (int, optional): Current process's rank. Defaults to 0. """ if rank == 0: - print(f"--> Model {config.model_name}") + logger.info(f"--> Model {config.model_name}") total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") + logger.info(f"--> {config.model_name} has {total_params / 1e6} Million params\n") def print_module_size(module, module_name, rank: int = 0) -> None: """ @@ -417,9 +434,9 @@ def print_module_size(module, module_name, rank: int = 0) -> None: rank (int, optional): Current process's rank. Defaults to 0. """ if rank == 0: - print(f"--> Module {module_name}") + logger.info(f"--> Module {module_name}") total_params = sum(p.numel() for p in module.parameters() if p.requires_grad) - print(f"\n--> {module_name} has {total_params / 1e6} Million params\n") + logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n") def get_policies(cfg, rank): @@ -444,13 +461,13 @@ def get_policies(cfg, rank): if bf16_ready and not cfg.use_fp16: mixed_precision_policy = bfSixteen_mixed if rank == 0: - print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + logger.info(f"bFloat16 enabled for mixed precision - using bfSixteen policy") elif cfg.use_fp16: mixed_precision_policy = fpSixteen if rank == 0: - print(f"FP16 enabled") + logger.info(f"FP16 enabled") else: - print(f"bFloat16 support not present. Using FP32, and not mixed precision") + logger.info(f"bFloat16 support not present. Using FP32, and not mixed precision") wrapping_policy = get_llama_wrapper() return mixed_precision_policy, wrapping_policy @@ -485,10 +502,10 @@ def save_train_params(train_config, fsdp_config, rank): # Check if there's a directory with the same name as the file if os.path.isdir(file_name): - print(f"Error: {file_name} is a directory, not a file.") + logger.info(f"Error: {file_name} is a directory, not a file.") else: # Write the YAML string to the file with open(file_name, 'w') as f: f.write(config_yaml) if rank==0: - print(f"training params are saved in {file_name}") + logger.info(f"training params are saved in {file_name}")