diff --git a/scripts/finetune_avsr.sh b/scripts/finetune_avsr.sh index 171e3c4d..278fe777 100644 --- a/scripts/finetune_avsr.sh +++ b/scripts/finetune_avsr.sh @@ -1,28 +1,107 @@ #!/bin/bash # export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 -export CUDA_LAUNCH_BLOCKING=1 +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO cd /root/SLAM-LLM -audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth -speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt -llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune +# speech_encoder_path= TODO! + + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/vicuna-13b-v1.5-finetune-avsr-20230115 # -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then python src/llama_recipes/pipeline/finetune.py \ --model_name avsr \ ---use_peft --peft_method lora \ ---quantization \ ---llm_name llama-2-7b-hf \ +--freeze_encoder \ +--freeze_llm \ +--llm_name vicuna-13b-v1.5 \ +--llm_path $llm_path \ +--llm_dim 4096 \ +--encoder_name moco_wav2vec2 \ +--encoder_ds_rate 2 \ +--encoder_dim 512 \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 20 \ +--batch_size_training 6 \ +--val_batch_size 2 \ +--num_workers_dataloader 2 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/second_try.log" \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name yanghaha \ +--wandb_project_name slam-llm \ +--wandb_exp_name avsr \ +--log_interval 5 \ + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--freeze_encoder \ +--freeze_llm \ +--use_fp16 \ +--enable_fsdp \ +--llm_name vicuna-13b-v1.5 \ --llm_path $llm_path \ ---encoder_name whisper \ ---encoder_path $speech_encoder_path \ +--llm_dim 4096 \ +--encoder_name moco_wav2vec2 \ +--encoder_ds_rate 2 \ +--encoder_dim 512 \ --encoder_projector linear \ +--encoder_projector_ds_rate 5 \ --dataset avsr_dataset \ --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ --batching_strategy custom \ ---num_epochs 1 \ +--num_epochs 20 \ --batch_size_training 2 \ ---output_dir $output_dir \ No newline at end of file +--val_batch_size 2 \ +--num_workers_dataloader 2 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/second_try.log" \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name yanghaha \ +--wandb_project_name slam-llm \ +--wandb_exp_name avsr \ +--log_interval 5 \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ +# --use_peft --peft_method lora \ +# --master_port=29501 \ +fi + +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} +# {"key": "1688-142285-0005", "prompt": "", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} + + + +# 没用 encoder_ds_rate + +# 1.15 + +# 7b batch size 开到2 ok的 + +# 6 2 0 可以 \ No newline at end of file diff --git a/scripts/finetune_avsr_debug.sh b/scripts/finetune_avsr_debug.sh index f6728478..ca9f8780 100644 --- a/scripts/finetune_avsr_debug.sh +++ b/scripts/finetune_avsr_debug.sh @@ -1,32 +1,104 @@ #!/bin/bash # export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH export CUDA_VISIBLE_DEVICES=0 -export CUDA_LAUNCH_BLOCKING=1 +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO cd /root/SLAM-LLM -audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth -speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 + +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-paddingr-20240112 -llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf #/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/yangguanrou.ygr/ckpts/llama-2-hf-finetune #/home/oss/yangguanrou.ygr/ckpts/llama-2-hf-finetune +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--llm_name vicuna-13b-v1.5 \ +--llm_path $llm_path \ +--llm_dim 4096 \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_dim 1280 \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset speech_dataset \ +--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 4 \ +--val_batch_size 4 \ +--num_workers_dataloader 4 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +# --log_file $output_dir/test.log \ +# --use_wandb \ +# --wandb_dir $output_dir \ +# --wandb_entity_name zym22 \ +# --wandb_project_name slam-llm \ +# --wandb_exp_name test \ +# --log_interval 5 \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ +# --use_peft --peft_method lora \ -# -m debugpy --listen 5680 --wait-for-client -python -m debugpy --listen 5680 --wait-for-client src/llama_recipes/pipeline/finetune.py \ ---model_name avsr \ ---use_peft --peft_method lora \ ---quantization \ ---llm_name llama-2-7b-hf \ +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--use_fp16 \ +--enable_fsdp \ +--llm_name vicuna-7b-v1.5 \ --llm_path $llm_path \ +--llm_dim 4096 \ --encoder_name whisper \ +--encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ +--encoder_dim 1280 \ --encoder_projector linear \ ---dataset avsr_dataset \ ---avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--encoder_projector_ds_rate 5 \ +--dataset speech_dataset \ +--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ --batching_strategy custom \ ---num_epochs 1 \ ---batch_size_training 2 \ +--num_epochs 100 \ +--batch_size_training 6 \ +--val_batch_size 6 \ +--num_workers_dataloader 4 \ +--lr 1e-4 \ --output_dir $output_dir \ ---stepSize 10 \ ---log_file "/root/SLAM-LLM/log/test.log" \ ---valid_subset "LRS3/val_debug.txt" \ \ No newline at end of file +--metric acc \ +--log_file /$output_dir/train.log \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name zym22 \ +--wandb_project_name slam-llm \ +--wandb_exp_name test \ +--log_interval 5 \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ +# --use_peft --peft_method lora \ +# --master_port=29501 \ +fi + +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} +# {"key": "1688-142285-0005", "prompt": "", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file diff --git a/scripts/finetune_avsr_debug_1214.sh b/scripts/finetune_avsr_debug_1214.sh deleted file mode 100644 index d44bef08..00000000 --- a/scripts/finetune_avsr_debug_1214.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 -export CUDA_LAUNCH_BLOCKING=1 -export OMP_NUM_THREADS=1 - -cd /root/SLAM-LLM - - -llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune - -# -m debugpy --listen 5678 --wait-for-client -python -m debugpy --listen 5679 --wait-for-client src/llama_recipes/pipeline/finetune.py \ ---model_name avsr \ ---freeze_llm \ ---llm_name llama-2-7b-hf \ ---llm_path $llm_path \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---dataset avsr_dataset \ ---avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ ---batching_strategy custom \ ---num_epochs 1 \ ---batch_size_training 4 \ ---lr 1e-5 \ ---output_dir $output_dir \ ---metric acc \ ---log_file "/root/SLAM-LLM/log/test.log" \ - - -# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ \ No newline at end of file diff --git a/scripts/finetune_avsr_debug_1218.sh b/scripts/finetune_avsr_debug_1218.sh deleted file mode 100644 index 01ac0f9a..00000000 --- a/scripts/finetune_avsr_debug_1218.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 -export CUDA_LAUNCH_BLOCKING=1 -export OMP_NUM_THREADS=1 - -cd /root/SLAM-LLM - - -llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune - -# -m debugpy --listen 5678 --wait-for-client -python src/llama_recipes/pipeline/finetune.py \ ---model_name avsr \ ---freeze_llm \ ---llm_name llama-2-7b-hf \ ---llm_path $llm_path \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---dataset avsr_dataset \ ---avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ ---batching_strategy custom \ ---num_epochs 1 \ ---batch_size_training 4 \ ---lr 1e-5 \ ---output_dir $output_dir \ ---metric acc \ ---log_file "/root/SLAM-LLM/log/test.log" \ - - -# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ \ No newline at end of file diff --git a/scripts/finetune_avsr_vicuna_debug_0113.sh b/scripts/finetune_avsr_vicuna_debug_0113.sh new file mode 100644 index 00000000..ab135502 --- /dev/null +++ b/scripts/finetune_avsr_vicuna_debug_0113.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1 +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +# speech_encoder_path= TODO! + + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/vicuna-7b-v1.5-finetune-avsr + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5679 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name avsr \ +--freeze_encoder \ +--freeze_llm \ +--llm_name vicuna-13b-v1.5 \ +--llm_path $llm_path \ +--llm_dim 4096 \ +--encoder_name moco_wav2vec2 \ +--encoder_ds_rate 2 \ +--encoder_dim 512 \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset avsr_dataset \ +--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ +--batching_strategy custom \ +--num_epochs 1 \ +--batch_size_training 2 \ +--num_workers_dataloader 2 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +--log_file "/root/SLAM-LLM/log/first_try.log" \ + + +# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \ + + +# --encoder_path $speech_encoder_path \ #TODO! +# --encoder_dim 1280 \ #TODO! \ No newline at end of file diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py index 5486e619..e8a6a05e 100644 --- a/src/llama_recipes/configs/datasets.py +++ b/src/llama_recipes/configs/datasets.py @@ -70,12 +70,10 @@ class avsr_dataset: noiseProb: float = 0. noiseSNR: float = 5 stepSize: int = 16384 - # charToIx={" ": 1, "'": 22, "1": 30, "0": 29, "3": 37, "2": 32, "5": 34, "4": 38, "7": 36, "6": 35, "9": 31, "8": 33, "A": 5, "C": 17, - # "B": 20, "E": 2, "D": 12, "G": 16, "F": 19, "I": 6, "H": 9, "K": 24, "J": 25, "M": 18, "L": 11, "O": 4, "N": 7, "Q": 27, - # "P": 21, "S": 8, "R": 10, "U": 13, "T": 3, "W": 15, "V": 23, "Y": 14, "X": 26, "Z": 28, "": 39} charToIx : str = "x" #应该没用了 TypeError: Object of type NotImplementedType is not JSON serializable 但这个是上面的问题 modal: str = "AV" pretrain_subset: str = "LRS3/pretrain.txt" train_subset: str = "LRS3/train.txt" valid_subset: str = "LRS3/val.txt" test_subset: str = "LRS3/test.txt" + reqInpLen: str = 80 diff --git a/src/llama_recipes/datasets/avsr_dataset.py b/src/llama_recipes/datasets/avsr_dataset.py index a40284ec..c16b5a77 100644 --- a/src/llama_recipes/datasets/avsr_dataset.py +++ b/src/llama_recipes/datasets/avsr_dataset.py @@ -5,6 +5,8 @@ import random import torch +import math +import copy import cv2 as cv from torch.nn.utils.rnn import pad_sequence @@ -12,6 +14,9 @@ import logging logger = logging.getLogger(__name__) +from llama_recipes.utils.compute_utils import calculate_output_length_1d +import torch.nn as nn + class AVSRDataset(Dataset): def __init__(self, dataset_config, tokenizer=None, split='train'): super().__init__() @@ -79,6 +84,12 @@ def __init__(self, dataset_config, tokenizer=None, split='train'): Normalize(mean=[0.4161], std=[0.1688]) ]) + # LLM new + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "{}" + self.reqInpLen = dataset_config.reqInpLen + def open_h5(self): self.h5 = h5py.File(self.h5file, "r") @@ -118,17 +129,107 @@ def __getitem__(self, index): #avsr 是shuffle的dataloader echat好像默认fa if index < 118516: #原本是96318 查过了 这个数确实是lrs2的那个行数 也就是文件数 原本应该是pretrain处理的 有一部分搞到main处理了 所以没有crop 导致超过500 #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) - inp, trgtin, trgtout, trgtLen = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) + inp, trgtin, trgtout, trgtLen, target = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) if inp==0 and trgtin ==0 and trgtout ==0 and trgtLen==0: index+=1 targetFile = self.datalist[index] + ".txt" #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile,self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) #就只是往后挪了一格 很弱 - inp, trgtin, trgtout, trgtLen = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile,self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) #就只是往后挪了一格 很弱 + inp, trgtin, trgtout, trgtLen,target = self.prepare_pretrain_input(index, self.modal, self.h5, targetFile,self.charToIx, self.transform, noise, self.noiseSNR, (3, 21), 160) #就只是往后挪了一格 很弱 else: #inp, trgtin, trgtout, trgtLen, trgttext = self.prepare_main_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR) - inp, trgtin, trgtout, trgtLen = self.prepare_main_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR) + inp, trgtin, trgtout, trgtLen,target = self.prepare_main_input(index, self.modal, self.h5, targetFile, self.charToIx, self.transform, noise, self.noiseSNR) + + + # new! + audio_raw = inp[0] #cpu torch.Size([48800]) + visual_raw = inp[1] #cpu torch.Size([77, 1, 112, 112]) + + prompt = "Transcribe video to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + + prompt = self.prompt_template.format(prompt) + answer = self.answer_template.format(target) - return inp, trgtin, trgtout, trgtLen #, trgttext #VO (none,(72,1,112,112) ) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + #audio_length, visual_length,inputLen = self.calculate_output_length(audio_raw,visual_raw) + audio_length_pre = self.calculate_output_length(audio_raw,visual_raw) #video #tensor(80) + audio_length = audio_length_pre // 5 # ad-hoc for 5x fc downsample #tensor(16) + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + # 'audio_mel': audio_mel, + 'audio_length': audio_length, + + 'inp':inp, + 'trgtin': trgtin, + 'trgtout': trgtout, + 'trgtLen':trgtLen, + + 'audio_length_pre':audio_length_pre, + } + + + #return inp, trgtin, trgtout, trgtLen #, trgttext #VO (none,(72,1,112,112) ) + + def calculate_output_length(self,audio_raw,visual_raw): + # 过wav2vec2 + audio_len = audio_raw.shape[0] + audio_len = math.floor(audio_len/320) #152 + + # visual 没有变 + visual_len= visual_raw.shape[0] #77 + + audLen = torch.tensor(audio_len) + vidLen = torch.tensor(visual_len) + + dismatch = audLen - 2 * vidLen #tensor([0, 1, 0, 2], device='cuda:0') + vidPadding = torch.ceil(torch.div(dismatch, 2)).int() #tensor([0.0000, 0.5000, 0.0000, 1.0000], device='cuda:0') tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) + vidPadding = vidPadding * (vidPadding > 0) #tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) + audPadding = 2 * vidPadding - dismatch #tensor([0, 1, 0, 0], device='cuda:0') + + mask = (vidPadding + vidLen) > self.reqInpLen #80 tensor([False, True, True, True], device='cuda:0') + vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) #tensor([21, 1, 0, 1], device='cuda:0', dtype=torch.int32) + mask = (audPadding + audLen) > 2 * self.reqInpLen #tensor([False, True, True, True], device='cuda:0') + audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) #tensor([42, 1, 0, 0], device='cuda:0') + + vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() + vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() + audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() + audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() + + audioBatch = torch.randn(audLen,1024) #.to('cuda') # pseudo audio Batch + videoBatch = torch.randn(vidLen,2048) #.to('cuda') # pseudo audio Batch + + pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding, audRightPadding)) + audioBatch = pad(audioBatch.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding, vidRightPadding)) + videoBatch = pad(videoBatch.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + + audio_length, visual_length = audioBatch.shape[0], videoBatch.shape[0] + inputLen = (vidLen + vidPadding).long() + + # 过卷积层 其实没有变 就用inputLen + return inputLen + def __len__(self): """ @@ -141,10 +242,24 @@ def __len__(self): else: return len(self.datalist) - def collator(self, dataBatch): + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + + audio_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio_length']] = 1 #downsample 再/5 + # audio & mask if not self.modal == "VO": - aud_seq_list = [data[0][0] for data in dataBatch] + #aud_seq_list = [data[0][0] for data in dataBatch] + aud_seq_list = [data['inp'][0] for data in samples] aud_padding_mask = torch.zeros((len(aud_seq_list), len(max(aud_seq_list, key=len))), dtype=torch.bool) for i, seq in enumerate(aud_seq_list): aud_padding_mask[i, len(seq):] = True @@ -154,17 +269,23 @@ def collator(self, dataBatch): aud_padding_mask = None # visual & len if not self.modal == "AO": - vis_seq_list = pad_sequence([data[0][1] for data in dataBatch], batch_first=True) #(4,147,1,112,112) #pad_sequence((none,62,1,112,112)) - vis_len = torch.tensor([len(data[0][1]) for data in dataBatch]) #就是这四个句子每一个的长度 tensor([ 62, 62, 97, 147]) #时间帧上pad + #vis_seq_list = pad_sequence([data[0][1] for data in dataBatch], batch_first=True) #(4,147,1,112,112) #pad_sequence((none,62,1,112,112)) + vis_seq_list = pad_sequence([data['inp'][1] for data in samples], batch_first=True) #(4,147,1,112,112) #pad_sequence((none,62,1,112,112)) + #vis_len = torch.tensor([len(data[0][1]) for data in dataBatch]) #就是这四个句子每一个的长度 tensor([ 62, 62, 97, 147]) #时间帧上pad + vis_len = torch.tensor([len(data['inp'][1]) for data in samples]) #就是这四个句子每一个的长度 tensor([ 62, 62, 97, 147]) #时间帧上pad + else: vis_seq_list = None vis_len = None inputBatch = (aud_seq_list, aud_padding_mask, vis_seq_list, vis_len) #!!! - targetinBatch = pad_sequence([data[1] for data in dataBatch], batch_first=True) - targetoutBatch = pad_sequence([data[2] for data in dataBatch], batch_first=True) - targetLenBatch = torch.stack([data[3] for data in dataBatch]) + # targetinBatch = pad_sequence([data[1] for data in dataBatch], batch_first=True) + # targetoutBatch = pad_sequence([data[2] for data in dataBatch], batch_first=True) + # targetLenBatch = torch.stack([data[3] for data in dataBatch]) + targetinBatch = pad_sequence([data['trgtin'] for data in samples], batch_first=True) + targetoutBatch = pad_sequence([data['trgtout'] for data in samples], batch_first=True) + targetLenBatch = torch.stack([data['trgtLen'] for data in samples]) if self.modal == "AO": inputBatch = (inputBatch[0].float(), inputBatch[1], None, None) @@ -180,17 +301,51 @@ def collator(self, dataBatch): targetMask[(torch.arange(targetMask.shape[0]), targetLenBatch.long() - 1)] = 1 targetMask = (1 - targetMask.flip([-1]).cumsum(-1).flip([-1])).bool() - return { - "inputBatch0": inputBatch[0], - "inputBatch1": inputBatch[1], - "inputBatch2": inputBatch[2], - "inputBatch3": inputBatch[3], + # return { + # "inputBatch0": inputBatch[0], + # "inputBatch1": inputBatch[1], + # "inputBatch2": inputBatch[2], + # "inputBatch3": inputBatch[3], - "targetoutBatch": targetoutBatch, - "targetLenBatch": targetLenBatch.long(), + # "targetoutBatch": targetoutBatch, + # "targetLenBatch": targetLenBatch.long(), + # 'maskw2v': True, + # } + return { + 'input_ids': input_ids, #torch.Size([4, 114]) + 'labels': labels, #torch.Size([4, 114]) + 'attention_mask': attention_mask, #torch.Size([4, 114]) + # 'audio_mel': audio_mel, + # 'audio_mel_post_mask': audio_mel_post_mask, + 'audio_mask': audio_mask, + + "audio": inputBatch[0], #torch.Size([4, 92800]) + "audiomask": inputBatch[1], #torch.Size([4, 92800]) + "visual": inputBatch[2], #torch.Size([4, 146, 1, 112, 112]) + "vis_len": inputBatch[3], #torch.Size([4]) + + "targetoutBatch": targetoutBatch, #torch.Size([4, 50]) + "targetLenBatch": targetLenBatch.long(), #torch.Size([4]) 'maskw2v': True, } + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def prepare_pretrain_input(self,index, modal, h5, targetFile, charToIx, transform, noise, noiseSNR, numWordsRange, maxLength): #(3,21) 160 """ Function to convert the data sample in the pretrain dataset into appropriate tensors. @@ -340,7 +495,7 @@ def prepare_pretrain_input(self,index, modal, h5, targetFile, charToIx, transfor else: numWords -= 1 - return inp, trgtin, trgtout, trgtLen #, trgtNWord + return inp, trgtin, trgtout, trgtLen , trgtNWord def prepare_main_input(self, index, modal, h5, targetFile, charToIx, transform, noise, noiseSNR): """ @@ -401,7 +556,7 @@ def prepare_main_input(self, index, modal, h5, targetFile, charToIx, transform, trgtout = torch.from_numpy(trgtout) trgtLen = torch.tensor(trgtLen) - return inp, trgtin, trgtout, trgtLen#,trgt #'THE FIRST TIME WHEN IT TOOK ME FIVE MONTHS FROM THE DECISION OF' + return inp, trgtin, trgtout, trgtLen,trgt #'THE FIRST TIME WHEN IT TOOK ME FIVE MONTHS FROM THE DECISION OF' def get_audio_dataset(dataset_config, tokenizer, split): diff --git a/src/llama_recipes/models/av_net.py b/src/llama_recipes/models/AV/av_net.py similarity index 80% rename from src/llama_recipes/models/av_net.py rename to src/llama_recipes/models/AV/av_net.py index f8742ccc..8a4fa658 100644 --- a/src/llama_recipes/models/av_net.py +++ b/src/llama_recipes/models/AV/av_net.py @@ -97,12 +97,18 @@ def __init__(self, model_config): def forward(self, inputBatch, maskw2v): audioBatch, audMask, videoBatch, vidLen = inputBatch #torch.Size([2, 32480]),torch.Size([2, 32480]),torch.Size([2, 52, 1, 112, 112]),[52,47] # audMask尾部有一堆true表示mask,其余都是false if not self.modal == "VO": - result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version - audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 - if audMask==None: - audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO + try: + result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整 + audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 + if audMask==None: + audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO + + audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') + except Exception as e: + print(e) + print(audioBatch.shape) + print(audMask) - audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') else: audLen = None @@ -111,8 +117,9 @@ def forward(self, inputBatch, maskw2v): videoBatch = self.visualModel(videoBatch, vidLen.long()) #torch.Size([99, 2048]) videoBatch = list(torch.split(videoBatch, vidLen.tolist(), dim=0)) #拆成一个list [(52,2048), (47, 2048)] + #print(audioBatch.shape,audLen,videoBatch[0].shape,videoBatch[1].shape, videoBatch[2].shape,videoBatch[3].shape,vidLen) audioBatch, videoBatch, inputLenBatch, mask = self.makePadding(audioBatch, audLen, videoBatch, vidLen) #[2, 160, 1024], torch.Size([2, 80, 2048]), tensor([80, 80], (2,80) #这一步比较关键 - + #print( max(max(vidLen).item()*2, max(audLen).item()), audioBatch.shape, videoBatch.shape, inputLenBatch, mask.shape) if isinstance(self.maskedLayerNorm, MaskedLayerNorm): self.maskedLayerNorm.SetMaskandLength(mask, inputLenBatch) @@ -142,8 +149,10 @@ def forward(self, inputBatch, maskw2v): jointBatch = self.EncoderPositionalEncoding(jointBatch) jointBatch = self.jointEncoder(jointBatch, src_key_padding_mask=mask) #[80, 2, 1024] + jointBatch = jointBatch.transpose(0, 1) #(2,129,1024) #new return jointBatch, inputLenBatch, mask #[80, 2, 1024], [80,80], [2,80] mask全是false + def makeMaskfromLength(self, maskShape, maskLength, maskDevice): mask = torch.zeros(maskShape, device=maskDevice) mask[(torch.arange(mask.shape[0]), maskLength - 1)] = 1 @@ -186,28 +195,30 @@ def makePadding(self, audioBatch, audLen, videoBatch, vidLen): mask = self.makeMaskfromLength(videoBatch.shape[:-1], inputLenBatch, videoBatch.device) else: - dismatch = audLen - 2 * vidLen - vidPadding = torch.ceil(torch.div(dismatch, 2)).int() - vidPadding = vidPadding * (vidPadding > 0) - audPadding = 2 * vidPadding - dismatch - - mask = (vidPadding + vidLen) > self.reqInpLen - vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) - mask = (audPadding + audLen) > 2 * self.reqInpLen - audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) - - vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() - vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() - audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() - audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() - - audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) - audioBatch = list(audioBatch) + dismatch = audLen - 2 * vidLen #tensor([0, 1, 0, 2], device='cuda:0') + vidPadding = torch.ceil(torch.div(dismatch, 2)).int() #tensor([0.0000, 0.5000, 0.0000, 1.0000], device='cuda:0') tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) + vidPadding = vidPadding * (vidPadding > 0) #tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) + audPadding = 2 * vidPadding - dismatch #tensor([0, 1, 0, 0], device='cuda:0') + + mask = (vidPadding + vidLen) > self.reqInpLen #80 tensor([False, True, True, True], device='cuda:0') + vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) #tensor([21, 1, 0, 1], device='cuda:0', dtype=torch.int32) + mask = (audPadding + audLen) > 2 * self.reqInpLen #tensor([False, True, True, True], device='cuda:0') + audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) #tensor([42, 1, 0, 0], device='cuda:0') + + vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() #tensor([10, 0, 0, 0], device='cuda:0', dtype=torch.int32) + vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() #tensor([11, 1, 0, 1], device='cuda:0', dtype=torch.int32) + audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() #tensor([21, 0, 0, 0], device='cuda:0', dtype=torch.int32) + audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() #tensor([21, 1, 0, 0], device='cuda:0', dtype=torch.int32) + # input audioBatch, torch.Size([4, 284, 1024]) + audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) #torch.Size([4, 1, 1, 284, 1024]) + audioBatch = list(audioBatch) #torch.Size([1, 1, 284, 1024]) 一个list for i, _ in enumerate(audioBatch): pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding[i], audRightPadding[i])) - audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) + audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) #audioBatch[i].shape, torch.Size([1, 1, 284, 1024]) + # print(i,audioBatch[i].shape) pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding[i], vidRightPadding[i])) videoBatch[i] = pad(videoBatch[i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + # print(i,videoBatch[i].shape) audioBatch = pad_sequence(audioBatch, batch_first=True) videoBatch = pad_sequence(videoBatch, batch_first=True) diff --git a/src/llama_recipes/models/avsr_model.py b/src/llama_recipes/models/AV/avsr_model.py similarity index 96% rename from src/llama_recipes/models/avsr_model.py rename to src/llama_recipes/models/AV/avsr_model.py index a8b9af99..808a4789 100644 --- a/src/llama_recipes/models/avsr_model.py +++ b/src/llama_recipes/models/AV/avsr_model.py @@ -15,7 +15,7 @@ from llama_recipes.utils.config_utils import generate_peft_config from llama_recipes.utils.train_utils import print_model_size -from .av_net import AVNet +from .AV.av_net import AVNet from .slam_model import setup_llm from torch.nn.utils.rnn import pad_sequence import copy @@ -35,15 +35,15 @@ def __init__( super().__init__() self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - - # audio-visual + + # audio-visual ↓ self.avnet=AVNet(model_config) - # load_ckpt + # load_ckpt ↑ checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) self.avnet.load_state_dict(checkpoint['state_dict'],strict=False) # 最终输出ctc/attention的模块没有用到 - # freeze + # freeze 外面都有 for name, param in self.avnet.named_parameters(): param.requires_grad = False self.avnet.eval() @@ -52,7 +52,7 @@ def __init__( self.llm = setup_llm(train_config, model_config, **kwargs) # projector - self.feature_projector = nn.Linear(model_config.DMODEL, self.llm.config.hidden_size) #(512,4096) + self.feature_projector = nn.Linear(model_config.DMODEL, self.llm.config.hidden_size) #(512,4096) 好像有遗留问题 TO DO # tokenizer self.tokenizer = tokenizer #tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path) 不需要保存 diff --git a/src/llama_recipes/models/moco_visual_frontend.py b/src/llama_recipes/models/AV/moco_visual_frontend.py similarity index 100% rename from src/llama_recipes/models/moco_visual_frontend.py rename to src/llama_recipes/models/AV/moco_visual_frontend.py diff --git a/src/llama_recipes/models/utils.py b/src/llama_recipes/models/AV/utils.py similarity index 100% rename from src/llama_recipes/models/utils.py rename to src/llama_recipes/models/AV/utils.py diff --git a/src/llama_recipes/models/visual_encoder.py b/src/llama_recipes/models/AV/visual_encoder.py similarity index 100% rename from src/llama_recipes/models/visual_encoder.py rename to src/llama_recipes/models/AV/visual_encoder.py diff --git a/src/llama_recipes/models/encoder.py b/src/llama_recipes/models/encoder.py index 32cfda89..e66c494c 100644 --- a/src/llama_recipes/models/encoder.py +++ b/src/llama_recipes/models/encoder.py @@ -42,4 +42,16 @@ def load(cls, model_config): BEATs_model = BEATs(cfg) BEATs_model.load_state_dict(checkpoint['model']) - return BEATs_model \ No newline at end of file + return BEATs_model + + +class AVEncoder: + + @classmethod + def load(cls, model_config): + from .AV.av_net import AVNet + avnet = AVNet(model_config) + checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) + avnet.load_state_dict(checkpoint['state_dict'],strict=False) + + return avnet \ No newline at end of file diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index b440e820..98a04689 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -45,6 +45,9 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "beats": from llama_recipes.models.encoder import BEATsEncoder encoder = BEATsEncoder.load(model_config) + if encoder_name == "moco_wav2vec2": + from llama_recipes.models.encoder import AVEncoder + encoder = AVEncoder.load(model_config) print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) if train_config.freeze_encoder: @@ -184,13 +187,25 @@ def forward(self, audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper audio_mask = kwargs.get("audio_mask", None) + audio = kwargs.get("audio", None) #torch.Size([2, 96480]) + audiomask = kwargs.get("audiomask", None) #删 #torch.Size([2, 96480]) + visual = kwargs.get("visual", None) #torch.Size([2, 151, 1, 112, 112]) + vis_len = kwargs.get("vis_len", None) #tensor([ 77, 151], device='cuda:0', dtype=torch.int32) + maskw2v = kwargs.get("maskw2v", None) #True + targetoutBatch = kwargs.get("targetoutBatch", None) #torch.Size([2, 29]) + targetLenBatch = kwargs.get("targetLenBatch", None) #tensor([18, 29], device='cuda:0') + + + encoder_outs = None - if audio_mel is not None: + if audio_mel is not None or audio is not None: if self.model_config.encoder_name == "whisper": encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim if self.model_config.encoder_name == "beats": encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim - + if self.model_config.encoder_name == "moco_wav2vec2": + encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audiomask, visual, vis_len) ,maskw2v) # bs*seq*dim + if self.model_config.encoder_projector == "q-former": encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) if self.model_config.encoder_projector == "linear": diff --git a/src/llama_recipes/pipeline/model_factory.py b/src/llama_recipes/pipeline/model_factory.py index 60a65276..93897d11 100644 --- a/src/llama_recipes/pipeline/model_factory.py +++ b/src/llama_recipes/pipeline/model_factory.py @@ -9,11 +9,8 @@ def model_factory(train_config, model_config, **kwargs): tokenizer = setup_tokenizer(train_config, model_config, **kwargs) - if train_config.model_name=="avsr": - from llama_recipes.models.avsr_model import setupavsr_model - model = setupavsr_model(tokenizer, train_config, model_config, **kwargs) - else: - model = setup_model(tokenizer, train_config, model_config, **kwargs) + + model = setup_model(tokenizer, train_config, model_config, **kwargs) ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) if ckpt_path is not None: