diff --git a/requirements.txt b/requirements.txt index 5ce37523..65d68f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,6 @@ transformers>=4.31.0 sentencepiece py7zr scipy -optimum \ No newline at end of file +optimum +wandb +hydra-core>=1.3.2 \ No newline at end of file diff --git a/scripts/conf/asr_vicuna_lora.yaml b/scripts/conf/asr_vicuna_lora.yaml new file mode 100644 index 00000000..ccc983df --- /dev/null +++ b/scripts/conf/asr_vicuna_lora.yaml @@ -0,0 +1,111 @@ + +model_config: + llm_name: "vicuna-13b-v1.5" + llm_path: "PATH/to/LLAMA/7B" + llm_dim: 4096 + encoder_name: null + encoder_ds_rate: 2 + encoder_path: null + encoder_dim: 1280 + encoder_projector: "linear" + encoder_projector_ds_rate: 5 + + DMODEL: 512 + FRONTEND_DMODEL: 1024 #这个是专门指moco的 + TX_ATTENTION_HEADS: 8 + TX_NUM_LAYERS: 6 + PE_MAX_LENGTH: 500 + AUDIO_FEATURE_SIZE: 1024 + VIDEO_FEATURE_SIZE: 2048 + TX_FEEDFORWARD_DIM: 2048 + TX_DROPOUT: 0.1 + CHAR_NUM_CLASSES: 40 + + WORD_NUM_CLASSES: 500 + FRAME_LENGTH: 29 + MOCO_FRONTEND_FILE: "/nfs/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" + WAV2VEC_FILE: "/nfs/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" + MAIN_REQ_INPUT_LENGTH: int = 80 + modal: "AV" + TRAIN_LRS3_MODEL_FILE: "/nfs/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" # "/home/oss/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" #单一模态是这个 + TRAINED_AO_FILE: "/nfs/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" + TRAINED_VO_FILE: "/nfs/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" + + +train_config: + model_name: "PATH/to/LLAMA/7B" + enable_fsdp: false + low_cpu_fsdp: false + run_validation: true + batch_size_training: 4 + batching_strategy: "packing" #alternative: padding + context_length: 4096 + gradient_accumulation_steps: 1 + num_epochs: 3 + num_workers_dataloader: 1 + lr: 1e-4 + weight_decay: 0.0 + gamma: 0.85 + seed: 42 + use_fp16: false + mixed_precision: true + val_batch_size: 1 + + use_peft: false + peft_config: + peft_method: "lora" # None , llama_adapter, prefix + r: 8 + lora_alpha: 32 + target_modules: [ "q_proj", "v_proj" ] + bias: null + task_type: "CAUSAL_LM" + lora_dropout: 0.05 + inference_mode: false + output_dir: "PATH/to/save/PEFT/model" + freeze_layers: false + num_freeze_layers: 1 + quantization: false + one_gpu: false + save_model: true + dist_checkpoint_root_folder: "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder: "fine-tuned" # will be used if using FSDP + save_optimizer: false # will be used if using FSDP + use_fast_kernels: false # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation: false + run_test_during_validation_file: "test.wav" + run_test_during_validation_prompt: "<|ASR|>" + freeze_llm: false + freeze_encoder: false + +dataset_config: + dataset: "samsum_dataset" + file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset" + train_data_path: null + val_data_path: null + train_split: "train" + test_split: "validation" + data_path: null + max_words: null + max_mel: null + fix_length_audio: -1 + +fsdp_config: + mixed_precision: true + use_fp16: false + # sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP + checkpoint_type: "StateDictType.SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: true + fsdp_cpu_offload: false + pure_bf16: false + optimizer: "AdamW" + +log_config: + use_wandb: false + wandb_dir: "/root/test_wandb" + wandb_entity_name : "project_name" + wandb_project_name : "project_name" + wandb_exp_name : "exp_name" + log_file: "/root/test.log" + log_interval: 5 + diff --git a/scripts/finetune_asr_vicuna.sh b/scripts/finetune_asr_vicuna.sh index 6e3b52b2..0f50d305 100644 --- a/scripts/finetune_asr_vicuna.sh +++ b/scripts/finetune_asr_vicuna.sh @@ -22,79 +22,88 @@ output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e- # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then -python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ ---model_name asr \ ---freeze_encoder \ ---freeze_llm \ ---llm_name vicuna-13b-v1.5 \ ---llm_path $llm_path \ ---llm_dim 5120 \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_dim 1280 \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---dataset speech_dataset \ ---speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ ---speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ ---batching_strategy custom \ ---num_epochs 100 \ ---batch_size_training 4 \ ---val_batch_size 4 \ ---num_workers_dataloader 4 \ ---lr 1e-4 \ ---output_dir $output_dir \ ---metric acc \ -# --log_file $output_dir/test.log \ -# --use_wandb \ -# --wandb_dir $output_dir \ -# --wandb_entity_name zym22 \ -# --wandb_project_name slam-llm \ -# --wandb_exp_name test \ -# --log_interval 5 \ +python src/llama_recipes/pipeline/finetune.py \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=whisper \ +++model_config.encoder_ds_rate=2 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1280 \ +++model_config.encoder_projector=linear \ +++model_config.encoder_projector_ds_rate=5 \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +++train_config.model_name=asr \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.num_epochs=100 \ +++train_config.batch_size_training=4 \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=4 \ +++train_config.lr=1e-4 \ +++train_config.output_dir=$output_dir \ +++train_config.peft_config.peft_method=lora \ +++metric=acc \ +#++log_config.log_file=/$output_dir/train.log \ +#++log_config.use_wandb=true \ +#++log_config.wandb_dir=$output_dir \ +#++log_config.wandb_entity_name=zym22 \ +#++log_config.wandb_project_name=slam-llm \ +#++log_config.wandb_exp_name=test \ +#++log_config.log_interval 5 \ # --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ # --use_peft --peft_method lora \ +##vicuna-7b-v1.5 else torchrun \ --nnodes 1 \ --nproc_per_node 4 \ --master_port=29502 \ src/llama_recipes/pipeline/finetune.py \ ---model_name asr \ ---freeze_encoder \ ---freeze_llm \ ---use_fp16 \ ---enable_fsdp \ ---llm_name vicuna-7b-v1.5 \ ---llm_path $llm_path \ ---llm_dim 4096 \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_dim 1280 \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---dataset speech_dataset \ ---speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ ---speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ ---batching_strategy custom \ ---num_epochs 100 \ ---batch_size_training 4 \ ---val_batch_size 4 \ ---num_workers_dataloader 4 \ ---lr 1e-4 \ ---output_dir $output_dir \ ---metric acc \ ---log_file /$output_dir/train.log \ ---use_wandb \ ---wandb_dir $output_dir \ ---wandb_entity_name zym22 \ ---wandb_project_name slam-llm \ ---wandb_exp_name test \ ---log_interval 5 \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=whisper \ +++model_config.encoder_ds_rate=2 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1280 \ +++model_config.encoder_projector=linear \ +++model_config.encoder_projector_ds_rate=5 \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +++train_config.model_name=asr \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.num_epochs=100 \ +++train_config.batch_size_training=4 \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=4 \ +++train_config.lr=1e-4 \ +++train_config.output_dir=$output_dir \ +++train_config.peft_config.peft_method=lora \ +++train_config.enable_fsdp=true \ +++train_config.enable_ddp=false \ +++train_config.use_fp16=true \ +++metric=acc \ +#++log_config.log_file=/$output_dir/train.log \ +#++log_config.use_wandb=true \ +#++log_config.wandb_dir=$output_dir \ +#++log_config.wandb_entity_name=zym22 \ +#++log_config.wandb_project_name=slam-llm \ +#++log_config.wandb_exp_name=test \ +#++log_config.log_interval 5 \ + # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ # --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ # --use_peft --peft_method lora \ diff --git a/src/llama_recipes/configs/__init__.py b/src/llama_recipes/configs/__init__.py deleted file mode 100644 index 679b022d..00000000 --- a/src/llama_recipes/configs/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config -from llama_recipes.configs.fsdp import fsdp_config -from llama_recipes.configs.training import train_config -from llama_recipes.configs.model import model_config -from llama_recipes.configs.log import log_config diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py deleted file mode 100644 index e8a6a05e..00000000 --- a/src/llama_recipes/configs/datasets.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class samsum_dataset: - dataset: str = "samsum_dataset" - train_split: str = "train" - test_split: str = "validation" - - -@dataclass -class grammar_dataset: - dataset: str = "grammar_dataset" - train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" - test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" - - -@dataclass -class alpaca_dataset: - dataset: str = "alpaca_dataset" - train_split: str = "train" - test_split: str = "val" - data_path: str = "src/llama_recipes/datasets/alpaca_data.json" - - -@dataclass -class speech_dataset: - dataset: str = "speech_dataset" - file: str = "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset" - train_split: str = "train" - test_split: str = "validation" - data_path: str = None - max_words: int = None - train_data_path: str = None - val_data_path: str = None - max_words: int = None - max_mel: int = None - fix_length_audio: int = -1 - - -@dataclass -class audio_dataset: - dataset: str = "audio_dataset" - file: str = "src/llama_recipes/datasets/audio_dataset.py:get_audio_dataset" - train_split: str = "train" - test_split: str = "validation" - data_path: str = None - fbank_mean: float = 15.41663 - fbank_std: float = 6.55582 - max_words: int = None - train_data_path: str = None - val_data_path: str = None - max_words: int = None - max_mel: int = None - fix_length_audio: int = -1 - - -@dataclass -class avsr_dataset: - dataset: str = "avsr_dataset" - file: str = "examples/avsr_dataset.py" - train_split: str = "train" - test_split: str = "val" - data_path: str = "/nfs/yangguanrou.ygr/" #"/home/oss/yangguanrou.ygr/" - h5file: str = "/nfs/yangguanrou.ygr/LRS3/LRS3.h5" # "/home/oss/yangguanrou.ygr/LRS3/LRS3.h5" - noiseFile : str = "/nfs/yangguanrou.ygr/AVSR/LRS3/Noise.h5" #"/home/oss/yangguanrou.ygr/AVSR/LRS3/Noise.h5" - noiseProb: float = 0. - noiseSNR: float = 5 - stepSize: int = 16384 - charToIx : str = "x" #应该没用了 TypeError: Object of type NotImplementedType is not JSON serializable 但这个是上面的问题 - modal: str = "AV" - pretrain_subset: str = "LRS3/pretrain.txt" - train_subset: str = "LRS3/train.txt" - valid_subset: str = "LRS3/val.txt" - test_subset: str = "LRS3/test.txt" - reqInpLen: str = 80 diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py deleted file mode 100644 index 7cf54289..00000000 --- a/src/llama_recipes/configs/fsdp.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType - -@dataclass -class fsdp_config: - mixed_precision: bool=True - use_fp16: bool=False - # sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: ShardingStrategy = ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP - checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. - fsdp_activation_checkpointing: bool=True - fsdp_cpu_offload: bool=False - pure_bf16: bool = False - optimizer: str= "AdamW" - diff --git a/src/llama_recipes/configs/log.py b/src/llama_recipes/configs/log.py deleted file mode 100644 index 662a1a1e..00000000 --- a/src/llama_recipes/configs/log.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class log_config: - use_wandb: bool = False - wandb_dir: str = "/root/test_wandb" - wandb_entity_name : str = "project_name" - wandb_project_name : str = "project_name" - wandb_exp_name : str = "exp_name" - log_file: str="/root/test.log" - log_interval: int = 5 \ No newline at end of file diff --git a/src/llama_recipes/configs/model.py b/src/llama_recipes/configs/model.py deleted file mode 100644 index 2080f9cd..00000000 --- a/src/llama_recipes/configs/model.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class model_config: - llm_name: str = "llama-2-7b-hf" - llm_path: str = "PATH/to/LLAMA/7B" - llm_dim: int = 4096 - encoder_name: str = None - encoder_ds_rate: int = 2 - encoder_path: str = None - encoder_dim: int = 1280 - encoder_projector: str = "linear" - encoder_projector_ds_rate: int = 5 - - DMODEL: int = 512 - FRONTEND_DMODEL: int = 1024 #这个是专门指moco的 - TX_ATTENTION_HEADS: int = 8 - TX_NUM_LAYERS: int = 6 - PE_MAX_LENGTH: int = 500 - AUDIO_FEATURE_SIZE: int = 1024 - VIDEO_FEATURE_SIZE: int = 2048 - TX_FEEDFORWARD_DIM: int= 2048 - TX_DROPOUT: int = 0.1 - CHAR_NUM_CLASSES: int = 40 - - WORD_NUM_CLASSES: int = 500 - FRAME_LENGTH: int = 29 - MOCO_FRONTEND_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" - WAV2VEC_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" - MAIN_REQ_INPUT_LENGTH: int = 80 - modal: str = "AV" - TRAIN_LRS3_MODEL_FILE: str = "/nfs/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" # "/home/oss/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" #单一模态是这个 - TRAINED_AO_FILE : str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" - TRAINED_VO_FILE: str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" - - - - - - - - diff --git a/src/llama_recipes/configs/peft.py b/src/llama_recipes/configs/peft.py deleted file mode 100644 index 73de09f2..00000000 --- a/src/llama_recipes/configs/peft.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass, field -from typing import List - -@dataclass -class lora_config: - r: int=8 - lora_alpha: int=32 - target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - bias= "none" - task_type: str= "CAUSAL_LM" - lora_dropout: float=0.05 - inference_mode: bool = False - -@dataclass -class llama_adapter_config: - adapter_len: int= 10 - adapter_layers: int= 30 - task_type: str= "CAUSAL_LM" - -@dataclass -class prefix_config: - num_virtual_tokens: int=30 - task_type: str= "CAUSAL_LM" \ No newline at end of file diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py deleted file mode 100644 index 1239c7f5..00000000 --- a/src/llama_recipes/configs/training.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class train_config: - model_name: str="PATH/to/LLAMA/7B" - enable_fsdp: bool=False - low_cpu_fsdp: bool=False - run_validation: bool=True - batch_size_training: int=4 - batching_strategy: str="packing" #alternative: padding - context_length: int=4096 - gradient_accumulation_steps: int=1 - num_epochs: int=3 - num_workers_dataloader: int=1 - lr: float=1e-4 - weight_decay: float=0.0 - gamma: float= 0.85 - seed: int=42 - use_fp16: bool=False - mixed_precision: bool=True - val_batch_size: int=1 - dataset = "samsum_dataset" - peft_method: str = "lora" # None , llama_adapter, prefix - use_peft: bool=False - output_dir: str = "PATH/to/save/PEFT/model" - freeze_layers: bool = False - num_freeze_layers: int = 1 - quantization: bool = False - one_gpu: bool = False - save_model: bool = True - dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP - dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP - save_optimizer: bool=False # will be used if using FSDP - use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels - run_test_during_validation: bool = False - run_test_during_validation_file: str = "test.wav" - run_test_during_validation_prompt: str = "<|ASR|>" - freeze_llm: bool = False - freeze_encoder: bool = False diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index 5a51d330..257627cb 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -31,7 +31,7 @@ def __init__(self, self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss self.prompt_template = "USER: {}\n ASSISTANT:" self.answer_template = "{}" - self.fix_length_audio = dataset_config.fix_length_audio + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) self.data_list = [] if split == "train": @@ -45,15 +45,7 @@ def __init__(self, data_dict = json.loads(line.strip()) self.data_list.append(data_dict) - # # debug - # with open(dataset_config.train_data_path, encoding='utf-8') as fin: - # for line in fin: - # data_dict = json.loads(line.strip()) - # self.data_list.append(data_dict) - # if split == "train": - # self.data_list = self.data_list[:80] - # else: - # self.data_list = self.data_list[80:100] + def get_source_len(self, data_dict): return data_dict["source_len"] diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py deleted file mode 100644 index 2ec5c234..00000000 --- a/src/llama_recipes/finetuning.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import os -from pkg_resources import packaging - -import fire -import random -import torch -import torch.optim as optim -from peft import get_peft_model, prepare_model_for_int8_training -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, -) -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload -from torch.optim.lr_scheduler import StepLR -from transformers import ( - LlamaForCausalLM, - LlamaTokenizer, - LlamaConfig, -) -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.data.concatenator import ConcatDataset -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, - get_dataloader_kwargs, -) -from llama_recipes.utils.dataset_utils import get_preprocessed_dataset - -from llama_recipes.utils.train_utils import ( - train, - freeze_transformer_layers, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies -) - - -def main(**kwargs): - # Update the configuration for the training and sharding process - train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() - update_config((train_config, fsdp_config), **kwargs) - - # Set the seeds for reproducibility - torch.cuda.manual_seed(train_config.seed) - torch.manual_seed(train_config.seed) - random.seed(train_config.seed) - - if train_config.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # Load the pre-trained model and setup its configuration - use_cache = False if train_config.enable_fsdp else None - if train_config.enable_fsdp and train_config.low_cpu_fsdp: - """ - for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - this avoids cpu oom when loading large models like llama 70B, in which case - model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - overhead and currently requires latest nightly. - """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") - if rank == 0: - model = LlamaForCausalLM.from_pretrained( - train_config.model_name, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) - else: - llama_config = LlamaConfig.from_pretrained(train_config.model_name) - llama_config.use_cache = use_cache - with torch.device("meta"): - model = LlamaForCausalLM(llama_config) - - else: - model = LlamaForCausalLM.from_pretrained( - train_config.model_name, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) - if train_config.enable_fsdp and train_config.use_fast_kernels: - """ - For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable - using of Flash Attention or Xformer memory-efficient kernels - based on the hardware being used. This would speed up fine-tuning. - """ - try: - from optimum.bettertransformer import BetterTransformer - model = BetterTransformer.transform(model) - except ImportError: - print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") - - # Load the tokenizer and add special tokens - tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) - tokenizer.pad_token_id = tokenizer.eos_token_id - - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - - # Prepare the model for int8 training if quantization is enabled - if train_config.quantization: - model = prepare_model_for_int8_training(model) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if train_config.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - if train_config.use_peft: - peft_config = generate_peft_config(train_config, kwargs) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - - #setting up FSDP if enable_fsdp is enabled - if train_config.enable_fsdp: - if not train_config.use_peft and train_config.freeze_layers: - - freeze_transformer_layers(train_config.num_freeze_layers) - - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) - - model = FSDP( - model, - auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - device_id=torch.cuda.current_device(), - limit_all_gathers=True, - sync_module_states=train_config.low_cpu_fsdp, - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if train_config.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - elif not train_config.quantization and not train_config.enable_fsdp: - model.to("cuda") - - dataset_config = generate_dataset_config(train_config, kwargs) - - # Load and preprocess the dataset for training and validation - dataset_train = get_preprocessed_dataset( - tokenizer, - dataset_config, - split="train", - ) - - if not train_config.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(dataset_train)}") - - dataset_val = get_preprocessed_dataset( - tokenizer, - dataset_config, - split="test", - ) - if not train_config.enable_fsdp or rank == 0: - print(f"--> Validation Set Length = {len(dataset_val)}") - - if train_config.batching_strategy == "packing": - dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) - - train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") - - # Create DataLoaders for the training and validation dataset - train_dataloader = torch.utils.data.DataLoader( - dataset_train, - num_workers=train_config.num_workers_dataloader, - pin_memory=True, - **train_dl_kwargs, - ) - - eval_dataloader = None - if train_config.run_validation: - if train_config.batching_strategy == "packing": - dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) - - val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") - - eval_dataloader = torch.utils.data.DataLoader( - dataset_val, - num_workers=train_config.num_workers_dataloader, - pin_memory=True, - **val_dl_kwargs, - ) - - # Initialize the optimizer and learning rate scheduler - if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": - optimizer = AnyPrecisionAdamW( - model.parameters(), - lr=train_config.lr, - momentum_dtype=torch.bfloat16, - variance_dtype=torch.bfloat16, - use_kahan_summation=False, - weight_decay=train_config.weight_decay, - ) - else: - optimizer = optim.AdamW( - model.parameters(), - lr=train_config.lr, - weight_decay=train_config.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) - - # Start the training process - results = train( - model, - train_dataloader, - eval_dataloader, - tokenizer, - optimizer, - scheduler, - train_config.gradient_accumulation_steps, - train_config, - fsdp_config if train_config.enable_fsdp else None, - local_rank if train_config.enable_fsdp else None, - rank if train_config.enable_fsdp else None, - ) - if not train_config.enable_fsdp or rank==0: - [print(f'Key: {k}, Value: {v}') for k, v in results.items()] - -if __name__ == "__main__": - fire.Fire(main) diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index ad3d1e8a..f686ba41 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -44,20 +44,20 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "moco_wav2vec2": from llama_recipes.models.encoder import AVEncoder encoder = AVEncoder.load(model_config) - print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) if train_config.freeze_encoder: for name, param in encoder.named_parameters(): param.requires_grad = False encoder.eval() - print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) return encoder def setup_llm(train_config, model_config, **kwargs): from pkg_resources import packaging - use_cache = False if train_config.enable_fsdp else None - if train_config.enable_fsdp and train_config.low_cpu_fsdp: + use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None + if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp: """ for FSDP, we can save cpu memory by loading pretrained model on rank0 only. this avoids cpu oom when loading large models like llama 70B, in which case @@ -90,7 +90,7 @@ def setup_llm(train_config, model_config, **kwargs): device_map="auto" if train_config.quantization else None, use_cache=use_cache, ) - if train_config.enable_fsdp and train_config.use_fast_kernels: + if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels: """ For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable using of Flash Attention or Xformer memory-efficient kernels @@ -102,7 +102,7 @@ def setup_llm(train_config, model_config, **kwargs): except ImportError: logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.") - print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) # Prepare the model for int8 training if quantization is enabled if train_config.quantization: @@ -119,11 +119,11 @@ def setup_llm(train_config, model_config, **kwargs): model.print_trainable_parameters() elif train_config.use_peft: logger.info("setup peft...") - peft_config = generate_peft_config(train_config, kwargs) + peft_config = generate_peft_config(train_config) model = get_peft_model(model, peft_config) model.print_trainable_parameters() - print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) return model def setup_encoder_projector(train_config, model_config, **kwargs): @@ -136,7 +136,7 @@ def setup_encoder_projector(train_config, model_config, **kwargs): elif model_config.encoder_projector == "q-former": from llama_recipes.models.projector import EncoderProjectorQFormer encoder_projector = EncoderProjectorQFormer(model_config) - print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) return encoder_projector diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index 6ce5a71c..de7170aa 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -14,24 +14,22 @@ from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ) +from torch.nn.parallel import DistributedDataParallel as DDP + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing # config -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.configs import model_config as MODEL_CONFIG -from llama_recipes.configs import log_config as LOG_CONFIG +# from llama_recipes.configs import fsdp_config as FSDP_CONFIG +# from llama_recipes.configs import train_config as TRAIN_CONFIG +# from llama_recipes.configs import model_config as MODEL_CONFIG +# from llama_recipes.configs import log_config as LOG_CONFIG from llama_recipes.data.concatenator import ConcatDataset # util from llama_recipes.utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, - get_dataloader_kwargs, -) +from llama_recipes.utils.config_utils import get_dataloader_kwargs + from llama_recipes.utils.dataset_utils import get_preprocessed_dataset from llama_recipes.utils.train_utils import ( train, @@ -47,11 +45,49 @@ import logging import wandb -def main(**kwargs): - # Update the configuration for the training and sharding process - train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() - update_config((train_config, fsdp_config, model_config, log_config), **kwargs) +import hydra +from omegaconf import DictConfig, ListConfig, OmegaConf + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + main(kwargs) + +def main(kwargs: DictConfig): + # Update the configuration for the training and sharding process + # train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() + # update_config((train_config, fsdp_config, model_config, log_config), **kwargs) + + train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ + kwargs.fsdp_config, \ + kwargs.model_config, \ + kwargs.log_config, \ + kwargs.dataset_config + fsdp_config.use_fp16 = train_config.use_fp16 + del kwargs.train_config + del kwargs.fsdp_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config + # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) @@ -86,7 +122,7 @@ def main(**kwargs): torch.manual_seed(train_config.seed) random.seed(train_config.seed) - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: setup() # torchrun specific local_rank = int(os.environ["LOCAL_RANK"]) @@ -100,7 +136,7 @@ def main(**kwargs): setup_environ_flags(rank) # Set wandb - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: if log_config.use_wandb: if not os.path.exists(log_config.wandb_dir): os.makedirs(log_config.wandb_dir, exist_ok=True) @@ -112,7 +148,7 @@ def main(**kwargs): # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if train_config.enable_fsdp and fsdp_config.pure_bf16: + if (train_config.enable_fsdp or train_config.enable_ddp) and fsdp_config.pure_bf16: model.to(torch.bfloat16) #setting up FSDP if enable_fsdp is enabled @@ -120,7 +156,8 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(train_config.num_freeze_layers) - + from torch.distributed.fsdp import ShardingStrategy + fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) @@ -138,12 +175,16 @@ def main(**kwargs): ) if fsdp_config.fsdp_activation_checkpointing: apply_fsdp_checkpointing(model) - elif not train_config.quantization and not train_config.enable_fsdp: + elif train_config.enable_ddp: + model = model.cuda(local_rank) + model = DDP(model, device_ids=[local_rank], + find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) + elif not train_config.quantization: model.to(device) - dataset_config = generate_dataset_config(train_config, kwargs) + # dataset_config = generate_dataset_config(train_config, kwargs) logger.info("dataset_config: {}".format(dataset_config)) - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: if log_config.use_wandb: wandb.config.update( {"dataset_config": vars(dataset_config)} ) @@ -153,14 +194,14 @@ def main(**kwargs): dataset_config, split="train", ) - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: logger.info(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="val", ) - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: logger.info(f"--> Validation Set Length = {len(dataset_val)}") if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) @@ -219,15 +260,15 @@ def main(**kwargs): train_config, log_config, fsdp_config if train_config.enable_fsdp else None, - local_rank if train_config.enable_fsdp else None, - rank if train_config.enable_fsdp else None, + local_rank if train_config.enable_fsdp or train_config.enable_ddp else None, + rank if train_config.enable_fsdp or train_config.enable_ddp else None, ) - if not train_config.enable_fsdp or rank==0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank==0: [logger.info(f'Key: {k}, Value: {v}') for k, v in results.items()] - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: if log_config.use_wandb: wandb.finish() if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + main_hydra() \ No newline at end of file diff --git a/src/llama_recipes/pipeline/inference.py b/src/llama_recipes/pipeline/inference.py index 704952c7..cc7bbd16 100644 --- a/src/llama_recipes/pipeline/inference.py +++ b/src/llama_recipes/pipeline/inference.py @@ -1,25 +1,52 @@ -import fire +# import fire +import logging import random import torch # import argparse from llama_recipes.models.slam_model import slam_model # config -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.configs import model_config as MODEL_CONFIG -from llama_recipes.utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, - get_dataloader_kwargs, -) +# from llama_recipes.configs import fsdp_config as FSDP_CONFIG +# from llama_recipes.configs import train_config as TRAIN_CONFIG +# from llama_recipes.configs import model_config as MODEL_CONFIG + from llama_recipes.pipeline.model_factory import model_factory -def main(**kwargs): +import hydra +from omegaconf import DictConfig, ListConfig, OmegaConf + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + main(kwargs) + +def main(kwargs: DictConfig): # Update the configuration for the training and sharding process - train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() - update_config((train_config, fsdp_config, model_config), **kwargs) + # train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() + # update_config((train_config, fsdp_config, model_config), **kwargs) + train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ + kwargs.fsdp_config, \ + kwargs.model_config, \ + kwargs.log_config, \ + kwargs.dataset_config # Set the seeds for reproducibility torch.cuda.manual_seed(train_config.seed) @@ -45,4 +72,4 @@ def main(**kwargs): if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + main_hydra() \ No newline at end of file diff --git a/src/llama_recipes/pipeline/inference_batch.py b/src/llama_recipes/pipeline/inference_batch.py index c413de86..8194348f 100644 --- a/src/llama_recipes/pipeline/inference_batch.py +++ b/src/llama_recipes/pipeline/inference_batch.py @@ -1,30 +1,57 @@ -import fire +# import fire import random import torch +import logging # import argparse from llama_recipes.models.slam_model import slam_model # config -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.configs import model_config as MODEL_CONFIG -from llama_recipes.configs import log_config as LOG_CONFIG -from llama_recipes.utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, - get_dataloader_kwargs, -) +# from llama_recipes.configs import fsdp_config as FSDP_CONFIG +# from llama_recipes.configs import train_config as TRAIN_CONFIG +# from llama_recipes.configs import model_config as MODEL_CONFIG +# from llama_recipes.configs import log_config as LOG_CONFIG +from llama_recipes.utils.config_utils import generate_dataset_config from llama_recipes.pipeline.model_factory import model_factory from llama_recipes.utils.dataset_utils import get_preprocessed_dataset import os import logging -def main(**kwargs): +import hydra +from omegaconf import DictConfig, ListConfig, OmegaConf + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + main(kwargs) - # Update the configuration for the training and sharding process - train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() - update_config((train_config, fsdp_config, model_config, log_config), **kwargs) +def main(kwargs: DictConfig): + + # Update the configuration for the training and sharding process + # train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() + # update_config((train_config, fsdp_config, model_config, log_config), **kwargs) + train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ + kwargs.fsdp_config, \ + kwargs.model_config, \ + kwargs.log_config, \ + kwargs.dataset_config # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) @@ -71,7 +98,7 @@ def main(**kwargs): dataset_config, split="test", ) - if not train_config.enable_fsdp or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: logger.info(f"--> Training Set Length = {len(dataset_test)}") test_dataloader = torch.utils.data.DataLoader( @@ -100,4 +127,4 @@ def main(**kwargs): if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + main_hydra() \ No newline at end of file diff --git a/src/llama_recipes/pipeline/model_factory.py b/src/llama_recipes/pipeline/model_factory.py index 93897d11..d88ef08c 100644 --- a/src/llama_recipes/pipeline/model_factory.py +++ b/src/llama_recipes/pipeline/model_factory.py @@ -18,5 +18,5 @@ def model_factory(train_config, model_config, **kwargs): ckpt_dict = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt_dict, strict=False) - print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp else 0) + print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) return model, tokenizer diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index d8115139..202b0c2c 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -2,7 +2,7 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import inspect -from dataclasses import asdict +# from dataclasses import asdict import torch.distributed as dist from torch.utils.data import DistributedSampler @@ -14,67 +14,74 @@ from transformers import default_data_collator from transformers.data import DataCollatorForSeq2Seq -from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config +# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler from llama_recipes.utils.dataset_utils import DATASET_PREPROC +from omegaconf import OmegaConf + import logging logger = logging.getLogger(__name__) -def update_config(config, **kwargs): - if isinstance(config, (tuple, list)): - for c in config: - update_config(c, **kwargs) - else: - for k, v in kwargs.items(): - if hasattr(config, k): - setattr(config, k, v) - elif "." in k: - # allow --some_config.some_param=True - config_name, param_name = k.split(".") - if type(config).__name__ == config_name: - if hasattr(config, param_name): - setattr(config, param_name, v) - else: - # In case of specialized config we can warm user - logger.warning(f"Warning: {config_name} does not accept parameter: {k}") - elif isinstance(config, train_config): - logger.warning(f"Warning: unknown parameter {k}") +# def update_config(config, **kwargs): +# if isinstance(config, (tuple, list)): +# for c in config: +# update_config(c, **kwargs) +# else: +# for k, v in kwargs.items(): +# if hasattr(config, k): +# setattr(config, k, v) +# elif "." in k: +# # allow --some_config.some_param=True +# config_name, param_name = k.split(".") +# if type(config).__name__ == config_name: +# if hasattr(config, param_name): +# setattr(config, param_name, v) +# else: +# # In case of specialized config we can warm user +# logger.warning(f"Warning: {config_name} does not accept parameter: {k}") +# elif isinstance(config, train_config): +# logger.warning(f"Warning: unknown parameter {k}") def generate_peft_config(train_config, kwargs): - configs = (lora_config, llama_adapter_config, prefix_config) - peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) - names = tuple(c.__name__.rstrip("_config") for c in configs) - - assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" - - config = configs[names.index(train_config.peft_method)]() - - update_config(config, **kwargs) - params = asdict(config) - peft_config = peft_configs[names.index(train_config.peft_method)](**params) + # configs = (lora_config, llama_adapter_config, prefix_config) + # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) + peft_configs = {"lora": LoraConfig, + "llama_adapter": AdaptionPromptConfig, + "prefix": PrefixTuningConfig + } + # names = tuple(c.__name__.rstrip("_config") for c in configs) + # + # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" + # + # config = configs[names.index(train_config.peft_method)]() + config = train_config.peft_config + + params = OmegaConf.to_container(config, resolve=True) + # peft_config = peft_configs[names.index(train_config.peft_method)](**params) + peft_config = peft_configs[config.get("peft_method", "lora")](**params) return peft_config -def generate_dataset_config(train_config, kwargs): - names = tuple(DATASET_PREPROC.keys()) - - assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" - - dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() - - update_config(dataset_config, **kwargs) - - return dataset_config +# def generate_dataset_config(train_config, kwargs): +# names = tuple(DATASET_PREPROC.keys()) +# +# assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" +# +# dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() +# +# update_config(dataset_config, **kwargs) +# +# return dataset_config def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size if train_config.batching_strategy == "padding": - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( dataset, batch_size=batch_size, @@ -86,7 +93,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) elif train_config.batching_strategy == "packing": - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(), @@ -98,7 +105,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["collate_fn"] = default_data_collator else: # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(), diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 1fa73123..7a6884ee 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -63,11 +63,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche Returns: results dictionary containing average training and validation perplexity and loss """ # Create a gradient scaler for fp16 - if train_config.use_fp16 and train_config.enable_fsdp: - scaler = ShardedGradScaler() - elif train_config.use_fp16 and not train_config.enable_fsdp: + # if train_config.use_fp16 and train_config.enable_fsdp: + # scaler = ShardedGradScaler() + # elif train_config.use_fp16 and not train_config.enable_fsdp: + # scaler = torch.cuda.amp.GradScaler() + if train_config.use_fp16: scaler = torch.cuda.amp.GradScaler() - if train_config.enable_fsdp: + if train_config.enable_fsdp: + scaler = ShardedGradScaler() + if train_config.enable_fsdp or train_config.enable_ddp: world_size = int(os.environ["WORLD_SIZE"]) autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext @@ -94,7 +98,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche for key in batch.keys(): if type(batch[key])==bool: continue - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: batch[key] = batch[key].to(local_rank) else: batch[key] = batch[key].to('cuda:0') @@ -107,7 +111,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche acc = acc / gradient_accumulation_steps if log_config.use_wandb and step % log_config.log_interval == 0: - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) else: @@ -137,12 +141,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche epoch_end_time = time.perf_counter()-epoch_start_time epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp): dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) train_epoch_acc = total_acc / len(train_dataloader) - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: train_epoch_loss = train_epoch_loss/world_size train_epoch_acc = train_epoch_acc/world_size train_perplexity = torch.exp(train_epoch_loss) @@ -152,13 +156,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche train_acc.append(train_epoch_acc) if log_config.use_wandb: - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) else: wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") @@ -180,15 +184,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche eval_epoch_acc = rest[0] if rest else -1 checkpoint_start_time = time.perf_counter() if train_config.save_model and (eval_epoch_loss < best_val_loss or eval_epoch_acc > best_val_acc): - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: dist.barrier() if train_config.use_peft: - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info(f"we are about to save the PEFT modules") else: logger.info(f"we are about to save the PEFT modules") - if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. + if train_config.enable_fsdp or train_config.enable_ddp: #(FIX:MZY):We now only support full_shard and no_shard. if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: save_model_checkpoint_peft_full_shard( model, optimizer, rank, train_config, epoch=epoch @@ -204,7 +208,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche save_model_checkpoint_peft( model, optimizer, rank, train_config, epoch=epoch ) - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") else: @@ -212,7 +216,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche elif not train_config.use_peft and train_config.freeze_llm: logger.info(f"llm is frozen, we are about to save other parts.") - if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. + if train_config.enable_fsdp or train_config.enable_ddp: #(FIX:MZY):We now only support full_shard and no_shard. if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: save_model_checkpoint_peft_full_shard( model, optimizer, rank, train_config, epoch=epoch @@ -229,12 +233,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ) else: - if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + if not train_config.use_peft and fsdp_config.checkpoint_type == "StateDictType.FULL_STATE_DICT": save_model_checkpoint( model, optimizer, rank, train_config, epoch=epoch ) - elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + elif not train_config.use_peft and fsdp_config.checkpoint_type == "StateDictType.SHARDED_STATE_DICT": logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") logger.info("=====================================================") @@ -250,13 +254,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ) logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") logger.info("=====================================================") - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: dist.barrier() checkpoint_end_time = time.perf_counter() - checkpoint_start_time checkpoint_times.append(checkpoint_end_time) if eval_epoch_loss < best_val_loss: best_val_loss = eval_epoch_loss - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") else: @@ -269,14 +273,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche val_acc.append(-1) if log_config.use_wandb: - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) else: wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) if train_config.run_test_during_validation: - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info("=====================================") logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") @@ -290,7 +294,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche with autocast(): logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) logger.info("=====================================") - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") else: @@ -316,7 +320,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche results["avg_checkpoint_time"] = avg_checkpoint_time #saving the training params including fsdp setting for reference. - if train_config.enable_fsdp and not train_config.use_peft: + if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft: save_train_params(train_config, fsdp_config, rank) return results @@ -333,7 +337,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): Returns: eval_ppl, eval_epoch_loss """ - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: world_size = int(os.environ["WORLD_SIZE"]) model.eval() eval_preds = [] @@ -346,7 +350,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): for key in batch.keys(): if type(batch[key])==bool: continue - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: batch[key] = batch[key].to(local_rank) else: batch[key] = batch[key].to('cuda:0') @@ -367,20 +371,20 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): ) # If there's more than one CUDA device, reduce evaluation loss across all devices - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) # Compute average loss and perplexity eval_epoch_loss = eval_loss / len(eval_dataloader) eval_epoch_acc = eval_acc / len(eval_dataloader) - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: eval_epoch_loss = eval_epoch_loss/world_size eval_epoch_acc = eval_epoch_acc/world_size eval_ppl = torch.exp(eval_epoch_loss) # Print evaluation metrics - if train_config.enable_fsdp: + if train_config.enable_fsdp or train_config.enable_ddp: if local_rank==0: logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") else: