Skip to content

Commit

Permalink
fix wandb and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Dec 20, 2023
1 parent 8942893 commit 8a019bf
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 48 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ debug.py
.idea/*
transformers
wandb/
*.log
log
log/
*.log
57 changes: 34 additions & 23 deletions scripts/finetune_speech_pretraining.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

Expand All @@ -14,7 +15,7 @@ cd /root/SLAM-LLM
speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-qformer64-proj2048-lr1e-5-whisper-test
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt-renew5-finetunepeft-test

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
Expand All @@ -27,34 +28,39 @@ python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/fin
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector q-former \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.fix_length_audio 64 \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \
--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.trans.jsonl \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--batch_size_training 16 \
--val_batch_size 16 \
--num_workers_dataloader 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \
--run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--metric acc \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
--use_wandb \
--wandb_dir $output_dir \
--wandb_entity_name zym22 \
--wandb_project_name slam-llm \
--wandb_exp_name test \
--log_file /$output_dir/test.log \
--log_interval 5 \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \
# --use_peft --peft_method lora \

else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--nproc_per_node 2 \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--use_peft --peft_method lora \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--llm_name llama-2-7b-hf \
Expand All @@ -70,18 +76,23 @@ src/llama_recipes/pipeline/finetune.py \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 16 \
--val_batch_size 16 \
--batch_size_training 8 \
--val_batch_size 8 \
--num_workers_dataloader 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \
--run_test_during_validation_prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--metric acc \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
# --freeze_llm \
--use_wandb \
--wandb_dir $output_dir \
--wandb_entity_name zym22 \
--wandb_project_name slam-llm \
--wandb_exp_name test \
--log_file /$output_dir/test.log \
--log_interval 5 \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \
# --use_peft --peft_method lora \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
# {"key": "1688-142285-0005", "prompt": "<ASR>", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
1 change: 1 addition & 0 deletions src/llama_recipes/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from llama_recipes.configs.fsdp import fsdp_config
from llama_recipes.configs.training import train_config
from llama_recipes.configs.model import model_config
from llama_recipes.configs.log import log_config
15 changes: 15 additions & 0 deletions src/llama_recipes/configs/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass


@dataclass
class log_config:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name : str = "project_name"
wandb_project_name : str = "project_name"
wandb_exp_name : str = "exp_name"
log_file: str="/root/test.log"
log_interval: int = 5
2 changes: 0 additions & 2 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ class train_config:
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
log_file: str="PATH/to/Log_File"
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
run_test_during_validation_prompt: str = "<|ASR|>"
freeze_llm: bool = False
freeze_encoder: bool = False
log_interval: int = 5
8 changes: 6 additions & 2 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,14 @@ def __init__(self,
self.data_list.append(data_dict)

# # debug
# with open(dataset_config.train_data_path, encoding='utf-8') as fin:
# for line in fin:
# data_dict = json.loads(line.strip())
# self.data_list.append(data_dict)
# if split == "train":
# self.data_list = contents[:80]
# self.data_list = self.data_list[:80]
# else:
# self.data_list = contents[80:100]
# self.data_list = self.data_list[80:100]

def get_source_len(self, data_dict):
return data_dict["source_len"]
Expand Down
4 changes: 2 additions & 2 deletions src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def save_model_checkpoint(

def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch))
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
os.makedirs(save_dir, exist_ok=True)
if not cfg.freeze_llm:
model.llm.save_pretrained(save_dir)
Expand All @@ -183,7 +183,7 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
project_dict[key] = cpu_state[key]
torch.save(project_dict, save_full_path)

logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
logger.info(f"model checkpoint saved for epoch {epoch+1} at {save_full_path}\n")



Expand Down
30 changes: 21 additions & 9 deletions src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.configs import model_config as MODEL_CONFIG
from llama_recipes.configs import log_config as LOG_CONFIG
from llama_recipes.data.concatenator import ConcatDataset

# util
Expand Down Expand Up @@ -48,14 +49,12 @@

def main(**kwargs):
# Update the configuration for the training and sharding process
train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG()
update_config((train_config, fsdp_config, model_config), **kwargs)

# Set wandb
wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config)}
wandb.init(project="project_name",name="exp_name",config=wandb_config) #记录参数
train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG()
update_config((train_config, fsdp_config, model_config, log_config), **kwargs)

# Set log
if not os.path.exists(os.path.dirname(log_config.log_file)):
os.makedirs(os.path.dirname(log_config.log_file))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand All @@ -66,7 +65,7 @@ def main(**kwargs):
logger = logging.getLogger()
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler(filename=train_config.log_file, mode='w')
file_handler = logging.FileHandler(filename=log_config.log_file, mode='w')
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
Expand Down Expand Up @@ -100,6 +99,14 @@ def main(**kwargs):
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

# Set wandb
if not train_config.enable_fsdp or rank == 0:
if log_config.use_wandb:
if not os.path.exists(log_config.wandb_dir):
os.makedirs(log_config.wandb_dir)
wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config), "log_config":vars(log_config)}
wandb.init(dir=log_config.wandb_dir, entity=log_config.wandb_entity_name, project=log_config.wandb_project_name,name=log_config.wandb_exp_name ,config=wandb_config)

model, tokenizer = model_factory(train_config, model_config, **kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
model.to(device)
Expand Down Expand Up @@ -137,7 +144,9 @@ def main(**kwargs):

dataset_config = generate_dataset_config(train_config, kwargs)
logger.info("dataset_config: {}".format(dataset_config))
wandb.config.update( {"dataset_config": vars(dataset_config)} )
if not train_config.enable_fsdp or rank == 0:
if log_config.use_wandb:
wandb.config.update( {"dataset_config": vars(dataset_config)} )

# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
Expand Down Expand Up @@ -209,14 +218,17 @@ def main(**kwargs):
scheduler,
train_config.gradient_accumulation_steps,
train_config,
log_config,
fsdp_config if train_config.enable_fsdp else None,
local_rank if train_config.enable_fsdp else None,
rank if train_config.enable_fsdp else None,
)
if not train_config.enable_fsdp or rank==0:
[logger.info(f'Key: {k}, Value: {v}') for k, v in results.items()]

wandb.finish()
if not train_config.enable_fsdp or rank == 0:
if log_config.use_wandb:
wandb.finish()

if __name__ == "__main__":
fire.Fire(main)
31 changes: 23 additions & 8 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
def byte2mb(x):
return int(x / 2**20)

def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, log_config,fsdp_config=None, local_rank=None, rank=None):
"""
Trains the model on the given dataloader
Expand All @@ -54,6 +54,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
num_epochs: The number of epochs to train for
local_rank: The rank of the current node in a distributed setting
train_config: The training configuration
log_config: The logging configuration
eval_dataloader: The dataloader containing the eval data
tokenizer: tokenizer used in the eval for decoding the predicitons
Expand Down Expand Up @@ -88,7 +89,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
for step, batch in enumerate(train_dataloader):
for key in batch.keys():
if type(batch[key])==bool: #train的时候是true infer的时候是false
if type(batch[key])==bool:
continue
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
Expand All @@ -102,8 +103,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
loss = loss / gradient_accumulation_steps
acc = acc / gradient_accumulation_steps

if step % train_config.log_interval == 0:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc})
if log_config.use_wandb and step % log_config.log_interval == 0:
if train_config.enable_fsdp:
if rank==0:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
else:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))

total_loss += loss.detach().float()
total_acc += acc
Expand Down Expand Up @@ -143,7 +148,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
train_loss.append(train_epoch_loss)
train_acc.append(train_epoch_acc)

wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
if log_config.use_wandb:
if train_config.enable_fsdp:
if rank==0:
wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
else:
wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})

if train_config.enable_fsdp:
if rank==0:
Expand Down Expand Up @@ -214,7 +224,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
logger.info("=====================================================")

j(model, rank, train_config)
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
Expand Down Expand Up @@ -244,7 +254,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
else:
val_acc.append(-1)

wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]})
if log_config.use_wandb:
if train_config.enable_fsdp:
if rank==0:
wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]})
else:
wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]})

if train_config.run_test_during_validation:
if train_config.enable_fsdp:
Expand Down Expand Up @@ -315,7 +330,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
with MemoryTrace() as memtrace:
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
for key in batch.keys():
if type(batch[key])==bool: #train的时候是true infer的时候是false
if type(batch[key])==bool:
continue
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
Expand Down

0 comments on commit 8a019bf

Please sign in to comment.