diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index 8b1ecda1..00acf7c9 100644 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -1,7 +1,7 @@ #cd /root/SLAM-LLM -trans="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt" -preds="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred" +trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds10-proj2048-steplrwarmup1e-4decay-fuyu-lora_qkvo_promptshort-lowergt-20240220/asr/6/decode_log_test_other_beam4_repetition_penalty1_bs1_gt" +preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds10-proj2048-steplrwarmup1e-4decay-fuyu-lora_qkvo_promptshort-lowergt-20240220/asr/6/decode_log_test_other_beam4_repetition_penalty1_bs1_pred" # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer diff --git a/scripts/finetune_llama_mupt.sh b/scripts/finetune_llama_mupt.sh new file mode 100644 index 00000000..6f91b951 --- /dev/null +++ b/scripts/finetune_llama_mupt.sh @@ -0,0 +1,134 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt +# speech_encoder_path=//nfs/maziyang.mzy/models/Whisper/small.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Base.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Base+.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt +text_encoder_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 + +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 +# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf +# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 +llm_path=/nfs/maziyang.mzy/models/MuPT_v1_8192 + +output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-symbol-ds1-proj2048-steplrwarmup1e-4decay-lora-20240224-test +# ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-20240131/asr/4 + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +hydra.run.dir=$output_dir \ +++model_config.llm_name="MuPT_v1_8192" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=1536 \ +++model_config.encoder_name="TinyLlama-1.1B-Chat-v0.4" \ +++model_config.encoder_path=$text_encoder_path \ +++model_config.encoder_dim=2048 \ +++model_config.encoder_projector=linear \ +++model_config.encoder_projector_ds_rate=1 \ +++dataset_config.dataset=text_dataset \ +++dataset_config.file="src/llama_recipes/datasets/text_dataset.py:get_text_dataset" \ +++dataset_config.tokenizer_path=$text_encoder_path \ +++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +++dataset_config.input_type=features \ +++train_config.model_name=mupt \ +++train_config.freeze_encoder=true \ +++train_config.use_peft=true \ +++train_config.peft_config.peft_method=lora \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=1000 \ +++train_config.batch_size_training=1 \ +++train_config.val_batch_size=1 \ +++train_config.num_workers_dataloader=1 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +# ++train_config.freeze_llm=true \ +# ++ckpt_path=$ckpt_path/model.pt \ +# ++model_config.encoder_projector=q-former \ +# ++dataset_config.fix_length_audio=64 \ +# ++peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ + + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +--master_port=29501 \ +src/llama_recipes/pipeline/finetune.py \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +hydra.run.dir=$output_dir \ +++model_config.llm_name="MuPT_v1_8192" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=1536 \ +++model_config.encoder_name="TinyLlama-1.1B-Chat-v0.4" \ +++model_config.encoder_path=$text_encoder_path \ +++model_config.encoder_dim=2048 \ +++model_config.encoder_projector=linear \ +++model_config.encoder_projector_ds_rate=1 \ +++dataset_config.dataset=text_dataset \ +++dataset_config.file="src/llama_recipes/datasets/text_dataset.py:get_text_dataset" \ +++dataset_config.tokenizer_path=$text_encoder_path \ +++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +++dataset_config.input_type=features \ +++train_config.model_name=mupt \ +++train_config.freeze_encoder=true \ +++train_config.use_peft=true \ +++train_config.peft_config.peft_method=lora \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=10000 \ +++train_config.total_steps=1000000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=5000 \ +++train_config.batch_size_training=6 \ +++train_config.val_batch_size=6 \ +++train_config.num_workers_dataloader=4 \ +++train_config.output_dir=$output_dir \ +++train_config.enable_fsdp=false \ +++train_config.enable_ddp=true \ +++train_config.use_fp16=true \ +++metric=acc \ +++log_config.log_file=/$output_dir/train.log \ +++log_config.use_wandb=true \ +++log_config.wandb_dir=$output_dir \ +++log_config.wandb_entity_name=zym22 \ +++log_config.wandb_project_name=slam-llm \ +++log_config.wandb_exp_name=${0##*/%.*} \ +++log_config.log_interval=5 \ +# ++train_config.freeze_llm=true \ +# ++ckpt_path=$ckpt_path/model.pt \ +# ++model_config.encoder_projector=q-former \ +# ++dataset_config.fix_length_audio=64 \ +# ++peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ +fi + +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} +# {"key": "1688-142285-0005", "prompt": "", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file diff --git a/src/llama_recipes/datasets/text_dataset.py b/src/llama_recipes/datasets/text_dataset.py new file mode 100644 index 00000000..ed984739 --- /dev/null +++ b/src/llama_recipes/datasets/text_dataset.py @@ -0,0 +1,198 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal + +import torch +from torch.utils.data import Dataset + + +class TextDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + + # self.data_list = contents + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = "X:1L:1/8Q:1/8=200M:4/4K:Gmin|:\"Gm\" BGdB" + self.prompt_template = "{}" + self.answer_template = "{}" + self.fix_length_text = dataset_config.get("fix_length_text", -1) # for Q-former + self.inference_mode = dataset_config.get("inference_mode", False) + self.input_type = dataset_config.get("input_type", None) + assert self.input_type in ["raw", "features"], "input_type must be one of [raw, features]" + if self.input_type == "features": + from transformers import AutoTokenizer + self.instruct_tokenizer = AutoTokenizer.from_pretrained(dataset_config.get("tokenizer_path", "Llama-2-7b-hf")) + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + # # debug + # with open(dataset_config.train_data_path, encoding='utf-8') as fin: + # for line in fin: + # data_dict = json.loads(line.strip()) + # self.data_list.append(data_dict) + # if split == "train": + # self.data_list = self.data_list[:80] + # else: + # self.data_list = self.data_list[80:100] + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + instruct = data_dict.get("instruct", "Dummy Instruct") + target = data_dict.get("target", "Dummy Target") + + prompt = self.prompt + prompt = self.prompt_template.format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + if self.input_type == "raw": + instruct_length = 0 + prompt = instruct + prompt + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + elif self.input_type == "features": + instruct_ids = self.instruct_tokenizer.encode(instruct) + instruct_length = len(instruct_ids) + instruct_ids = torch.tensor(instruct_ids, dtype=torch.int64) if instruct_ids is not None else None + + if self.fix_length_text > 0: # for Q-former + instruct_length = self.fix_length_text + instruct_pseudo = torch.full((instruct_length,), -1) # placeholder + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((instruct_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + "instruct_ids": instruct_ids if self.input_type == "features" else None, + "instruct_length": instruct_length, + } + + answer = self.answer_template.format(target) + example = prompt + answer + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((instruct_pseudo, example_ids)) # [instruct,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [instruct,prompt,answer,eos] + labels_ids[:instruct_length + prompt_length] = -1 # [-1,-1,answer,eos] + example_mask = example_ids.ge(-1) # [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [instruct,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "instruct_ids": instruct_ids if self.input_type == "features" else None, + "instruct_length": instruct_length, + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + if self.input_type == "raw": + instruct_max_length = 0 + instruct_ids = None + elif self.input_type == "features": + instruct_max_length = max([s['instruct_ids'].shape[0] for s in samples]) + instruct_ids = torch.stack([self.pad(s['instruct_ids'], instruct_max_length, self.instruct_tokenizer.pad_token_id) + for s in samples]) + instruct_mask = torch.zeros(len(samples), instruct_max_length) + for line, sample in enumerate(samples): + instruct_mask[line, :sample['instruct_length']] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['instruct_length']] = 1 + + if self.inference_mode: + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "instruct_ids": instruct_ids if self.input_type == "features" else None, + "instruct_mask": instruct_mask if self.input_type == "features" else None, + "modality_mask": modality_mask, + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "instruct_ids": instruct_ids if self.input_type == "features" else None, + "instruct_mask": instruct_mask if self.input_type == "features" else None, + "modality_mask": modality_mask + } + + + +def get_text_dataset(dataset_config, tokenizer, split): + dataset = TextDatasetJsonl(dataset_config, tokenizer, split) + + return dataset diff --git a/src/llama_recipes/models/encoder.py b/src/llama_recipes/models/encoder.py index c027d1b4..223b8163 100644 --- a/src/llama_recipes/models/encoder.py +++ b/src/llama_recipes/models/encoder.py @@ -75,4 +75,12 @@ def load(cls, model_config): checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) avnet.load_state_dict(checkpoint['state_dict'],strict=False) - return avnet \ No newline at end of file + return avnet + +class HfTextEncoder: + + @classmethod + def load(cls, model_config): + from transformers import AutoModel + model = AutoModel.from_pretrained(model_config.encoder_path) + return model \ No newline at end of file diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index e06c0c4f..d76c4c37 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -6,7 +6,7 @@ import torch.nn.functional as F import torch.distributed as dist from typing import List, Optional, Tuple, Union -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training from llama_recipes.utils.config_utils import generate_peft_config @@ -25,8 +25,13 @@ def setup_model(tokenizer, train_config, model_config, **kwargs): def setup_tokenizer(train_config, model_config, **kwargs): # Load the tokenizer and add special tokens - tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) - tokenizer.pad_token_id = tokenizer.eos_token_id + if "mupt" in model_config.llm_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path, + trust_remote_code=True, + use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) + tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer @@ -48,6 +53,9 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "moco_wav2vec2": from llama_recipes.models.encoder import AVEncoder encoder = AVEncoder.load(model_config) + if "llama" in encoder_name.lower(): + from llama_recipes.models.encoder import HfTextEncoder + encoder = HfTextEncoder.load(model_config) print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) if train_config.freeze_encoder: @@ -185,7 +193,6 @@ def forward(self, audio_mel = kwargs.get("audio_mel", None) audio_mel_mask = kwargs.get("audio_mel_mask", None) audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper - modality_mask = kwargs.get("modality_mask", None) audio = kwargs.get("audio", None) audio_mask = kwargs.get("audio_mask", None) @@ -193,6 +200,11 @@ def forward(self, vis_len = kwargs.get("vis_len", None) maskw2v = kwargs.get("maskw2v", False) #(FIX:MZY) False for supervised learning and inference + # for text encoder + instruct_ids = kwargs.get("instruct_ids", None) + instruct_mask = kwargs.get("instruct_mask", None) + + modality_mask = kwargs.get("modality_mask", None) encoder_outs = None if audio_mel is not None or audio is not None: @@ -212,6 +224,16 @@ def forward(self, if self.model_config.encoder_projector == "linear": encoder_outs = self.encoder_projector(encoder_outs) + if instruct_ids is not None: + if self.encoder is not None: + encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state + + if self.model_config.encoder_projector == "q-former": + encoder_outs = self.encoder_projector(encoder_outs, instruct_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: input_ids[input_ids == -1] = 0 if hasattr(self.llm.model, "embed_tokens"): diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index 205b891f..dda72029 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -253,8 +253,8 @@ def main(kwargs: DictConfig): optimizer, lr_lambda=lambda step: ( min(step / train_config.warmup_steps, 1) if step < train_config.warmup_steps - else 1 - # else max(0.0, 1 - (step - train_config.warmup_steps) / (train_config.total_steps - train_config.warmup_steps)) + else max(0.0, 1 - (step - train_config.warmup_steps) / (train_config.total_steps - train_config.warmup_steps)) + # else 1 ) ) diff --git a/src/llama_recipes/utils/compute_ppl.py b/src/llama_recipes/utils/compute_ppl.py index a5521f3c..b9ac59a1 100644 --- a/src/llama_recipes/utils/compute_ppl.py +++ b/src/llama_recipes/utils/compute_ppl.py @@ -3,15 +3,17 @@ from tqdm import tqdm import json -MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" +# MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" +MODEL_PATH = "/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf" +# MODEL_PATH = "/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) -device = 'cuda:0' +device = 'cuda:7' model.to(device) model.eval() -corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl" +corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl" corpus = [] with open(corpus_path, encoding='utf-8') as fin: for line in fin: @@ -22,7 +24,7 @@ total_tokens = 0 for sentence in tqdm(corpus): - inputs = tokenizer(sentence, return_tensors="pt").to(device) + inputs = tokenizer(sentence.strip().lower(), return_tensors="pt").to(device) input_ids = inputs["input_ids"] # input_len = input_ids.size(1) diff --git a/src/llama_recipes/utils/dataset_utils.py b/src/llama_recipes/utils/dataset_utils.py index ba9ffe1c..fcbbc365 100644 --- a/src/llama_recipes/utils/dataset_utils.py +++ b/src/llama_recipes/utils/dataset_utils.py @@ -58,6 +58,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): "custom_dataset": get_custom_dataset, "speech_dataset": get_custom_dataset, "audio_dataset": get_custom_dataset, + "text_dataset": get_custom_dataset, "avsr_dataset": get_custom_dataset, } diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index ac3dbeaa..2b1ad8e5 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -175,11 +175,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: logger.info(f"we are about to save the PEFT modules") if train_config.enable_fsdp: - if getattr(ShardingStrategy, fsdp_config.sharding_strategy) == ShardingStrategy.FULL_SHARD: + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: save_model_checkpoint_peft_full_shard( model, optimizer, rank, train_config, epoch=epoch ) - elif getattr(ShardingStrategy, fsdp_config.sharding_strategy) == ShardingStrategy.NO_SHARD: + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: if rank==0: save_model_checkpoint_peft( model, optimizer, rank, train_config, epoch=epoch @@ -205,11 +205,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche elif not train_config.use_peft and train_config.freeze_llm: logger.info(f"llm is frozen, we are about to save other parts.") if train_config.enable_fsdp: - if getattr(ShardingStrategy, fsdp_config.sharding_strategy) == ShardingStrategy.FULL_SHARD: + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: save_model_checkpoint_peft_full_shard( model, optimizer, rank, train_config, epoch=epoch ) - elif getattr(ShardingStrategy, fsdp_config.sharding_strategy) == ShardingStrategy.NO_SHARD: + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: if rank==0: save_model_checkpoint_peft( model, optimizer, rank, train_config, epoch=epoch