Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ygr avsr #16

Merged
merged 11 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
.DS_Store
__pycache__
.ipynb_checkpoints
.vscode
debug.py
.idea/*
transformers
wandb/
*.log
log
9 changes: 6 additions & 3 deletions examples/vllm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from vllm import LLM
from vllm import LLM, SamplingParams

import logging
logger = logging.getLogger(__name__)


torch.cuda.manual_seed(42)
torch.manual_seed(42)
Expand All @@ -27,15 +30,15 @@ def main(
if user_prompt is None:
user_prompt = input("Enter your prompt: ")

print(f"User prompt:\n{user_prompt}")
logger.info(f"User prompt:\n{user_prompt}")

print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
logger.info(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)


outputs = model.generate(user_prompt, sampling_params=sampling_param)

print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
logger.info(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
user_prompt = input("Enter next prompt (press Enter to exit): ")
if not user_prompt:
break
Expand Down
28 changes: 28 additions & 0 deletions scripts/finetune_avsr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth
speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt
llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/finetune.py \
--model_name avsr \
--use_peft --peft_method lora \
--quantization \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset avsr_dataset \
--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
--batching_strategy custom \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir
32 changes: 32 additions & 0 deletions scripts/finetune_avsr_debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth
speech_encoder_path=/home/oss/maziyang.mzy/models/Whisper/base.pt

llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf #/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/yangguanrou.ygr/ckpts/llama-2-hf-finetune #/home/oss/yangguanrou.ygr/ckpts/llama-2-hf-finetune

# -m debugpy --listen 5680 --wait-for-client
python -m debugpy --listen 5680 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name avsr \
--use_peft --peft_method lora \
--quantization \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset avsr_dataset \
--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
--batching_strategy custom \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir \
--stepSize 10 \
--log_file "/root/SLAM-LLM/log/test.log" \
--valid_subset "LRS3/val_debug.txt" \
35 changes: 35 additions & 0 deletions scripts/finetune_avsr_debug_1214.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

cd /root/SLAM-LLM


llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
python -m debugpy --listen 5679 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name avsr \
--freeze_llm \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset avsr_dataset \
--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
--batching_strategy custom \
--num_epochs 1 \
--batch_size_training 4 \
--lr 1e-5 \
--output_dir $output_dir \
--metric acc \
--log_file "/root/SLAM-LLM/log/test.log" \


# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
35 changes: 35 additions & 0 deletions scripts/finetune_avsr_debug_1218.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

cd /root/SLAM-LLM


llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/yangguanrou.ygr/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/finetune.py \
--model_name avsr \
--freeze_llm \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset avsr_dataset \
--avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
--batching_strategy custom \
--num_epochs 1 \
--batch_size_training 4 \
--lr 1e-5 \
--output_dir $output_dir \
--metric acc \
--log_file "/root/SLAM-LLM/log/test.log" \


# --avsr_dataset.file src/llama_recipes/datasets/avsr_dataset.py:get_audio_dataset \
89 changes: 89 additions & 0 deletions scripts/finetune_speech_pretraining_my.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \
--custom_dataset.train_data_path /nfs/beinian.lzr/workspace/datasets/speech_llm/train_dataset/data_wav_json/asr/librispeech_train_960h_wav_speech_llm_train_data.json \
--custom_dataset.val_data_path /nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/dev_other/librispeech_dev_other.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav" \
--run_test_during_validation_prompt "<|ASR|>" \
--metric acc \
--log_file "/root/SLAM-LLM/log/test.log" \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
# --use_peft --peft_method lora \

else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_dataset.py:get_audio_dataset \
--custom_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h_wav_speech_llm_train_data.json \
--custom_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 8 \
--val_batch_size 8 \
--num_workers_dataloader 4 \
--lr 1e-5 \
--output_dir $output_dir \
--run_test_during_validation \
--run_test_during_validation_file "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0000.wav" \
--run_test_during_validation_prompt "<|ASR|>" \
--metric acc \
--log_file "/root/SLAM-LLM/log/test.log" \
# --ckpt_path "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/models/llama-2-hf-finetune/echat/7" \
# --use_peft --peft_method lora \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
37 changes: 31 additions & 6 deletions src/llama_recipes/configs/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,34 @@ class custom_dataset:
file: str = "examples/custom_dataset.py"
train_split: str = "train"
test_split: str = "validation"
data_path: str = NotImplemented
train_data_path: str = NotImplemented
val_data_path: str = NotImplemented
max_words: int = NotImplemented
max_mel: int = NotImplemented
fix_length_audio: int = -1
data_path: str = None
max_words: int = None
train_data_path: str = None
val_data_path: str = None
max_words: int = None
max_mel: int = None
fix_length_audio: int = -1



@dataclass
class avsr_dataset:
dataset: str = "avsr_dataset"
file: str = "examples/avsr_dataset.py"
train_split: str = "train"
test_split: str = "val"
data_path: str = "/nfs/yangguanrou.ygr/" #"/home/oss/yangguanrou.ygr/"
h5file: str = "/nfs/yangguanrou.ygr/LRS3/LRS3.h5" # "/home/oss/yangguanrou.ygr/LRS3/LRS3.h5"
noiseFile : str = "/nfs/yangguanrou.ygr/AVSR/LRS3/Noise.h5" #"/home/oss/yangguanrou.ygr/AVSR/LRS3/Noise.h5"
noiseProb: float = 0.
noiseSNR: float = 5
stepSize: int = 16384
# charToIx={" ": 1, "'": 22, "1": 30, "0": 29, "3": 37, "2": 32, "5": 34, "4": 38, "7": 36, "6": 35, "9": 31, "8": 33, "A": 5, "C": 17,
# "B": 20, "E": 2, "D": 12, "G": 16, "F": 19, "I": 6, "H": 9, "K": 24, "J": 25, "M": 18, "L": 11, "O": 4, "N": 7, "Q": 27,
# "P": 21, "S": 8, "R": 10, "U": 13, "T": 3, "W": 15, "V": 23, "Y": 14, "X": 26, "Z": 28, "<EOS>": 39}
charToIx : str = "x" #应该没用了 TypeError: Object of type NotImplementedType is not JSON serializable 但这个是上面的问题
modal: str = "AV"
pretrain_subset: str = "LRS3/pretrain.txt"
train_subset: str = "LRS3/train.txt"
valid_subset: str = "LRS3/val.txt"
test_subset: str = "LRS3/test.txt"
31 changes: 30 additions & 1 deletion src/llama_recipes/configs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,33 @@ class model_config:
encoder_ds_rate: int = 2
encoder_path: str = None
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
encoder_projector_ds_rate: int = 5

DMODEL: int = 512
FRONTEND_DMODEL: int = 1024 #这个是专门指moco的
TX_ATTENTION_HEADS: int = 8
TX_NUM_LAYERS: int = 6
PE_MAX_LENGTH: int = 500
AUDIO_FEATURE_SIZE: int = 1024
VIDEO_FEATURE_SIZE: int = 2048
TX_FEEDFORWARD_DIM: int= 2048
TX_DROPOUT: int = 0.1
CHAR_NUM_CLASSES: int = 40

WORD_NUM_CLASSES: int = 500
FRAME_LENGTH: int = 29
MOCO_FRONTEND_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt"
WAV2VEC_FILE: str = "/nfs/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt"
MAIN_REQ_INPUT_LENGTH: int = 80
modal: str = "AV"
TRAIN_LRS3_MODEL_FILE: str = "/nfs/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" # "/home/oss/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" #单一模态是这个
TRAINED_AO_FILE : str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt"
TRAINED_VO_FILE: str = "/nfs/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt"








2 changes: 2 additions & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class train_config:
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
log_file: str="PATH/to/Log_File"
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
run_test_during_validation_prompt: str = "<|ASR|>"
freeze_llm: bool = False
freeze_encoder: bool = False
log_interval: int = 5
Loading
Loading