diff --git a/examples/sec_emotioncaps/README.md b/examples/sec_emotioncaps/README.md new file mode 100644 index 00000000..d1aba1d2 --- /dev/null +++ b/examples/sec_emotioncaps/README.md @@ -0,0 +1,42 @@ +# Speech Emotion Caption + +## Model Architecture + +This recipe generates high-quality, human-like speech emotion descriptions. The model is based on the **q-former projector** and the **vicuna-7b-v1.5 LLM**. The model is trained on **an unpublished datasets** dataset, which is a large-scale dataset for speech emotion captioning. + +![](docs/model.png) + +## Performance and checkpoints + +We only train the q-former projector in this recipe. + +Encoder | Projector | LLM | Similarity Score +---|---|---|--- +[emotion2vec_base](https://huggingface.co/emotion2vec/emotion2vec_base) | [Q-Former](to_do)| [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 71.10 + +> **Note**: The baseline model [SECap](https://github.com/thuhcsi/SECap) was tested in our environment and achieved a similarity score of 71.52. Our model's score is slightly lower. + +## Data preparation +You need to prepare the data jsonl in this format. + +``` +{"key": "key_name", "source": "path_to_wav_file", "target": "corresponding_caption"} +... +``` + + +## Decode with checkpoints + +``` +bash decode_emotion2vec_qformer_vicuna_7b.sh +``` + +Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script. + +## Train a new model + +If you do have sufficient relevant data, you can train the model yourself. + +``` +bash finetune_emotion2vec_qformer_vicuna_7b.sh +``` diff --git a/examples/sec_emotioncaps/conf/ds_config.json b/examples/sec_emotioncaps/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/sec_emotioncaps/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/sec_emotioncaps/conf/prompt.yaml b/examples/sec_emotioncaps/conf/prompt.yaml new file mode 100644 index 00000000..400cbfa2 --- /dev/null +++ b/examples/sec_emotioncaps/conf/prompt.yaml @@ -0,0 +1,3 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + prompt: "请用中文用一句话描述上面给出的音频中说话人的情感。" diff --git a/examples/sec_emotioncaps/docs/model.png b/examples/sec_emotioncaps/docs/model.png new file mode 100644 index 00000000..033e321f Binary files /dev/null and b/examples/sec_emotioncaps/docs/model.png differ diff --git a/examples/sec_emotioncaps/finetune_sec.py b/examples/sec_emotioncaps/finetune_sec.py new file mode 100644 index 00000000..f417daa9 --- /dev/null +++ b/examples/sec_emotioncaps/finetune_sec.py @@ -0,0 +1,49 @@ +from slam_llm.pipeline.finetune import main as train + +import hydra +import logging +from typing import Optional +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/sec_emotioncaps/inference_sec_batch.py b/examples/sec_emotioncaps/inference_sec_batch.py new file mode 100644 index 00000000..49389782 --- /dev/null +++ b/examples/sec_emotioncaps/inference_sec_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/sec_emotioncaps/model/slam_model_sec.py b/examples/sec_emotioncaps/model/slam_model_sec.py new file mode 100644 index 00000000..1edcdb3e --- /dev/null +++ b/examples/sec_emotioncaps/model/slam_model_sec.py @@ -0,0 +1,155 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_sec( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_sec(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/sec_emotioncaps/scripts/decode_emotion2vec_qformer_vicuna_7b.sh b/examples/sec_emotioncaps/scripts/decode_emotion2vec_qformer_vicuna_7b.sh new file mode 100644 index 00000000..05dbab7a --- /dev/null +++ b/examples/sec_emotioncaps/scripts/decode_emotion2vec_qformer_vicuna_7b.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM +cd $run_dir +code_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM/examples/sec_emotioncaps + +speech_encoder_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/emotion2vec_base.pt +llm_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/vicuna-7b-v1.5 +val_data_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/data/valid.jsonl + +encoder_fairseq_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/deps/emotion2vec/upstream + +output_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-decode-$(date +"%Y%m%d-%s") + +ckpt_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-finetune-20241001-1727786623/sec_epoch_1_step_3000/model.pt + +decode_log=$output_dir/decode_log + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=emotion2vec \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \ +++model_config.encoder_dim=768 \ +++model_config.encoder_projector=q-former \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.data_path=$val_data_path \ +++dataset_config.inference_mode=true \ +++dataset_config.input_type=raw \ +++train_config.model_name=sec \ +++train_config.num_epochs=1 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++log_config.log_file=$output_dir/train.log \ +++ckpt_path=$ckpt_path \ +++decode_log=$decode_log +" + +# -m debugpy --listen 5678 --wait-for-client +python $code_dir/inference_sec_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args \ No newline at end of file diff --git a/examples/sec_emotioncaps/scripts/finetune_emotion2vec_qformer_vicuna_7b.sh b/examples/sec_emotioncaps/scripts/finetune_emotion2vec_qformer_vicuna_7b.sh new file mode 100644 index 00000000..b7cf6b0d --- /dev/null +++ b/examples/sec_emotioncaps/scripts/finetune_emotion2vec_qformer_vicuna_7b.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM +cd $run_dir +code_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM/examples/sec_emotioncaps + +speech_encoder_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/emotion2vec_base.pt +llm_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/vicuna-7b-v1.5 +train_data_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/data/train.jsonl +val_data_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/data/valid.jsonl + +encoder_fairseq_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/deps/emotion2vec/upstream + +output_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-finetune-$(date +"%Y%m%d-%s") + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=emotion2vec \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \ +++model_config.encoder_dim=768 \ +++model_config.encoder_projector=q-former \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=$train_data_path \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.input_type=raw \ +++train_config.model_name=sec \ +++train_config.num_epochs=6 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=1000 \ +++train_config.batch_size_training=4 \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +++log_config.log_file=$output_dir/train.log \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python $code_dir/finetune_sec.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + --master_port=29503 \ + $code_dir/finetune_sec.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi \ No newline at end of file diff --git a/examples/sec_emotioncaps/sec_config.py b/examples/sec_emotioncaps/sec_config.py new file mode 100644 index 00000000..bc40b3d0 --- /dev/null +++ b/examples/sec_emotioncaps/sec_config.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass, field +from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + +@dataclass +class ModelConfig: + file: str = "examples/sec_emotioncaps/model/slam_model_sec.py:model_factory" + llm_name: str = "vicuna-13b-v1.5" + llm_path: str = "path/to/vicuna" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + encoder_name: Optional[str] = "emotion2vec" + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = None + encoder_fairseq_dir: Optional[str] = None + encoder_dim: int = 768 + encoder_projector: str = "q-former" + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + encoder_type: str = field(default="finetune", metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" + }) + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None, llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "sec" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + dataset: str = "speech_dataset" + file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" + train_data_path: Optional[str] = None + val_data_path: Optional[str] = None + train_split: str = "train" + test_split:str = "validation" + prompt: Optional[str] = None + data_path: Optional[str] = None + max_words: Optional[int] = None + max_mel: Optional[float] = None + fix_length_audio: int = -1 + inference_mode: bool = False + input_type: str = field(default="raw", metadata={ + "help":"Use raw when input is wav, mel when for whisper" + }) + mel_size: int = field(default=80, metadata={ + "help": "80 for whisper large v1 and v2, 128 for v3" + }) + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/src/slam_llm/models/encoder.py b/src/slam_llm/models/encoder.py index 76633404..f53b8262 100644 --- a/src/slam_llm/models/encoder.py +++ b/src/slam_llm/models/encoder.py @@ -175,3 +175,15 @@ def extract_features(self, source, padding_mask=None): _, hidden_states = self.model.get_predictions(source) out = hidden_states[self.config.encoder_layer_idx] return out + +class Emotion2vecEncoder: + + @classmethod + def load(cls, model_config): + import fairseq + model_path = UserDirModule(model_config.encoder_fairseq_dir) + fairseq.utils.import_user_module(model_path) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + model = model[0] + + return model \ No newline at end of file diff --git a/src/slam_llm/models/slam_model.py b/src/slam_llm/models/slam_model.py index fb7450e6..0aa93d34 100644 --- a/src/slam_llm/models/slam_model.py +++ b/src/slam_llm/models/slam_model.py @@ -98,6 +98,9 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "musicfm": from slam_llm.models.encoder import MusicFMEncoder encoder = MusicFMEncoder.load(model_config) + if encoder_name == "emotion2vec": + from slam_llm.models.encoder import Emotion2vecEncoder + encoder = Emotion2vecEncoder.load(model_config) if "llama" in encoder_name.lower(): from slam_llm.models.encoder import HfTextEncoder @@ -343,6 +346,8 @@ def forward(self, audio_mel_post_mask = (~audio_mel_post_mask).float() if self.model_config.encoder_name == 'musicfm': encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask + if self.model_config.encoder_name == "emotion2vec": + encoder_outs = self.encoder.extract_features(audio, None)['x'] # bs*seq*dim if self.encoder is None: encoder_outs = audio_mel if audio_mel is not None else audio