Skip to content

Commit

Permalink
not tested update including model_factory, slam_model, speech_text_da…
Browse files Browse the repository at this point in the history
…taset, finetune, model config
  • Loading branch information
ddlBoJack committed Nov 7, 2023
1 parent 7840635 commit a0c1e2f
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 0 deletions.
28 changes: 28 additions & 0 deletions scripts/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

audio_encoder_path=/home/oss/maziyang.mzy/models/AudioMAE/finetuned.pth
speech_encoder_path=/nfs/zhifu.gzf/init_model/whisper/large-v2.pt
llm_path=/home/oss/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/models/llama-2-hf-finetune

# -m debugpy --listen 5678 --wait-for-client
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name echat \
--quantization \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
--encoder_path $speech_encoder_path \
--encoder_projector linear \
--dataset custom_dataset \
--custom_dataset.file src/llama_recipes/datasets/speech_text_dataset.py:get_audio_dataset \
--batching_strategy padding \
--max_words 2596 \
--num_epochs 1 \
--batch_size_training 2 \
--output_dir $output_dir
1 change: 1 addition & 0 deletions src/llama_recipes/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config
from llama_recipes.configs.fsdp import fsdp_config
from llama_recipes.configs.training import train_config
from llama_recipes.configs.model import model_config
10 changes: 10 additions & 0 deletions src/llama_recipes/configs/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class model_config:
llm_name: str = "llama-2-7b-hf"
llm_path: str = "PATH/to/LLAMA/7B"
encoder_name: str = None
encoder_path: str = None
encoder_projector: str = "linear"
135 changes: 135 additions & 0 deletions src/llama_recipes/datasets/speech_text_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os.path as osp
import random
import json, yaml
import copy

import numpy as np
from scipy import signal
import soundfile as sf

import torch
import torchaudio
from torch.utils.data import Dataset
import whisper


prompt = (
f"USER: <Speech><SpeechHere></Speech> {prompt}\n ASSISTANT:"
)

def apply_prompt_template(prompt, answer):
return prompt.format(prompt=prompt)

class AudioDataset(Dataset):
def __init__(
self,
dataset_config,
tokenizer=None,
split='train'
):
super().__init__()
self.data = torch.randn(100, 160000)

self.dataset_config = dataset_config
self.max_words = dataset_config.max_words
self.tokenizer = tokenizer
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index):
item = self.data[index]

# load audio and pad/trim it to fit 30 seconds
audio_raw = whisper.load_audio(item)
audio_raw = whisper.pad_or_trim(audio_raw)
# make log-Mel spectrogram
audio_feats = whisper.log_mel_spectrogram(audio_raw)

prompt="""
Please provide an emotional response based on the emotional speech you hear.
Remember to format your answer as follows: <|EMOTION|><|DEGREE|><|REPLY|>.
<|EMOTION|> is a standalone adjective.
<|DEGREE|> is an number ranging from 0 to 2.
<|REPLY|> is a reply based on a the speech.
"""
answer="""
<|happy|><2><|The moon looks so beautiful tonight.|>
"""

prompt = apply_prompt_template(prompt=prompt)
example = prompt + answer
prompt_ids = torch.tensor(
self.tokenizer.encode(prompt), dtype=torch.int64
)

example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
padding = self.max_words - example.shape[0]
if padding > 0:
example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
example = example[: self.max_words]
labels = copy.deepcopy(example)
labels[: len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = self.IGNORE_INDEX

return {
"input_ids": example.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
'audio_mel': audio_mel
}


def _wav2feat(self, data):
wav = data.reshape(1, -1)

feats = torchaudio.compliance.kaldi.fbank( # 25ms and 10ms
wav, htk_compat=True, sample_frequency=16000, use_energy=False,
window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10
)
n_frames = feats.shape[0]

p = self.target_length - n_frames

# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
feats = m(feats)
elif p < 0:
feats = feats[0:self.target_length, :]

return feats.unsqueeze(0) # channels, frames, dim


def pad(self, sequence, max_length, padding_idx=0):
if len(sequence) < max_length:
sequence = sequence + [padding_idx] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
return sequence

def collator(self, samples):
assert samples is not None
input_ids = torch.stack([s['input_ids'] for s in samples])
labels = torch.stack([s['labels'] for s in samples])
attention_mask = torch.stack([s['attention_mask'] for s in samples])

audio_feats = torch.stack([s['audio_feats'] for s in samples])
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'audio_feats': audio_feats,
}


def get_audio_dataset(dataset_config, tokenizer, split):
dataset = AudioDataset(dataset_config, tokenizer, split)

return dataset
134 changes: 134 additions & 0 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import types
import torch
import soundfile as sf
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
)
import whisper
import librosa

from llama_recipes.utils.config_utils import generate_peft_config


def setup_model(train_config, model_config, **kwargs):
return slam_model(tokenizer, train_config, model_config, **kwargs)


def setup_tokenizer(train_config, model_config, **kwargs):
# Load the tokenizer and add special tokens
if model_config.llm_name=="llama-2-7b-hf":
tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer


def extract_variable_length_features(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)

# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)

for block in self.blocks:
x = block(x)

x = self.ln_post(x)
return x

def setup_llm(train_config, model_config, **kwargs):
from pkg_resources import packaging
use_cache = False if train_config.enable_fsdp else None
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
this avoids cpu oom when loading large models like llama 70B, in which case
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)

else:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
using of Flash Attention or Xformer memory-efficient kernels
based on the hardware being used. This would speed up fine-tuning.
"""
try:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
model = prepare_model_for_kbit_training(model)

if train_config.use_peft:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

return model


class slam_model(nn.Module):
def __init__(
self,
tokenizer,
train_config,
model_config,
**kwargs
):
super().__init__()
# whisper
self.speech_encoder = whisper.load_model(model_config.encoder_path).encoder
self.speech_encoder.extract_features = types.MethodType(extract_variable_length_features, self.speech_encoder)
for name, param in self.speech_encoder.named_parameters():
param.requires_grad = False
self.speech_encoder.eval()
self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model)

# llama
llm = setup_llm(train_config, model_config, **kwargs)

# Projector
self.speech_encoder_projector = nn.Linear(self.speech_encoder.config.d_model ,self.llm.config.hidden_size)

def forward(self):
pass #TODO
Loading

0 comments on commit a0c1e2f

Please sign in to comment.