-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
not tested update including model_factory, slam_model, speech_text_da…
…taset, finetune, model config
- Loading branch information
Showing
7 changed files
with
498 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/bin/bash | ||
export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export CUDA_LAUNCH_BLOCKING=1 | ||
|
||
cd /root/SLAM-LLM | ||
|
||
audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth | ||
speech_encoder_path=/nfs/zhifu.gzf/init_model/whisper/large-v2.pt | ||
llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf | ||
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ | ||
--model_name echat \ | ||
--quantization \ | ||
--llm_name llama-2-7b-hf \ | ||
--llm_path $llm_path \ | ||
--encoder_name whisper \ | ||
--encoder_path $speech_encoder_path \ | ||
--encoder_projector linear \ | ||
--dataset custom_dataset \ | ||
--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \ | ||
--batching_strategy padding \ | ||
--max_words 2596 \ | ||
--num_epochs 1 \ | ||
--batch_size_training 2 \ | ||
--output_dir $output_dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class model_config: | ||
llm_name: str = "llama-2-7b-hf" | ||
llm_path: str = "PATH/to/LLAMA/7B" | ||
encoder_name: str = None | ||
encoder_path: str = None | ||
encoder_projector: str = "linear" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import os.path as osp | ||
import random | ||
import json, yaml | ||
import copy | ||
|
||
import numpy as np | ||
from scipy import signal | ||
import soundfile as sf | ||
|
||
import torch | ||
import torchaudio | ||
from torch.utils.data import Dataset | ||
import whisper | ||
|
||
|
||
prompt = ( | ||
f"USER: <Speech><SpeechHere></Speech> {prompt}\n ASSISTANT:" | ||
) | ||
|
||
def apply_prompt_template(prompt, answer): | ||
return prompt.format(prompt=prompt) | ||
|
||
class AudioDataset(Dataset): | ||
def __init__( | ||
self, | ||
dataset_config, | ||
tokenizer=None, | ||
split='train' | ||
): | ||
super().__init__() | ||
self.data = torch.randn(100, 160000) | ||
|
||
self.dataset_config = dataset_config | ||
self.max_words = dataset_config.max_words | ||
self.tokenizer = tokenizer | ||
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss | ||
|
||
def __len__(self) -> int: | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
item = self.data[index] | ||
|
||
# load audio and pad/trim it to fit 30 seconds | ||
audio_raw = whisper.load_audio(item) | ||
audio_raw = whisper.pad_or_trim(audio_raw) | ||
# make log-Mel spectrogram | ||
audio_feats = whisper.log_mel_spectrogram(audio_raw) | ||
|
||
prompt=""" | ||
Please provide an emotional response based on the emotional speech you hear. | ||
Remember to format your answer as follows: <|EMOTION|><|DEGREE|><|REPLY|>. | ||
<|EMOTION|> is a standalone adjective. | ||
<|DEGREE|> is an number ranging from 0 to 2. | ||
<|REPLY|> is a reply based on a the speech. | ||
""" | ||
answer=""" | ||
<|happy|><2><|The moon looks so beautiful tonight.|> | ||
""" | ||
|
||
prompt = apply_prompt_template(prompt=prompt) | ||
example = prompt + answer | ||
prompt_ids = torch.tensor( | ||
self.tokenizer.encode(prompt), dtype=torch.int64 | ||
) | ||
|
||
example = self.tokenizer.encode(example) | ||
example.append(self.tokenizer.eos_token_id) | ||
padding = self.max_words - example.shape[0] | ||
if padding > 0: | ||
example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) | ||
elif padding < 0: | ||
example = example[: self.max_words] | ||
labels = copy.deepcopy(example) | ||
labels[: len(prompt)] = -1 | ||
example_mask = example.ge(0) | ||
label_mask = labels.ge(0) | ||
example[~example_mask] = 0 | ||
labels[~label_mask] = self.IGNORE_INDEX | ||
|
||
return { | ||
"input_ids": example.tolist(), | ||
"labels": labels.tolist(), | ||
"attention_mask":example_mask.tolist(), | ||
'audio_mel': audio_mel | ||
} | ||
|
||
|
||
def _wav2feat(self, data): | ||
wav = data.reshape(1, -1) | ||
|
||
feats = torchaudio.compliance.kaldi.fbank( # 25ms and 10ms | ||
wav, htk_compat=True, sample_frequency=16000, use_energy=False, | ||
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10 | ||
) | ||
n_frames = feats.shape[0] | ||
|
||
p = self.target_length - n_frames | ||
|
||
# cut and pad | ||
if p > 0: | ||
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | ||
feats = m(feats) | ||
elif p < 0: | ||
feats = feats[0:self.target_length, :] | ||
|
||
return feats.unsqueeze(0) # channels, frames, dim | ||
|
||
|
||
def pad(self, sequence, max_length, padding_idx=0): | ||
if len(sequence) < max_length: | ||
sequence = sequence + [padding_idx] * (max_length - len(sequence)) | ||
else: | ||
sequence = sequence[:max_length] | ||
return sequence | ||
|
||
def collator(self, samples): | ||
assert samples is not None | ||
input_ids = torch.stack([s['input_ids'] for s in samples]) | ||
labels = torch.stack([s['labels'] for s in samples]) | ||
attention_mask = torch.stack([s['attention_mask'] for s in samples]) | ||
|
||
audio_feats = torch.stack([s['audio_feats'] for s in samples]) | ||
return { | ||
'input_ids': input_ids, | ||
'labels': labels, | ||
'attention_mask': attention_mask, | ||
'audio_feats': audio_feats, | ||
} | ||
|
||
|
||
def get_audio_dataset(dataset_config, tokenizer, split): | ||
dataset = AudioDataset(dataset_config, tokenizer, split) | ||
|
||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import types | ||
import torch | ||
import soundfile as sf | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training | ||
from transformers import ( | ||
LlamaForCausalLM, | ||
LlamaTokenizer, | ||
LlamaConfig, | ||
) | ||
import whisper | ||
import librosa | ||
|
||
from llama_recipes.utils.config_utils import generate_peft_config | ||
|
||
|
||
def setup_model(train_config, model_config, **kwargs): | ||
return slam_model(tokenizer, train_config, model_config, **kwargs) | ||
|
||
|
||
def setup_tokenizer(train_config, model_config, **kwargs): | ||
# Load the tokenizer and add special tokens | ||
if model_config.llm_name=="llama-2-7b-hf": | ||
tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path) | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
return tokenizer | ||
|
||
|
||
def extract_variable_length_features(self, x: torch.Tensor): | ||
""" | ||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) | ||
the mel spectrogram of the audio | ||
""" | ||
x = F.gelu(self.conv1(x)) | ||
x = F.gelu(self.conv2(x)) | ||
x = x.permute(0, 2, 1) | ||
|
||
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" | ||
# x = (x + self.positional_embedding).to(x.dtype) | ||
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) | ||
|
||
for block in self.blocks: | ||
x = block(x) | ||
|
||
x = self.ln_post(x) | ||
return x | ||
|
||
def setup_llm(train_config, model_config, **kwargs): | ||
from pkg_resources import packaging | ||
use_cache = False if train_config.enable_fsdp else None | ||
if train_config.enable_fsdp and train_config.low_cpu_fsdp: | ||
""" | ||
for FSDP, we can save cpu memory by loading pretrained model on rank0 only. | ||
this avoids cpu oom when loading large models like llama 70B, in which case | ||
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms | ||
overhead and currently requires latest nightly. | ||
""" | ||
v = packaging.version.parse(torch.__version__) | ||
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 | ||
if not verify_latest_nightly: | ||
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " | ||
"please install latest nightly.") | ||
if rank == 0: | ||
model = LlamaForCausalLM.from_pretrained( | ||
model_config.llm_path, | ||
load_in_8bit=True if train_config.quantization else None, | ||
device_map="auto" if train_config.quantization else None, | ||
use_cache=use_cache, | ||
) | ||
else: | ||
llama_config = LlamaConfig.from_pretrained(model_config.llm_path) | ||
llama_config.use_cache = use_cache | ||
with torch.device("meta"): | ||
model = LlamaForCausalLM(llama_config) | ||
|
||
else: | ||
model = LlamaForCausalLM.from_pretrained( | ||
model_config.llm_path, | ||
load_in_8bit=True if train_config.quantization else None, | ||
device_map="auto" if train_config.quantization else None, | ||
use_cache=use_cache, | ||
) | ||
if train_config.enable_fsdp and train_config.use_fast_kernels: | ||
""" | ||
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable | ||
using of Flash Attention or Xformer memory-efficient kernels | ||
based on the hardware being used. This would speed up fine-tuning. | ||
""" | ||
try: | ||
from optimum.bettertransformer import BetterTransformer | ||
model = BetterTransformer.transform(model) | ||
except ImportError: | ||
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") | ||
|
||
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) | ||
|
||
# Prepare the model for int8 training if quantization is enabled | ||
if train_config.quantization: | ||
model = prepare_model_for_kbit_training(model) | ||
|
||
if train_config.use_peft: | ||
peft_config = generate_peft_config(train_config, kwargs) | ||
model = get_peft_model(model, peft_config) | ||
model.print_trainable_parameters() | ||
|
||
return model | ||
|
||
|
||
class slam_model(nn.Module): | ||
def __init__( | ||
self, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs | ||
): | ||
super().__init__() | ||
# whisper | ||
self.speech_encoder = whisper.load_model(model_config.encoder_path).encoder | ||
self.speech_encoder.extract_features = types.MethodType(extract_variable_length_features, self.speech_encoder) | ||
for name, param in self.speech_encoder.named_parameters(): | ||
param.requires_grad = False | ||
self.speech_encoder.eval() | ||
self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) | ||
|
||
# llama | ||
llm = setup_llm(train_config, model_config, **kwargs) | ||
|
||
# Projector | ||
self.speech_encoder_projector = nn.Linear(self.speech_encoder.config.d_model ,self.llm.config.hidden_size) | ||
|
||
def forward(self): | ||
pass #TODO |
Oops, something went wrong.