Skip to content

Commit

Permalink
Merge pull request #71 from juhayna-zh/hnz_mc
Browse files Browse the repository at this point in the history
add music caption recipe
  • Loading branch information
ddlBoJack authored May 20, 2024
2 parents 22e7fc6 + 1a06f2c commit 9ff440b
Show file tree
Hide file tree
Showing 13 changed files with 982 additions and 1 deletion.
31 changes: 31 additions & 0 deletions examples/music_caption/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Music Caption

## Performance and checkpoints
Here is a recipe for music captioning, using MusicFM as encoder. We only train the linear projector. For more about MusicFM and its checkpoints, please refer to [this repository](https://github.com/minzwon/musicfm).

The following results are obtained by training on the LP-MusicCaps-MC training set and evaluating on the LP-MusicCaps-MC test set.
Encoder | Projector | LLM | BLEU-1 | METEOR | SPICE | SPIDER
|---|---|---|---|---|---|---
[MusicFM(pretrained with MSD)](https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt) | [Linear](https://drive.google.com/file/d/1-9pob6QvJRoq5Dy-LZbiDfF6Q7QRO8Au/view?usp=sharing)(~18.88M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 25.6 | 10.0 | 8.7 | 6.9


## Data preparation
You need to prepare the data jsonl in this format. Note that you may need to pre-extract the sample rate and duration of audio files for better loading efficiency.
```
{"key": "[-0Gj8-vB1q4]-[30-40]", "source": "path/to/MusicCaps/wav/[-0Gj8-vB1q4]-[30-40].wav", "target": "The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.", "duration": 10.0, "sample_rate": 48000}
...
{"key": "[-0vPFx-wRRI]-[30-40]", "source": "path/to/MusicCaps/wav/[-0vPFx-wRRI]-[30-40].wav", "target": "a male voice is singing a melody with changing tempos while snipping his fingers rhythmically. The recording sounds like it has been recorded in an empty room. This song may be playing, practicing snipping and singing along.", "duration": 10.0, "sample_rate": 48000}
```

## Decode with checkpoints
```
bash decode_musicfm_linear_vicuna_7b_10s.sh
```
Modify the path including `music_encoder_path`, `music_encoder_stat_path`, `music_encoder_config_path`(if specified), `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script.

## Train a new model

### Use MusicFM as encoder for music modality.
```
finetune_musicfm_linear_vicuna_7b_10s.sh
```
19 changes: 19 additions & 0 deletions examples/music_caption/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
3 changes: 3 additions & 0 deletions examples/music_caption/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
prompt: "Describe this music."
47 changes: 47 additions & 0 deletions examples/music_caption/deepspeed_finetune_mir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from slam_llm.pipeline.finetune_deepspeed import main as train
from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper

import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig


@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
deepspeed_config: str = field(default="examples/asr_librispeech/conf/ds_config.json", metadata={"help": "The metric for evaluation"})


@deepspeed_main_wrapper(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
45 changes: 45 additions & 0 deletions examples/music_caption/finetune_mir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from slam_llm.pipeline.finetune import main as train

import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
53 changes: 53 additions & 0 deletions examples/music_caption/inference_mir_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from slam_llm.pipeline.inference_batch import main as inference

import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Optional
from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig


@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
decode_log: str = field(
default="output/decode_log",
metadata={"help": "The prefix for the decode output"},
)
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None,
metadata={
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
},
)


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


if __name__ == "__main__":
main_hydra()
130 changes: 130 additions & 0 deletions examples/music_caption/mir_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/music_caption/model/slam_model_mir.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
encoder_name: Optional[str] = "mulan"
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_config_path: Optional[str] = None
encoder_stat_path: Optional[str] = None
encoder_layer_idx: Optional[int] = None
encoder_dim: int = 768
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether inpit is normalized, used for models such as wavlm"
})

@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
}) #
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1

use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = True # False

@dataclass
class DataConfig:
dataset: str = "mir_dataset"
file: str = "src/slam_llm/datasets/mir_dataset.py:get_mir_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
fixed_duration: float = 10.0
audio_label_freq: int = 75
fixed_audio_token_num: Optional[int] = None
sample_rate: int = 24000
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether inpit is normalized, used for models such as wavlm"
})

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/root/test.log"
log_interval: int = 5
Loading

0 comments on commit 9ff440b

Please sign in to comment.