Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update text encoder for MuPT #40

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds10-proj2048-steplrwarmup1e-4decay-fuyu-lora_qkvo_promptshort-lowergt-20240220/asr/6/decode_log_test_other_beam4_repetition_penalty1_bs1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds10-proj2048-steplrwarmup1e-4decay-fuyu-lora_qkvo_promptshort-lowergt-20240220/asr/6/decode_log_test_other_beam4_repetition_penalty1_bs1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
134 changes: 134 additions & 0 deletions scripts/finetune_llama_mupt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1,2,3
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

cd /root/SLAM-LLM

# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt
# speech_encoder_path=//nfs/maziyang.mzy/models/Whisper/small.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Base.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Base+.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt
text_encoder_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5
llm_path=/nfs/maziyang.mzy/models/MuPT_v1_8192

output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-symbol-ds1-proj2048-steplrwarmup1e-4decay-lora-20240224-test
# ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-20240131/asr/4

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$output_dir \
++model_config.llm_name="MuPT_v1_8192" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=1536 \
++model_config.encoder_name="TinyLlama-1.1B-Chat-v0.4" \
++model_config.encoder_path=$text_encoder_path \
++model_config.encoder_dim=2048 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=1 \
++dataset_config.dataset=text_dataset \
++dataset_config.file="src/llama_recipes/datasets/text_dataset.py:get_text_dataset" \
++dataset_config.tokenizer_path=$text_encoder_path \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++dataset_config.input_type=features \
++train_config.model_name=mupt \
++train_config.freeze_encoder=true \
++train_config.use_peft=true \
++train_config.peft_config.peft_method=lora \
++train_config.batching_strategy=custom \
++train_config.warmup_steps=1000 \
++train_config.total_steps=100000 \
++train_config.lr=1e-4 \
++train_config.validation_interval=1000 \
++train_config.batch_size_training=1 \
++train_config.val_batch_size=1 \
++train_config.num_workers_dataloader=1 \
++train_config.output_dir=$output_dir \
++metric=acc \
# ++train_config.freeze_llm=true \
# ++ckpt_path=$ckpt_path/model.pt \
# ++model_config.encoder_projector=q-former \
# ++dataset_config.fix_length_audio=64 \
# ++peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \


else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--master_port=29501 \
src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$output_dir \
++model_config.llm_name="MuPT_v1_8192" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=1536 \
++model_config.encoder_name="TinyLlama-1.1B-Chat-v0.4" \
++model_config.encoder_path=$text_encoder_path \
++model_config.encoder_dim=2048 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=1 \
++dataset_config.dataset=text_dataset \
++dataset_config.file="src/llama_recipes/datasets/text_dataset.py:get_text_dataset" \
++dataset_config.tokenizer_path=$text_encoder_path \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++dataset_config.input_type=features \
++train_config.model_name=mupt \
++train_config.freeze_encoder=true \
++train_config.use_peft=true \
++train_config.peft_config.peft_method=lora \
++train_config.batching_strategy=custom \
++train_config.warmup_steps=10000 \
++train_config.total_steps=1000000 \
++train_config.lr=1e-4 \
++train_config.validation_interval=5000 \
++train_config.batch_size_training=6 \
++train_config.val_batch_size=6 \
++train_config.num_workers_dataloader=4 \
++train_config.output_dir=$output_dir \
++train_config.enable_fsdp=false \
++train_config.enable_ddp=true \
++train_config.use_fp16=true \
++metric=acc \
++log_config.log_file=/$output_dir/train.log \
++log_config.use_wandb=true \
++log_config.wandb_dir=$output_dir \
++log_config.wandb_entity_name=zym22 \
++log_config.wandb_project_name=slam-llm \
++log_config.wandb_exp_name=${0##*/%.*} \
++log_config.log_interval=5 \
# ++train_config.freeze_llm=true \
# ++ckpt_path=$ckpt_path/model.pt \
# ++model_config.encoder_projector=q-former \
# ++dataset_config.fix_length_audio=64 \
# ++peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
# {"key": "1688-142285-0005", "prompt": "<ASR>", "source": "/nfs/beinian.lzr/workspace/datasets/data/16k/opendata/librispeech/test_other/wav/1688-142285-0005.wav", "target": "YOU WHO WERE ALWAYS ACCUSING PEOPLE OF BEING SHOPPY AT HELSTONE", "target_len": 11, "source_len": 220, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
198 changes: 198 additions & 0 deletions src/llama_recipes/datasets/text_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os.path as osp
import random
import json, yaml
import copy

import numpy as np
from scipy import signal

import torch
from torch.utils.data import Dataset


class TextDatasetJsonl(torch.utils.data.Dataset):

def __init__(self,
dataset_config,
tokenizer=None,
split='train',
):
super().__init__()
self.dataset_config = dataset_config
self.tokenizer = tokenizer

# self.data_list = contents
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt = "X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB"
self.prompt_template = "{}"
self.answer_template = "{}"
self.fix_length_text = dataset_config.get("fix_length_text", -1) # for Q-former
self.inference_mode = dataset_config.get("inference_mode", False)
self.input_type = dataset_config.get("input_type", None)
assert self.input_type in ["raw", "features"], "input_type must be one of [raw, features]"
if self.input_type == "features":
from transformers import AutoTokenizer
self.instruct_tokenizer = AutoTokenizer.from_pretrained(dataset_config.get("tokenizer_path", "Llama-2-7b-hf"))

self.data_list = []
if split == "train":
with open(dataset_config.train_data_path, encoding='utf-8') as fin:
for line in fin:
data_dict = json.loads(line.strip())
self.data_list.append(data_dict)
else:
with open(dataset_config.val_data_path, encoding='utf-8') as fin:
for line in fin:
data_dict = json.loads(line.strip())
self.data_list.append(data_dict)

# # debug
# with open(dataset_config.train_data_path, encoding='utf-8') as fin:
# for line in fin:
# data_dict = json.loads(line.strip())
# self.data_list.append(data_dict)
# if split == "train":
# self.data_list = self.data_list[:80]
# else:
# self.data_list = self.data_list[80:100]

def get_source_len(self, data_dict):
return data_dict["source_len"]

def get_target_len(self, data_dict):
return data_dict["target_len"] if "target_len" in data_dict else 0

def __len__(self):
return len(self.data_list)

def __getitem__(self, index):
data_dict = self.data_list[index]
instruct = data_dict.get("instruct", "Dummy Instruct")
target = data_dict.get("target", "Dummy Target")

prompt = self.prompt
prompt = self.prompt_template.format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)

if self.input_type == "raw":
instruct_length = 0
prompt = instruct + prompt
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
elif self.input_type == "features":
instruct_ids = self.instruct_tokenizer.encode(instruct)
instruct_length = len(instruct_ids)
instruct_ids = torch.tensor(instruct_ids, dtype=torch.int64) if instruct_ids is not None else None

if self.fix_length_text > 0: # for Q-former
instruct_length = self.fix_length_text
instruct_pseudo = torch.full((instruct_length,), -1) # placeholder

if self.inference_mode:
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((instruct_pseudo, prompt_ids)) # [audio,prompt]
example_mask = example_ids.ge(-1) # [True,True]

return {
"input_ids": example_ids,
"attention_mask": example_mask,
"instruct_ids": instruct_ids if self.input_type == "features" else None,
"instruct_length": instruct_length,
}

answer = self.answer_template.format(target)
example = prompt + answer
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
example_ids = torch.tensor(
example_ids, dtype=torch.int64
)
example_ids = torch.cat((instruct_pseudo, example_ids)) # [instruct,prompt,answer,eos]

labels_ids = copy.deepcopy(example_ids) # [instruct,prompt,answer,eos]
labels_ids[:instruct_length + prompt_length] = -1 # [-1,-1,answer,eos]
example_mask = example_ids.ge(-1) # [True,True,True,True]

label_mask = labels_ids.ge(0) # [False,False,True,True]
example_ids[~example_mask] = 0 # [instruct,prompt,answer,eos]
labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos]

return {
"input_ids": example_ids,
"labels": labels_ids,
"attention_mask": example_mask,
"instruct_ids": instruct_ids if self.input_type == "features" else None,
"instruct_length": instruct_length,
}

def pad(self, sequence, max_length, padding_idx=0):
if isinstance(sequence, (int, list, tuple)):
if len(sequence) < max_length:
sequence = sequence + [padding_idx] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, torch.Tensor):
if len(sequence) < max_length:
sequence = torch.cat(
(sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, np.ndarray):
if len(sequence) < max_length:
sequence = np.concatenate(
(sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:max_length]
else:
raise Exception("Type mismatch during padding!")
return sequence

def collator(self, samples):
assert samples is not None
input_ids_max_length = max([s['input_ids'].shape[0] for s in samples])
input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id)
for s in samples])
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples])
if self.input_type == "raw":
instruct_max_length = 0
instruct_ids = None
elif self.input_type == "features":
instruct_max_length = max([s['instruct_ids'].shape[0] for s in samples])
instruct_ids = torch.stack([self.pad(s['instruct_ids'], instruct_max_length, self.instruct_tokenizer.pad_token_id)
for s in samples])
instruct_mask = torch.zeros(len(samples), instruct_max_length)
for line, sample in enumerate(samples):
instruct_mask[line, :sample['instruct_length']] = 1

modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
modality_mask[line, :sample['instruct_length']] = 1

if self.inference_mode:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"instruct_ids": instruct_ids if self.input_type == "features" else None,
"instruct_mask": instruct_mask if self.input_type == "features" else None,
"modality_mask": modality_mask,
}

labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"instruct_ids": instruct_ids if self.input_type == "features" else None,
"instruct_mask": instruct_mask if self.input_type == "features" else None,
"modality_mask": modality_mask
}



def get_text_dataset(dataset_config, tokenizer, split):
dataset = TextDatasetJsonl(dataset_config, tokenizer, split)

return dataset
10 changes: 9 additions & 1 deletion src/llama_recipes/models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,12 @@ def load(cls, model_config):
checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE)
avnet.load_state_dict(checkpoint['state_dict'],strict=False)

return avnet
return avnet

class HfTextEncoder:

@classmethod
def load(cls, model_config):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_config.encoder_path)
return model
Loading
Loading