From 8a019bf525f1e2b5c559694d56356cb74bb9b186 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Wed, 20 Dec 2023 22:50:51 +0800 Subject: [PATCH] fix wandb and logging --- .gitignore | 4 +- scripts/finetune_speech_pretraining.sh | 57 +++++++++++-------- src/llama_recipes/configs/__init__.py | 1 + src/llama_recipes/configs/log.py | 15 +++++ src/llama_recipes/configs/training.py | 2 - src/llama_recipes/datasets/speech_dataset.py | 8 ++- .../model_checkpointing/checkpoint_handler.py | 4 +- src/llama_recipes/pipeline/finetune.py | 30 +++++++--- src/llama_recipes/utils/train_utils.py | 31 +++++++--- 9 files changed, 104 insertions(+), 48 deletions(-) create mode 100644 src/llama_recipes/configs/log.py diff --git a/.gitignore b/.gitignore index 0e0e2506..f6082dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,5 @@ debug.py .idea/* transformers wandb/ -*.log -log \ No newline at end of file +log/ +*.log \ No newline at end of file diff --git a/scripts/finetune_speech_pretraining.sh b/scripts/finetune_speech_pretraining.sh index 53e9787d..db5d5abd 100644 --- a/scripts/finetune_speech_pretraining.sh +++ b/scripts/finetune_speech_pretraining.sh @@ -1,6 +1,7 @@ #!/bin/bash -#export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1 export CUDA_LAUNCH_BLOCKING=1 export OMP_NUM_THREADS=1 @@ -14,7 +15,7 @@ cd /root/SLAM-LLM speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt # speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-qformer64-proj2048-lr1e-5-whisper-test +output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt-renew5-finetunepeft-test # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then @@ -27,34 +28,39 @@ python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/fin --encoder_name whisper \ --encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ ---encoder_projector q-former \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ --dataset custom_dataset \ ---custom_dataset.fix_length_audio 64 \ --custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \ --custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.trans.jsonl \ --custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ --batching_strategy custom \ --num_epochs 100 \ ---batch_size_training 4 \ ---val_batch_size 4 \ +--batch_size_training 16 \ +--val_batch_size 16 \ +--num_workers_dataloader 4 \ --lr 1e-5 \ --output_dir $output_dir \ ---run_test_during_validation \ ---run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \ ---run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \ --metric acc \ -# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ -# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name zym22 \ +--wandb_project_name slam-llm \ +--wandb_exp_name test \ +--log_file /$output_dir/test.log \ +--log_interval 5 \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ # --use_peft --peft_method lora \ else torchrun \ --nnodes 1 \ ---nproc_per_node 4 \ +--nproc_per_node 2 \ src/llama_recipes/pipeline/finetune.py \ --model_name asr \ --freeze_encoder \ ---use_peft --peft_method lora \ +--freeze_llm \ --use_fp16 \ --enable_fsdp \ --llm_name llama-2-7b-hf \ @@ -70,18 +76,23 @@ src/llama_recipes/pipeline/finetune.py \ --custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ --batching_strategy custom \ --num_epochs 100 \ ---batch_size_training 16 \ ---val_batch_size 16 \ +--batch_size_training 8 \ +--val_batch_size 8 \ --num_workers_dataloader 4 \ --lr 1e-5 \ --output_dir $output_dir \ ---run_test_during_validation \ ---run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \ ---run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \ --metric acc \ -# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \ -# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \ -# --freeze_llm \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name zym22 \ +--wandb_project_name slam-llm \ +--wandb_exp_name test \ +--log_file /$output_dir/test.log \ +--log_interval 5 \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ +# --use_peft --peft_method lora \ fi -# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file +# {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} +# {"key": "1688-142285-0005", "prompt": "", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} \ No newline at end of file diff --git a/src/llama_recipes/configs/__init__.py b/src/llama_recipes/configs/__init__.py index 58d0851c..679b022d 100644 --- a/src/llama_recipes/configs/__init__.py +++ b/src/llama_recipes/configs/__init__.py @@ -5,3 +5,4 @@ from llama_recipes.configs.fsdp import fsdp_config from llama_recipes.configs.training import train_config from llama_recipes.configs.model import model_config +from llama_recipes.configs.log import log_config diff --git a/src/llama_recipes/configs/log.py b/src/llama_recipes/configs/log.py new file mode 100644 index 00000000..662a1a1e --- /dev/null +++ b/src/llama_recipes/configs/log.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass + + +@dataclass +class log_config: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name : str = "project_name" + wandb_project_name : str = "project_name" + wandb_exp_name : str = "exp_name" + log_file: str="/root/test.log" + log_interval: int = 5 \ No newline at end of file diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 65c13d8c..1239c7f5 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -36,10 +36,8 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels - log_file: str="PATH/to/Log_File" run_test_during_validation: bool = False run_test_during_validation_file: str = "test.wav" run_test_during_validation_prompt: str = "<|ASR|>" freeze_llm: bool = False freeze_encoder: bool = False - log_interval: int = 5 diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index 606a78ed..ee8c6ded 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -46,10 +46,14 @@ def __init__(self, self.data_list.append(data_dict) # # debug + # with open(dataset_config.train_data_path, encoding='utf-8') as fin: + # for line in fin: + # data_dict = json.loads(line.strip()) + # self.data_list.append(data_dict) # if split == "train": - # self.data_list = contents[:80] + # self.data_list = self.data_list[:80] # else: - # self.data_list = contents[80:100] + # self.data_list = self.data_list[80:100] def get_source_len(self, data_dict): return data_dict["source_len"] diff --git a/src/llama_recipes/model_checkpointing/checkpoint_handler.py b/src/llama_recipes/model_checkpointing/checkpoint_handler.py index 020b4d00..d12b74c4 100644 --- a/src/llama_recipes/model_checkpointing/checkpoint_handler.py +++ b/src/llama_recipes/model_checkpointing/checkpoint_handler.py @@ -166,7 +166,7 @@ def save_model_checkpoint( def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): logger.info(f"--> saving model ...") - save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch)) + save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) os.makedirs(save_dir, exist_ok=True) if not cfg.freeze_llm: model.llm.save_pretrained(save_dir) @@ -183,7 +183,7 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): project_dict[key] = cpu_state[key] torch.save(project_dict, save_full_path) - logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + logger.info(f"model checkpoint saved for epoch {epoch+1} at {save_full_path}\n") diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index d070594a..ebe67916 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -21,6 +21,7 @@ from llama_recipes.configs import fsdp_config as FSDP_CONFIG from llama_recipes.configs import train_config as TRAIN_CONFIG from llama_recipes.configs import model_config as MODEL_CONFIG +from llama_recipes.configs import log_config as LOG_CONFIG from llama_recipes.data.concatenator import ConcatDataset # util @@ -48,14 +49,12 @@ def main(**kwargs): # Update the configuration for the training and sharding process - train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() - update_config((train_config, fsdp_config, model_config), **kwargs) - - # Set wandb - wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config)} - wandb.init(project="project_name",name="exp_name",config=wandb_config) #记录参数 + train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() + update_config((train_config, fsdp_config, model_config, log_config), **kwargs) # Set log + if not os.path.exists(os.path.dirname(log_config.log_file)): + os.makedirs(os.path.dirname(log_config.log_file)) logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -66,7 +65,7 @@ def main(**kwargs): logger = logging.getLogger() logger.setLevel(logging.INFO) - file_handler = logging.FileHandler(filename=train_config.log_file, mode='w') + file_handler = logging.FileHandler(filename=log_config.log_file, mode='w') file_handler.setLevel(logging.INFO) file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') file_handler.setFormatter(file_formatter) @@ -100,6 +99,14 @@ def main(**kwargs): clear_gpu_cache(local_rank) setup_environ_flags(rank) + # Set wandb + if not train_config.enable_fsdp or rank == 0: + if log_config.use_wandb: + if not os.path.exists(log_config.wandb_dir): + os.makedirs(log_config.wandb_dir) + wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config), "log_config":vars(log_config)} + wandb.init(dir=log_config.wandb_dir, entity=log_config.wandb_entity_name, project=log_config.wandb_project_name,name=log_config.wandb_exp_name ,config=wandb_config) + model, tokenizer = model_factory(train_config, model_config, **kwargs) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. model.to(device) @@ -137,7 +144,9 @@ def main(**kwargs): dataset_config = generate_dataset_config(train_config, kwargs) logger.info("dataset_config: {}".format(dataset_config)) - wandb.config.update( {"dataset_config": vars(dataset_config)} ) + if not train_config.enable_fsdp or rank == 0: + if log_config.use_wandb: + wandb.config.update( {"dataset_config": vars(dataset_config)} ) # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( @@ -209,6 +218,7 @@ def main(**kwargs): scheduler, train_config.gradient_accumulation_steps, train_config, + log_config, fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, @@ -216,7 +226,9 @@ def main(**kwargs): if not train_config.enable_fsdp or rank==0: [logger.info(f'Key: {k}, Value: {v}') for k, v in results.items()] - wandb.finish() + if not train_config.enable_fsdp or rank == 0: + if log_config.use_wandb: + wandb.finish() if __name__ == "__main__": fire.Fire(main) \ No newline at end of file diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index dbc2978b..75a9c57d 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -41,7 +41,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer): def byte2mb(x): return int(x / 2**20) -def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None): +def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, log_config,fsdp_config=None, local_rank=None, rank=None): """ Trains the model on the given dataloader @@ -54,6 +54,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche num_epochs: The number of epochs to train for local_rank: The rank of the current node in a distributed setting train_config: The training configuration + log_config: The logging configuration eval_dataloader: The dataloader containing the eval data tokenizer: tokenizer used in the eval for decoding the predicitons @@ -88,7 +89,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) for step, batch in enumerate(train_dataloader): for key in batch.keys(): - if type(batch[key])==bool: #train的时候是true infer的时候是false + if type(batch[key])==bool: continue if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) @@ -102,8 +103,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche loss = loss / gradient_accumulation_steps acc = acc / gradient_accumulation_steps - if step % train_config.log_interval == 0: - wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}) + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) total_loss += loss.detach().float() total_acc += acc @@ -143,7 +148,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_loss.append(train_epoch_loss) train_acc.append(train_epoch_acc) - wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + if log_config.use_wandb: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + else: + wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) if train_config.enable_fsdp: if rank==0: @@ -214,7 +224,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") logger.info("=====================================================") - j(model, rank, train_config) + save_model_and_optimizer_sharded(model, rank, train_config) if train_config.save_optimizer: save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") @@ -244,7 +254,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: val_acc.append(-1) - wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) + if log_config.use_wandb: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) + else: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) if train_config.run_test_during_validation: if train_config.enable_fsdp: @@ -315,7 +330,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): with MemoryTrace() as memtrace: for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): for key in batch.keys(): - if type(batch[key])==bool: #train的时候是true infer的时候是false + if type(batch[key])==bool: continue if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank)